diff --git a/build.sbt b/build.sbt index a3dafd97..ded5f1e5 100644 --- a/build.sbt +++ b/build.sbt @@ -273,7 +273,15 @@ lazy val datasetSettings = mc("org.apache.spark.sql.reflection.package$ScalaSubtypeLock$"), mc("frameless.MapGroups"), mc(f"frameless.MapGroups$$"), - dmm("frameless.functions.package.litAggr") + dmm("frameless.functions.package.litAggr"), + dmm("org.apache.spark.sql.FramelessInternals.column"), + dmm("frameless.TypedEncoder.collectionEncoder"), + dmm("frameless.TypedEncoder.setEncoder"), + dmm("frameless.functions.FramelessUdf.evalCode"), + dmm("frameless.functions.FramelessUdf.copy"), + dmm("frameless.functions.FramelessUdf.this"), + dmm("frameless.functions.FramelessUdf.apply"), + imt("frameless.functions.FramelessUdf.apply") ) }, coverageExcludedPackages := "frameless.reflection", diff --git a/dataset/src/main/scala/frameless/CollectionCaster.scala b/dataset/src/main/scala/frameless/CollectionCaster.scala new file mode 100644 index 00000000..bf329992 --- /dev/null +++ b/dataset/src/main/scala/frameless/CollectionCaster.scala @@ -0,0 +1,67 @@ +package frameless + +import frameless.TypedEncoder.CollectionConversion +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{ + CodegenContext, + CodegenFallback, + ExprCode +} +import org.apache.spark.sql.catalyst.expressions.{ Expression, UnaryExpression } +import org.apache.spark.sql.types.{ DataType, ObjectType } + +case class CollectionCaster[F[_], C[_], Y]( + child: Expression, + conversion: CollectionConversion[F, C, Y]) + extends UnaryExpression + with CodegenFallback { + + protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + override def eval(input: InternalRow): Any = { + val o = child.eval(input).asInstanceOf[Object] + o match { + case col: F[Y] @unchecked => + conversion.convert(col) + case _ => o + } + } + + override def dataType: DataType = child.dataType +} + +case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression) + extends UnaryExpression { + + protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + // eval on interpreted works, fallback on codegen does not, e.g. with ColumnTests.asCol and Vectors, the code generated still has child of type Vector but child eval returns X2, which is not good + override def eval(input: InternalRow): Any = { + val o = child.eval(input).asInstanceOf[Object] + o match { + case col: Set[Y] @unchecked => + col.toSeq + case _ => o + } + } + + def toSeqOr[T](isSet: => T, or: => T): T = + child.dataType match { + case ObjectType(cls) + if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + isSet + case t => or + } + + override def dataType: DataType = + toSeqOr(ObjectType(classOf[scala.collection.Seq[_]]), child.dataType) + + override protected def doGenCode( + ctx: CodegenContext, + ev: ExprCode + ): ExprCode = + defineCodeGen(ctx, ev, c => toSeqOr(s"$c.toVector()", s"$c")) + +} diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index b72ff771..e11ec73d 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -11,6 +11,10 @@ import org.apache.spark.sql.catalyst.expressions.{ UnsafeArrayData, Literal } +import org.apache.spark.sql.FramelessInternals +import org.apache.spark.sql.FramelessInternals.UserDefinedType +import org.apache.spark.sql.{ reflection => ScalaReflection } + import org.apache.spark.sql.catalyst.util.{ ArrayBasedMapData, DateTimeUtils, @@ -18,7 +22,6 @@ import org.apache.spark.sql.catalyst.util.{ } import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - import shapeless._ import shapeless.ops.hlist.IsHCons import com.sparkutils.shim.expressions.{ @@ -34,6 +37,8 @@ import org.apache.spark.sql.shim.{ Invoke5 => Invoke } +import scala.collection.immutable.{ ListSet, TreeSet } + abstract class TypedEncoder[T]( implicit val classTag: ClassTag[T]) @@ -509,10 +514,70 @@ object TypedEncoder { override def toString: String = s"arrayEncoder($jvmRepr)" } - implicit def collectionEncoder[C[X] <: Seq[X], T]( + /** + * Per #804 - when MapObjects is used in interpreted mode the type returned is Seq, not the derived type used in compilation + * + * This type class offers extensible conversion for more specific types. By default Seq, List and Vector for Seq's and Set, TreeSet and ListSet are supported. + * + * @tparam C + */ + trait CollectionConversion[F[_], C[_], Y] extends Serializable { + def convert(c: F[Y]): C[Y] + } + + object CollectionConversion { + + implicit def seqToSeq[Y] = new CollectionConversion[Seq, Seq, Y] { + override def convert(c: Seq[Y]): Seq[Y] = c + } + + implicit def seqToVector[Y] = new CollectionConversion[Seq, Vector, Y] { + override def convert(c: Seq[Y]): Vector[Y] = c.toVector + } + + implicit def seqToList[Y] = new CollectionConversion[Seq, List, Y] { + override def convert(c: Seq[Y]): List[Y] = c.toList + } + + implicit def setToSet[Y] = new CollectionConversion[Set, Set, Y] { + override def convert(c: Set[Y]): Set[Y] = c + } + + implicit def setToTreeSet[Y]( + implicit + ordering: Ordering[Y] + ) = new CollectionConversion[Set, TreeSet, Y] { + + override def convert(c: Set[Y]): TreeSet[Y] = + TreeSet.newBuilder.++=(c).result() + } + + implicit def setToListSet[Y] = new CollectionConversion[Set, ListSet, Y] { + + override def convert(c: Set[Y]): ListSet[Y] = + ListSet.newBuilder.++=(c).result() + } + } + + implicit def seqEncoder[C[X] <: Seq[X], T]( + implicit + i0: Lazy[RecordFieldEncoder[T]], + i1: ClassTag[C[T]], + i2: CollectionConversion[Seq, C, T] + ) = collectionEncoder[Seq, C, T] + + implicit def setEncoder[C[X] <: Set[X], T]( implicit i0: Lazy[RecordFieldEncoder[T]], - i1: ClassTag[C[T]] + i1: ClassTag[C[T]], + i2: CollectionConversion[Set, C, T] + ) = collectionEncoder[Set, C, T] + + def collectionEncoder[O[_], C[X], T]( + implicit + i0: Lazy[RecordFieldEncoder[T]], + i1: ClassTag[C[T]], + i2: CollectionConversion[O, C, T] ): TypedEncoder[C[T]] = new TypedEncoder[C[T]] { private lazy val encodeT = i0.value.encoder @@ -529,38 +594,31 @@ object TypedEncoder { if (ScalaReflection.isNativeType(enc.jvmRepr)) { NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr) } else { - MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable) + // converts to Seq, both Set and Seq handling must convert to Seq first + MapObjects( + enc.toCatalyst, + SeqCaster(path), + enc.jvmRepr, + encodeT.nullable + ) } } def fromCatalyst(path: Expression): Expression = - MapObjects( - i0.value.fromCatalyst, - path, - encodeT.catalystRepr, - encodeT.nullable, - Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly - ) + CollectionCaster[O, C, T]( + MapObjects( + i0.value.fromCatalyst, + path, + encodeT.catalystRepr, + encodeT.nullable, + Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling + ), + implicitly[CollectionConversion[O, C, T]] + ) // This will convert Seq to the appropriate C[_] when eval'ing. override def toString: String = s"collectionEncoder($jvmRepr)" } - /** - * @param i1 implicit lazy `RecordFieldEncoder[T]` to encode individual elements of the set. - * @param i2 implicit `ClassTag[Set[T]]` to provide runtime information about the set type. - * @tparam T the element type of the set. - * @return a `TypedEncoder` instance for `Set[T]`. - */ - implicit def setEncoder[T]( - implicit - i1: shapeless.Lazy[RecordFieldEncoder[T]], - i2: ClassTag[Set[T]] - ): TypedEncoder[Set[T]] = { - implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet) - - TypedEncoder.usingInjection - } - /** * @tparam A the key type * @tparam B the value type diff --git a/dataset/src/main/scala/frameless/functions/Udf.scala b/dataset/src/main/scala/frameless/functions/Udf.scala index cebfe4d2..c34e8561 100644 --- a/dataset/src/main/scala/frameless/functions/Udf.scala +++ b/dataset/src/main/scala/frameless/functions/Udf.scala @@ -49,16 +49,12 @@ trait Udf { ) => TypedColumn[T, R] = { case us => val scalaUdf = -<<<<<<< HEAD - FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) -======= FramelessUdf( f, us.toList[UntypedExpression[T]], TypedEncoder[R], s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2]) ) ->>>>>>> 3bdb8ad (#803 - clean udf from #804, no shim start) new TypedColumn[T, R](scalaUdf) } @@ -75,9 +71,6 @@ trait Udf { ) => TypedColumn[T, R] = { case us => val scalaUdf = -<<<<<<< HEAD - FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) -======= FramelessUdf( f, us.toList[UntypedExpression[T]], @@ -89,7 +82,6 @@ trait Udf { s(2).asInstanceOf[A3] ) ) ->>>>>>> 3bdb8ad (#803 - clean udf from #804, no shim start) new TypedColumn[T, R](scalaUdf) } @@ -102,9 +94,6 @@ trait Udf { def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1, A2, A3, A4) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = { case us => val scalaUdf = -<<<<<<< HEAD - FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) -======= FramelessUdf( f, us.toList[UntypedExpression[T]], @@ -117,7 +106,6 @@ trait Udf { s(3).asInstanceOf[A4] ) ) ->>>>>>> 3bdb8ad (#803 - clean udf from #804, no shim start) new TypedColumn[T, R](scalaUdf) } @@ -130,9 +118,6 @@ trait Udf { def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = { case us => val scalaUdf = -<<<<<<< HEAD - FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) -======= FramelessUdf( f, us.toList[UntypedExpression[T]], @@ -146,7 +131,6 @@ trait Udf { s(4).asInstanceOf[A5] ) ) ->>>>>>> 3bdb8ad (#803 - clean udf from #804, no shim start) new TypedColumn[T, R](scalaUdf) } } diff --git a/dataset/src/test/scala/frameless/EncoderTests.scala b/dataset/src/test/scala/frameless/EncoderTests.scala index 4ebf5d93..ab1f3581 100644 --- a/dataset/src/test/scala/frameless/EncoderTests.scala +++ b/dataset/src/test/scala/frameless/EncoderTests.scala @@ -1,7 +1,6 @@ package frameless -import scala.collection.immutable.Set - +import scala.collection.immutable.{ ListSet, Set, TreeSet } import org.scalatest.matchers.should.Matchers object EncoderTests { @@ -10,6 +9,8 @@ object EncoderTests { case class InstantRow(i: java.time.Instant) case class DurationRow(d: java.time.Duration) case class PeriodRow(p: java.time.Period) + + case class ContainerOf[CC[X] <: Iterable[X]](a: CC[X1[Int]]) } class EncoderTests extends TypedDatasetSuite with Matchers { @@ -32,4 +33,55 @@ class EncoderTests extends TypedDatasetSuite with Matchers { test("It should encode java.time.Period") { implicitly[TypedEncoder[PeriodRow]] } + + def performCollection[C[X] <: Iterable[X]]( + toType: Seq[X1[Int]] => C[X1[Int]] + )(implicit + ce: TypedEncoder[C[X1[Int]]] + ): (Unit, Unit) = evalCodeGens { + + implicit val cte = TypedExpressionEncoder[C[X1[Int]]] + implicit val e = implicitly[TypedEncoder[ContainerOf[C]]] + implicit val te = TypedExpressionEncoder[ContainerOf[C]] + implicit val xe = implicitly[TypedEncoder[X1[ContainerOf[C]]]] + implicit val xte = TypedExpressionEncoder[X1[ContainerOf[C]]] + val v = toType((1 to 20).map(X1(_))) + val ds = { + sqlContext.createDataset(Seq(X1[ContainerOf[C]](ContainerOf[C](v)))) + } + ds.head.a.a shouldBe v + () + } + + test("It should serde a Seq of Objects") { + performCollection[Seq](_) + } + + test("It should serde a Set of Objects") { + performCollection[Set](_) + } + + test("It should serde a Vector of Objects") { + performCollection[Vector](_.toVector) + } + + test("It should serde a TreeSet of Objects") { + // only needed for 2.12 + implicit val ordering = new Ordering[X1[Int]] { + val intordering = implicitly[Ordering[Int]] + + override def compare(x: X1[Int], y: X1[Int]): Int = + intordering.compare(x.a, y.a) + } + + performCollection[TreeSet](TreeSet.newBuilder.++=(_).result()) + } + + test("It should serde a List of Objects") { + performCollection[List](_.toList) + } + + test("It should serde a ListSet of Objects") { + performCollection[ListSet](ListSet.newBuilder.++=(_).result()) + } } diff --git a/dataset/src/test/scala/frameless/package.scala b/dataset/src/test/scala/frameless/package.scala index be7afb00..601613c8 100644 --- a/dataset/src/test/scala/frameless/package.scala +++ b/dataset/src/test/scala/frameless/package.scala @@ -212,4 +212,58 @@ package object frameless { } res } + + // from Quality, which is from Spark test versions + + // if this blows then debug on CodeGenerator 1294, 1299 and grab code.body + def forceCodeGen[T](f: => T): T = { + val codegenMode = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + f + } + } + + def forceInterpreted[T](f: => T): T = { + val codegenMode = CodegenObjectFactoryMode.NO_CODEGEN.toString + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + f + } + } + + /** + * runs the same test with both eval and codegen, then does the same again using resolveWith + * + * @param f + * @tparam T + * @return + */ + def evalCodeGens[T](f: => T): (T, T) = + (forceInterpreted(f), forceCodeGen(f)) + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL + * configurations. + */ + protected def withSQLConf[T](pairs: (String, String)*)(f: => T): T = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => conf.setConfString(k, v) } + try f + finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } + }