Skip to content

Commit

Permalink
Add a typed col function for creating column references
Browse files Browse the repository at this point in the history
Resolves #186.
  • Loading branch information
Itamar Ravid committed Sep 22, 2017
1 parent 68aa838 commit 897e499
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
10 changes: 10 additions & 0 deletions dataset/src/main/scala/frameless/functions/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package frameless

import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.functions.{ col => sparkCol }
import shapeless.Witness

package object functions extends Udf with UnaryFunctions {
object aggregate extends AggregateFunctions
Expand All @@ -17,4 +19,12 @@ package object functions extends Udf with UnaryFunctions {
new TypedColumn(expr)
}
}

def col[T, A](column: Witness.Lt[Symbol])(
implicit
exists: TypedColumn.Exists[T, column.T, A],
encoder: TypedEncoder[A]): TypedColumn[T, A] = {
val untypedExpr = sparkCol(column.value.name).as[A](TypedExpressionEncoder[A])
new TypedColumn[T, A](untypedExpr)
}
}
3 changes: 2 additions & 1 deletion dataset/src/test/scala/frameless/SelectTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ class SelectTests extends TypedDatasetSuite {
val A = dataset.col[A]('a)

val dataset2 = dataset.select(A).collect().run().toVector
val symDataset2 = dataset.select(functions.col('a)).collect().run().toVector
val data2 = data.map { case X4(a, _, _, _) => a }

dataset2 ?= data2
(dataset2 ?= data2) && (symDataset2 ?= data2)
}

check(forAll(prop[Int, Int, Int, Int] _))
Expand Down

0 comments on commit 897e499

Please sign in to comment.