From 60f0df038a334296fed256b7b10a743e7f8f3eaa Mon Sep 17 00:00:00 2001 From: frosforever Date: Tue, 5 Dec 2017 17:15:03 -0500 Subject: [PATCH] add when column method --- .../functions/NonAggregateFunctions.scala | 24 ++++++++++++ .../NonAggregateFunctionsTests.scala | 37 +++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala index 1bac1226..f3780990 100644 --- a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala @@ -1,6 +1,7 @@ package frameless package functions +import org.apache.spark.sql.Column import org.apache.spark.sql.{functions => untyped} import scala.util.matching.Regex @@ -104,6 +105,29 @@ trait NonAggregateFunctions { new TypedColumn[T, A](untyped.bitwiseNOT(column.untyped)) } + /** Non-Aggregate function: Evaluates a list of conditions and returns one of multiple + * possible result expressions. If none match, otherwise is returned + * {{{ + * when(ds('boolField), ds('a)) + * .when(ds('otherBoolField), lit(123)) + * .otherwise(ds('b)) + * }}} + * apache/spark + */ + def when[T, A](condition: TypedColumn[T, Boolean], value: TypedColumn[T, A]): When[T, A] = + new When[T, A](condition, value) + + class When[T, A] private (untypedC: Column) { + private[functions] def this(condition: TypedColumn[T, Boolean], value: TypedColumn[T, A]) = + this(untyped.when(condition.untyped, value.untyped)) + + def when(condition: TypedColumn[T, Boolean], value: TypedColumn[T, A]): When[T, A] = new When[T, A]( + untypedC.when(condition.untyped, value.untyped) + ) + + def otherwise(value: TypedColumn[T, A]): TypedColumn[T, A] = + new TypedColumn[T, A](untypedC.otherwise(value.untyped))(value.uencoder) + } ////////////////////////////////////////////////////////////////////////////////////////////// // String functions diff --git a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala index b93bb251..ea627a18 100644 --- a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala @@ -462,6 +462,43 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Int] _)) } + test("when") { + val spark = session + import spark.implicits._ + + def prop[A : TypedEncoder : Encoder](condition1: Boolean, condition2: Boolean, value1: A, value2: A, otherwise: A) = { + val ds = TypedDataset.create(X5(condition1, condition2, value1, value2, otherwise) :: Nil) + + val untypedWhen = ds.toDF() + .select( + untyped.when(untyped.col("a"), untyped.col("c")) + .when(untyped.col("b"), untyped.col("d")) + .otherwise(untyped.col("e")) + ) + .as[A] + .collect() + .toList + + val typedWhen = ds + .select( + when(ds('a), ds('c)) + .when(ds('b), ds('d)) + .otherwise(ds('e)) + ) + .collect() + .run() + .toList + + typedWhen ?= untypedWhen + } + + check(forAll(prop[Long] _)) + check(forAll(prop[Short] _)) + check(forAll(prop[Byte] _)) + check(forAll(prop[Int] _)) + check(forAll(prop[Option[Int]] _)) + } + test("ascii") { val spark = session import spark.implicits._