diff --git a/build.sbt b/build.sbt index a3dafd97..9de04ebe 100644 --- a/build.sbt +++ b/build.sbt @@ -273,7 +273,10 @@ lazy val datasetSettings = mc("org.apache.spark.sql.reflection.package$ScalaSubtypeLock$"), mc("frameless.MapGroups"), mc(f"frameless.MapGroups$$"), - dmm("frameless.functions.package.litAggr") + dmm("frameless.functions.package.litAggr"), + dmm("org.apache.spark.sql.FramelessInternals.column"), + dmm("frameless.TypedEncoder.collectionEncoder"), + dmm("frameless.TypedEncoder.setEncoder") ) }, coverageExcludedPackages := "frameless.reflection", diff --git a/dataset/src/main/scala/frameless/CollectionCaster.scala b/dataset/src/main/scala/frameless/CollectionCaster.scala new file mode 100644 index 00000000..bf329992 --- /dev/null +++ b/dataset/src/main/scala/frameless/CollectionCaster.scala @@ -0,0 +1,67 @@ +package frameless + +import frameless.TypedEncoder.CollectionConversion +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{ + CodegenContext, + CodegenFallback, + ExprCode +} +import org.apache.spark.sql.catalyst.expressions.{ Expression, UnaryExpression } +import org.apache.spark.sql.types.{ DataType, ObjectType } + +case class CollectionCaster[F[_], C[_], Y]( + child: Expression, + conversion: CollectionConversion[F, C, Y]) + extends UnaryExpression + with CodegenFallback { + + protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + override def eval(input: InternalRow): Any = { + val o = child.eval(input).asInstanceOf[Object] + o match { + case col: F[Y] @unchecked => + conversion.convert(col) + case _ => o + } + } + + override def dataType: DataType = child.dataType +} + +case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression) + extends UnaryExpression { + + protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + // eval on interpreted works, fallback on codegen does not, e.g. with ColumnTests.asCol and Vectors, the code generated still has child of type Vector but child eval returns X2, which is not good + override def eval(input: InternalRow): Any = { + val o = child.eval(input).asInstanceOf[Object] + o match { + case col: Set[Y] @unchecked => + col.toSeq + case _ => o + } + } + + def toSeqOr[T](isSet: => T, or: => T): T = + child.dataType match { + case ObjectType(cls) + if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + isSet + case t => or + } + + override def dataType: DataType = + toSeqOr(ObjectType(classOf[scala.collection.Seq[_]]), child.dataType) + + override protected def doGenCode( + ctx: CodegenContext, + ev: ExprCode + ): ExprCode = + defineCodeGen(ctx, ev, c => toSeqOr(s"$c.toVector()", s"$c")) + +} diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index b72ff771..1ae9512d 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -5,7 +5,6 @@ import java.util.Date import java.time.{ Duration, Instant, LocalDate, Period } import java.sql.Timestamp import scala.reflect.ClassTag -import FramelessInternals.UserDefinedType import org.apache.spark.sql.catalyst.expressions.{ Expression, UnsafeArrayData, @@ -18,7 +17,6 @@ import org.apache.spark.sql.catalyst.util.{ } import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - import shapeless._ import shapeless.ops.hlist.IsHCons import com.sparkutils.shim.expressions.{ @@ -34,6 +32,8 @@ import org.apache.spark.sql.shim.{ Invoke5 => Invoke } +import scala.collection.immutable.{ ListSet, TreeSet } + abstract class TypedEncoder[T]( implicit val classTag: ClassTag[T]) @@ -509,10 +509,70 @@ object TypedEncoder { override def toString: String = s"arrayEncoder($jvmRepr)" } - implicit def collectionEncoder[C[X] <: Seq[X], T]( + /** + * Per #804 - when MapObjects is used in interpreted mode the type returned is Seq, not the derived type used in compilation + * + * This type class offers extensible conversion for more specific types. By default Seq, List and Vector for Seq's and Set, TreeSet and ListSet are supported. + * + * @tparam C + */ + trait CollectionConversion[F[_], C[_], Y] extends Serializable { + def convert(c: F[Y]): C[Y] + } + + object CollectionConversion { + + implicit def seqToSeq[Y] = new CollectionConversion[Seq, Seq, Y] { + override def convert(c: Seq[Y]): Seq[Y] = c + } + + implicit def seqToVector[Y] = new CollectionConversion[Seq, Vector, Y] { + override def convert(c: Seq[Y]): Vector[Y] = c.toVector + } + + implicit def seqToList[Y] = new CollectionConversion[Seq, List, Y] { + override def convert(c: Seq[Y]): List[Y] = c.toList + } + + implicit def setToSet[Y] = new CollectionConversion[Set, Set, Y] { + override def convert(c: Set[Y]): Set[Y] = c + } + + implicit def setToTreeSet[Y]( + implicit + ordering: Ordering[Y] + ) = new CollectionConversion[Set, TreeSet, Y] { + + override def convert(c: Set[Y]): TreeSet[Y] = + TreeSet.newBuilder.++=(c).result() + } + + implicit def setToListSet[Y] = new CollectionConversion[Set, ListSet, Y] { + + override def convert(c: Set[Y]): ListSet[Y] = + ListSet.newBuilder.++=(c).result() + } + } + + implicit def seqEncoder[C[X] <: Seq[X], T]( + implicit + i0: Lazy[RecordFieldEncoder[T]], + i1: ClassTag[C[T]], + i2: CollectionConversion[Seq, C, T] + ) = collectionEncoder[Seq, C, T] + + implicit def setEncoder[C[X] <: Set[X], T]( implicit i0: Lazy[RecordFieldEncoder[T]], - i1: ClassTag[C[T]] + i1: ClassTag[C[T]], + i2: CollectionConversion[Set, C, T] + ) = collectionEncoder[Set, C, T] + + def collectionEncoder[O[_], C[X], T]( + implicit + i0: Lazy[RecordFieldEncoder[T]], + i1: ClassTag[C[T]], + i2: CollectionConversion[O, C, T] ): TypedEncoder[C[T]] = new TypedEncoder[C[T]] { private lazy val encodeT = i0.value.encoder @@ -529,38 +589,31 @@ object TypedEncoder { if (ScalaReflection.isNativeType(enc.jvmRepr)) { NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr) } else { - MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable) + // converts to Seq, both Set and Seq handling must convert to Seq first + MapObjects( + enc.toCatalyst, + SeqCaster(path), + enc.jvmRepr, + encodeT.nullable + ) } } def fromCatalyst(path: Expression): Expression = - MapObjects( - i0.value.fromCatalyst, - path, - encodeT.catalystRepr, - encodeT.nullable, - Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly - ) + CollectionCaster[O, C, T]( + MapObjects( + i0.value.fromCatalyst, + path, + encodeT.catalystRepr, + encodeT.nullable, + Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling + ), + implicitly[CollectionConversion[O, C, T]] + ) // This will convert Seq to the appropriate C[_] when eval'ing. override def toString: String = s"collectionEncoder($jvmRepr)" } - /** - * @param i1 implicit lazy `RecordFieldEncoder[T]` to encode individual elements of the set. - * @param i2 implicit `ClassTag[Set[T]]` to provide runtime information about the set type. - * @tparam T the element type of the set. - * @return a `TypedEncoder` instance for `Set[T]`. - */ - implicit def setEncoder[T]( - implicit - i1: shapeless.Lazy[RecordFieldEncoder[T]], - i2: ClassTag[Set[T]] - ): TypedEncoder[Set[T]] = { - implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet) - - TypedEncoder.usingInjection - } - /** * @tparam A the key type * @tparam B the value type diff --git a/dataset/src/main/scala/frameless/functions/Udf.scala b/dataset/src/main/scala/frameless/functions/Udf.scala index c47457bd..f5e5cb7a 100644 --- a/dataset/src/main/scala/frameless/functions/Udf.scala +++ b/dataset/src/main/scala/frameless/functions/Udf.scala @@ -2,7 +2,11 @@ package frameless package functions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, NonSQLExpression} +import org.apache.spark.sql.catalyst.expressions.{ + Expression, + LeafExpression, + NonSQLExpression +} import org.apache.spark.sql.catalyst.expressions.codegen._ import Block._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -67,8 +71,17 @@ trait Udf { ) => TypedColumn[T, R] = { case us => val scalaUdf = - FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R], - s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2], s(1).asInstanceOf[A3])) + FramelessUdf( + f, + us.toList[UntypedExpression[T]], + TypedEncoder[R], + s => + f( + s.head.asInstanceOf[A1], + s(1).asInstanceOf[A2], + s(2).asInstanceOf[A3] + ) + ) new TypedColumn[T, R](scalaUdf) } @@ -81,8 +94,18 @@ trait Udf { def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1, A2, A3, A4) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = { case us => val scalaUdf = - FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R], - s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2], s(1).asInstanceOf[A3], s(1).asInstanceOf[A4])) + FramelessUdf( + f, + us.toList[UntypedExpression[T]], + TypedEncoder[R], + s => + f( + s.head.asInstanceOf[A1], + s(1).asInstanceOf[A2], + s(2).asInstanceOf[A3], + s(3).asInstanceOf[A4] + ) + ) new TypedColumn[T, R](scalaUdf) } @@ -95,8 +118,19 @@ trait Udf { def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = { case us => val scalaUdf = - FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R], - s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2], s(1).asInstanceOf[A3], s(1).asInstanceOf[A4], s(1).asInstanceOf[A5])) + FramelessUdf( + f, + us.toList[UntypedExpression[T]], + TypedEncoder[R], + s => + f( + s.head.asInstanceOf[A1], + s(1).asInstanceOf[A2], + s(2).asInstanceOf[A3], + s(3).asInstanceOf[A4], + s(4).asInstanceOf[A5] + ) + ) new TypedColumn[T, R](scalaUdf) } } @@ -119,7 +153,8 @@ case class FramelessUdf[T, R]( override def toString: String = s"FramelessUdf(${children.mkString(", ")})" - lazy val typedEnc = TypedExpressionEncoder[R](rencoder).asInstanceOf[ExpressionEncoder[R]] + lazy val typedEnc = + TypedExpressionEncoder[R](rencoder).asInstanceOf[ExpressionEncoder[R]] def eval(input: InternalRow): Any = { val jvmTypes = children.map(_.eval(input)) @@ -130,11 +165,10 @@ case class FramelessUdf[T, R]( val retval = if (returnCatalyst == null) null + else if (typedEnc.isSerializedAsStructForTopLevel) + returnCatalyst else - if (typedEnc.isSerializedAsStructForTopLevel) - returnCatalyst - else - returnCatalyst.get(0, dataType) + returnCatalyst.get(0, dataType) retval } diff --git a/dataset/src/test/scala/frameless/EncoderTests.scala b/dataset/src/test/scala/frameless/EncoderTests.scala index 4ebf5d93..ab1f3581 100644 --- a/dataset/src/test/scala/frameless/EncoderTests.scala +++ b/dataset/src/test/scala/frameless/EncoderTests.scala @@ -1,7 +1,6 @@ package frameless -import scala.collection.immutable.Set - +import scala.collection.immutable.{ ListSet, Set, TreeSet } import org.scalatest.matchers.should.Matchers object EncoderTests { @@ -10,6 +9,8 @@ object EncoderTests { case class InstantRow(i: java.time.Instant) case class DurationRow(d: java.time.Duration) case class PeriodRow(p: java.time.Period) + + case class ContainerOf[CC[X] <: Iterable[X]](a: CC[X1[Int]]) } class EncoderTests extends TypedDatasetSuite with Matchers { @@ -32,4 +33,55 @@ class EncoderTests extends TypedDatasetSuite with Matchers { test("It should encode java.time.Period") { implicitly[TypedEncoder[PeriodRow]] } + + def performCollection[C[X] <: Iterable[X]]( + toType: Seq[X1[Int]] => C[X1[Int]] + )(implicit + ce: TypedEncoder[C[X1[Int]]] + ): (Unit, Unit) = evalCodeGens { + + implicit val cte = TypedExpressionEncoder[C[X1[Int]]] + implicit val e = implicitly[TypedEncoder[ContainerOf[C]]] + implicit val te = TypedExpressionEncoder[ContainerOf[C]] + implicit val xe = implicitly[TypedEncoder[X1[ContainerOf[C]]]] + implicit val xte = TypedExpressionEncoder[X1[ContainerOf[C]]] + val v = toType((1 to 20).map(X1(_))) + val ds = { + sqlContext.createDataset(Seq(X1[ContainerOf[C]](ContainerOf[C](v)))) + } + ds.head.a.a shouldBe v + () + } + + test("It should serde a Seq of Objects") { + performCollection[Seq](_) + } + + test("It should serde a Set of Objects") { + performCollection[Set](_) + } + + test("It should serde a Vector of Objects") { + performCollection[Vector](_.toVector) + } + + test("It should serde a TreeSet of Objects") { + // only needed for 2.12 + implicit val ordering = new Ordering[X1[Int]] { + val intordering = implicitly[Ordering[Int]] + + override def compare(x: X1[Int], y: X1[Int]): Int = + intordering.compare(x.a, y.a) + } + + performCollection[TreeSet](TreeSet.newBuilder.++=(_).result()) + } + + test("It should serde a List of Objects") { + performCollection[List](_.toList) + } + + test("It should serde a ListSet of Objects") { + performCollection[ListSet](ListSet.newBuilder.++=(_).result()) + } } diff --git a/dataset/src/test/scala/frameless/functions/UdfTests.scala b/dataset/src/test/scala/frameless/functions/UdfTests.scala index fdade538..a193ddc6 100644 --- a/dataset/src/test/scala/frameless/functions/UdfTests.scala +++ b/dataset/src/test/scala/frameless/functions/UdfTests.scala @@ -7,7 +7,7 @@ import org.scalacheck.Prop._ class UdfTests extends TypedDatasetSuite { -/* + /* implicit def vectorArbitrary[A: Arbitrary]: Arbitrary[Vector[A]] = Arbitrary( for { @@ -17,17 +17,25 @@ class UdfTests extends TypedDatasetSuite { } yield vector ) -*/ + */ test("one argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder](data: Vector[X1[A]], f1: A => B): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder]( + data: Vector[X1[A]], + f1: A => B + ): Prop = { val dataset: TypedDataset[X1[A]] = TypedDataset.create(data) val u1 = udf[X1[A], A, B](f1) val u2 = dataset.makeUDF(f1) val A = dataset.col[A]('a) // filter forces whole codegen - val codegen = dataset.deserialized.filter((_:X1[A]) => true).select(u1(A)).collect().run().toVector + val codegen = dataset.deserialized + .filter((_: X1[A]) => true) + .select(u1(A)) + .collect() + .run() + .toVector // otherwise it uses local relation val local = dataset.select(u2(A)).collect().run().toVector @@ -49,15 +57,22 @@ class UdfTests extends TypedDatasetSuite { // Vector isn't supported by MapObjects, not all collections are equal check(forAll(prop[Option[Vector[String]], Option[Vector[String]]] _)) - def prop2[A: TypedEncoder, B: TypedEncoder](f: A => B)(a: A): Prop = prop(Vector(X1(a)), f) + def prop2[A: TypedEncoder, B: TypedEncoder](f: A => B)(a: A): Prop = + prop(Vector(X1(a)), f) - check(forAll(prop2[Int, Option[Int]](x => if (x % 2 == 0) Some(x) else None) _)) + check( + forAll(prop2[Int, Option[Int]](x => if (x % 2 == 0) Some(x) else None) _) + ) check(forAll(prop2[Option[Int], Int](x => x getOrElse 0) _)) } test("multiple one argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f1: A => A, f2: B => B, f3: C => C): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f1: A => A, + f2: B => B, + f3: C => C + ): Prop = { val dataset = TypedDataset.create(data) val u11 = udf[X3[A, B, C], A, A](f1) val u21 = udf[X3[A, B, C], B, B](f2) @@ -69,8 +84,10 @@ class UdfTests extends TypedDatasetSuite { val B = dataset.col[B]('b) val C = dataset.col[C]('c) - val dataset21 = dataset.select(u11(A), u21(B), u31(C)).collect().run().toVector - val dataset22 = dataset.select(u12(A), u22(B), u32(C)).collect().run().toVector + val dataset21 = + dataset.select(u11(A), u21(B), u31(C)).collect().run().toVector + val dataset22 = + dataset.select(u12(A), u22(B), u32(C)).collect().run().toVector val d = data.map(x => (f1(x.a), f2(x.b), f3(x.c))) (dataset21 ?= d) && (dataset22 ?= d) @@ -83,8 +100,10 @@ class UdfTests extends TypedDatasetSuite { } test("two argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f1: (A, B) => C): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f1: (A, B) => C + ): Prop = { val dataset = TypedDataset.create(data) val u1 = udf[X3[A, B, C], A, B, C](f1) val u2 = dataset.makeUDF(f1) @@ -103,8 +122,11 @@ class UdfTests extends TypedDatasetSuite { } test("multiple two argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f1: (A, B) => C, f2: (B, C) => A): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f1: (A, B) => C, + f2: (B, C) => A + ): Prop = { val dataset = TypedDataset.create(data) val u11 = udf[X3[A, B, C], A, B, C](f1) val u12 = dataset.makeUDF(f1) @@ -115,8 +137,10 @@ class UdfTests extends TypedDatasetSuite { val B = dataset.col[B]('b) val C = dataset.col[C]('c) - val dataset21 = dataset.select(u11(A, B), u21(B, C)).collect().run().toVector - val dataset22 = dataset.select(u12(A, B), u22(B, C)).collect().run().toVector + val dataset21 = + dataset.select(u11(A, B), u21(B, C)).collect().run().toVector + val dataset22 = + dataset.select(u12(A, B), u22(B, C)).collect().run().toVector val d = data.map(x => (f1(x.a, x.b), f2(x.b, x.c))) (dataset21 ?= d) && (dataset22 ?= d) @@ -127,73 +151,96 @@ class UdfTests extends TypedDatasetSuite { } test("three argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f: (A, B, C) => C): Prop = { - val dataset = TypedDataset.create(data) - val u1 = udf[X3[A, B, C], A, B, C, C](f) - val u2 = dataset.makeUDF(f) - - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - - val dataset21 = dataset.select(u1(A, B, C)).collect().run().toVector - val dataset22 = dataset.select(u2(A, B, C)).collect().run().toVector - val d = data.map(x => f(x.a, x.b, x.c)) - - (dataset21 ?= d) && (dataset22 ?= d) + forceInterpreted { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f: (A, B, C) => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u1 = udf[X3[A, B, C], A, B, C, C](f) + val u2 = dataset.makeUDF(f) + + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + + val dataset21 = dataset.select(u1(A, B, C)).collect().run().toVector + val dataset22 = dataset.select(u2(A, B, C)).collect().run().toVector + val d = data.map(x => f(x.a, x.b, x.c)) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int] _)) + check(forAll(prop[String, Int, Int] _)) } - - check(forAll(prop[Int, Int, Int] _)) - check(forAll(prop[String, Int, Int] _)) } test("four argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder] - (data: Vector[X4[A, B, C, D]], f: (A, B, C, D) => C): Prop = { - val dataset = TypedDataset.create(data) - val u1 = udf[X4[A, B, C, D], A, B, C, D, C](f) - val u2 = dataset.makeUDF(f) - - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - val D = dataset.col[D]('d) - - val dataset21 = dataset.select(u1(A, B, C, D)).collect().run().toVector - val dataset22 = dataset.select(u2(A, B, C, D)).collect().run().toVector - val d = data.map(x => f(x.a, x.b, x.c, x.d)) - - (dataset21 ?= d) && (dataset22 ?= d) + forceInterpreted { + def prop[ + A: TypedEncoder, + B: TypedEncoder, + C: TypedEncoder, + D: TypedEncoder + ](data: Vector[X4[A, B, C, D]], + f: (A, B, C, D) => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u1 = udf[X4[A, B, C, D], A, B, C, D, C](f) + val u2 = dataset.makeUDF(f) + + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + val D = dataset.col[D]('d) + + val dataset21 = dataset.select(u1(A, B, C, D)).collect().run().toVector + val dataset22 = dataset.select(u2(A, B, C, D)).collect().run().toVector + val d = data.map(x => f(x.a, x.b, x.c, x.d)) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int, Int] _)) + check(forAll(prop[String, Int, Int, String] _)) + check(forAll(prop[String, String, String, String] _)) + check(forAll(prop[String, Long, String, String] _)) + check(forAll(prop[String, Boolean, Boolean, String] _)) } - - check(forAll(prop[Int, Int, Int, Int] _)) - check(forAll(prop[String, Int, Int, String] _)) - check(forAll(prop[String, String, String, String] _)) - check(forAll(prop[String, Long, String, String] _)) - check(forAll(prop[String, Boolean, Boolean, String] _)) } test("five argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder, E: TypedEncoder] - (data: Vector[X5[A, B, C, D, E]], f: (A, B, C, D, E) => C): Prop = { - val dataset = TypedDataset.create(data) - val u1 = udf[X5[A, B, C, D, E], A, B, C, D, E, C](f) - val u2 = dataset.makeUDF(f) - - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - val D = dataset.col[D]('d) - val E = dataset.col[E]('e) - - val dataset21 = dataset.select(u1(A, B, C, D, E)).collect().run().toVector - val dataset22 = dataset.select(u2(A, B, C, D, E)).collect().run().toVector - val d = data.map(x => f(x.a, x.b, x.c, x.d, x.e)) - - (dataset21 ?= d) && (dataset22 ?= d) + forceInterpreted { + def prop[ + A: TypedEncoder, + B: TypedEncoder, + C: TypedEncoder, + D: TypedEncoder, + E: TypedEncoder + ](data: Vector[X5[A, B, C, D, E]], + f: (A, B, C, D, E) => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u1 = udf[X5[A, B, C, D, E], A, B, C, D, E, C](f) + val u2 = dataset.makeUDF(f) + + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + val D = dataset.col[D]('d) + val E = dataset.col[E]('e) + + val dataset21 = + dataset.select(u1(A, B, C, D, E)).collect().run().toVector + val dataset22 = + dataset.select(u2(A, B, C, D, E)).collect().run().toVector + val d = data.map(x => f(x.a, x.b, x.c, x.d, x.e)) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int, Int, Int] _)) } - - check(forAll(prop[Int, Int, Int, Int, Int] _)) } } diff --git a/dataset/src/test/scala/frameless/package.scala b/dataset/src/test/scala/frameless/package.scala index 813a9666..c5654d7c 100644 --- a/dataset/src/test/scala/frameless/package.scala +++ b/dataset/src/test/scala/frameless/package.scala @@ -1,6 +1,9 @@ import java.time.format.DateTimeFormatter import java.time.{ LocalDateTime => JavaLocalDateTime } +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf + import org.scalacheck.{ Arbitrary, Gen } package object frameless { @@ -39,6 +42,14 @@ package object frameless { def vectorGen[A: Arbitrary]: Gen[Vector[A]] = arbVector[A].arbitrary + implicit def arbSeq[A]( + implicit + A: Arbitrary[A] + ): Arbitrary[scala.collection.Seq[A]] = + Arbitrary(Gen.listOf(A.arbitrary).map(_.toVector.toSeq)) + + def seqGen[A: Arbitrary]: Gen[scala.collection.Seq[A]] = arbSeq[A].arbitrary + implicit val arbUdtEncodedClass: Arbitrary[UdtEncodedClass] = Arbitrary { for { int <- Arbitrary.arbitrary[Int] @@ -161,4 +172,58 @@ package object frameless { } res } + + // from Quality, which is from Spark test versions + + // if this blows then debug on CodeGenerator 1294, 1299 and grab code.body + def forceCodeGen[T](f: => T): T = { + val codegenMode = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + f + } + } + + def forceInterpreted[T](f: => T): T = { + val codegenMode = CodegenObjectFactoryMode.NO_CODEGEN.toString + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + f + } + } + + /** + * runs the same test with both eval and codegen, then does the same again using resolveWith + * + * @param f + * @tparam T + * @return + */ + def evalCodeGens[T](f: => T): (T, T) = + (forceInterpreted(f), forceCodeGen(f)) + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL + * configurations. + */ + protected def withSQLConf[T](pairs: (String, String)*)(f: => T): T = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => conf.setConfString(k, v) } + try f + finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } + }