Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility asJoinColValue to allow join on Value class dataset #695

Merged
merged 1 commit into from
Mar 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
def count[F[_]]()(implicit F: SparkDelay[F]): F[Long] =
F.delay(dataset.count())

/** Returns `TypedColumn` of type `A` given its name.
/** Returns `TypedColumn` of type `A` given its name (alias for `col`).
*
* {{{
* tf('id)
Expand Down Expand Up @@ -250,7 +250,7 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
def col[A](x: Function1[T, A]): TypedColumn[T, A] =
macro TypedColumnMacroImpl.applyImpl[T, A]

/** Projects the entire TypedDataset[T] into a single column of type TypedColumn[T,T]
/** Projects the entire `TypedDataset[T]` into a single column of type `TypedColumn[T,T]`.
* {{{
* ts: TypedDataset[Foo] = ...
* ts.select(ts.asCol, ts.asCol): TypedDataset[(Foo,Foo)]
Expand All @@ -261,12 +261,28 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
case StructType(_) =>
val allColumns: Array[Column] = dataset.columns.map(dataset.col)
org.apache.spark.sql.functions.struct(allColumns.toSeq: _*)

case _ =>
dataset.col(dataset.columns.head)
}

new TypedColumn[T,T](projectedColumn)
}

/** References the entire `TypedDataset[T]` as a single column
* of type `TypedColumn[T,T]` so it can be used in a join operation.
*
* {{{
* def nameJoin(ds1: TypedDataset[Person], ds2: TypedDataset[Name]) =
* ds1.joinLeftSemi(ds2)(ds1.col('name) === ds2.asJoinColValue)
* }}}
*/
def asJoinColValue(implicit i0: IsValueClass[T]): TypedColumn[T, T] = {
import _root_.frameless.syntax._

dataset.col("value").typedColumn
}

object colMany extends SingletonProductArgs {
def applyProduct[U <: HList, Out](columns: U)
(implicit
Expand Down Expand Up @@ -635,11 +651,13 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
def joinInner[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean])
(implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = {
import FramelessInternals._

val leftPlan = logicalPlan(dataset)
val rightPlan = logicalPlan(other.dataset)
val join = disambiguate(Join(leftPlan, rightPlan, Inner, Some(condition.expr), JoinHint.NONE))
val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan)
val joinedDs = mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, U)])

TypedDataset.create[(T, U)](joinedDs)
}

Expand Down Expand Up @@ -1291,8 +1309,9 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
i7: TypedEncoder[Out]
): TypedDataset[Out] = {
val df = dataset.toDF()
val trans =
df.filter(df(column.value.name).isNotNull).as[Out](TypedExpressionEncoder[Out])
val trans = df.filter(df(column.value.name).isNotNull).
as[Out](TypedExpressionEncoder[Out])

TypedDataset.create[Out](trans)
}
}
Expand All @@ -1304,6 +1323,7 @@ object TypedDataset {
sqlContext: SparkSession
): TypedDataset[A] = {
val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A])

TypedDataset.create[A](dataset)
}

Expand All @@ -1313,10 +1333,12 @@ object TypedDataset {
sqlContext: SparkSession
): TypedDataset[A] = {
val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A])

TypedDataset.create[A](dataset)
}

def create[A: TypedEncoder](dataset: Dataset[A]): TypedDataset[A] = createUnsafe(dataset.toDF())
def create[A: TypedEncoder](dataset: Dataset[A]): TypedDataset[A] =
createUnsafe(dataset.toDF())

/**
* Creates a [[frameless.TypedDataset]] from a Spark [[org.apache.spark.sql.DataFrame]].
Expand Down
18 changes: 17 additions & 1 deletion dataset/src/test/scala/frameless/ColumnTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ final class ColumnTests extends TypedDatasetSuite with Matchers {
test("asCol with numeric operators") {
def prop(a: Seq[Long]) = {
val ds: TypedDataset[Long] = TypedDataset.create(a)
val (first,second) = (2L,5L)
val (first, second) = (2L, 5L)
val frameless: Seq[(Long, Long, Long)] =
ds.select(ds.asCol, ds.asCol+first, ds.asCol*second).collect().run()

Expand All @@ -402,6 +402,22 @@ final class ColumnTests extends TypedDatasetSuite with Matchers {
check(forAll(prop _))
}

test("reference Value class so can join on") {
import RecordEncoderTests.{ Name, Person }

val bar = new Name("bar")

val ds1: TypedDataset[Person] = TypedDataset.create(
Seq(Person(bar, 23), Person(new Name("foo"), 11)))

val ds2: TypedDataset[Name] =
TypedDataset.create(Seq(new Name("lorem"), bar))

val joined = ds1.joinLeftSemi(ds2)(ds1.col('name) === ds2.asJoinColValue)

joined.collect().run() shouldEqual Seq(Person(bar, 23))
}

test("unary_!") {
val ds = TypedDataset.create((true, false) :: Nil)

Expand Down