Skip to content

Commit

Permalink
Merge pull request #220 from frosforever/i#164-when
Browse files Browse the repository at this point in the history
add when column method
  • Loading branch information
OlivierBlanvillain committed Dec 8, 2017
2 parents 75fc3dc + 60f0df0 commit 73eada3
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package frameless
package functions

import org.apache.spark.sql.Column
import org.apache.spark.sql.{functions => untyped}

import scala.util.matching.Regex
Expand Down Expand Up @@ -104,6 +105,29 @@ trait NonAggregateFunctions {
new TypedColumn[T, A](untyped.bitwiseNOT(column.untyped))
}

/** Non-Aggregate function: Evaluates a list of conditions and returns one of multiple
* possible result expressions. If none match, otherwise is returned
* {{{
* when(ds('boolField), ds('a))
* .when(ds('otherBoolField), lit(123))
* .otherwise(ds('b))
* }}}
* apache/spark
*/
def when[T, A](condition: TypedColumn[T, Boolean], value: TypedColumn[T, A]): When[T, A] =
new When[T, A](condition, value)

class When[T, A] private (untypedC: Column) {
private[functions] def this(condition: TypedColumn[T, Boolean], value: TypedColumn[T, A]) =
this(untyped.when(condition.untyped, value.untyped))

def when(condition: TypedColumn[T, Boolean], value: TypedColumn[T, A]): When[T, A] = new When[T, A](
untypedC.when(condition.untyped, value.untyped)
)

def otherwise(value: TypedColumn[T, A]): TypedColumn[T, A] =
new TypedColumn[T, A](untypedC.otherwise(value.untyped))(value.uencoder)
}

//////////////////////////////////////////////////////////////////////////////////////////////
// String functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,43 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[Int] _))
}

test("when") {
val spark = session
import spark.implicits._

def prop[A : TypedEncoder : Encoder](condition1: Boolean, condition2: Boolean, value1: A, value2: A, otherwise: A) = {
val ds = TypedDataset.create(X5(condition1, condition2, value1, value2, otherwise) :: Nil)

val untypedWhen = ds.toDF()
.select(
untyped.when(untyped.col("a"), untyped.col("c"))
.when(untyped.col("b"), untyped.col("d"))
.otherwise(untyped.col("e"))
)
.as[A]
.collect()
.toList

val typedWhen = ds
.select(
when(ds('a), ds('c))
.when(ds('b), ds('d))
.otherwise(ds('e))
)
.collect()
.run()
.toList

typedWhen ?= untypedWhen
}

check(forAll(prop[Long] _))
check(forAll(prop[Short] _))
check(forAll(prop[Byte] _))
check(forAll(prop[Int] _))
check(forAll(prop[Option[Int]] _))
}

test("ascii") {
val spark = session
import spark.implicits._
Expand Down

0 comments on commit 73eada3

Please sign in to comment.