diff --git a/dataset/src/main/scala/frameless/CollectionCaster.scala b/dataset/src/main/scala/frameless/CollectionCaster.scala index 8eb712a2..55e7ca7d 100644 --- a/dataset/src/main/scala/frameless/CollectionCaster.scala +++ b/dataset/src/main/scala/frameless/CollectionCaster.scala @@ -1,24 +1,52 @@ package frameless -import frameless.TypedEncoder.SeqConversion +import frameless.TypedEncoder.CollectionConversion import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, ObjectType} -case class CollectionCaster[C[_]](child: Expression, conversion: SeqConversion[C]) extends UnaryExpression with CodegenFallback { +case class CollectionCaster[F[_],C[_],Y](child: Expression, conversion: CollectionConversion[F,C,Y]) extends UnaryExpression with CodegenFallback { protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) override def eval(input: InternalRow): Any = { val o = child.eval(input).asInstanceOf[Object] o match { - case seq: scala.collection.Seq[_] => - conversion.convertSeq(seq) - case set: scala.collection.Set[_] => - o + case col: F[Y] @unchecked => + conversion.convert(col) case _ => o } } override def dataType: DataType = child.dataType } + +case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression) extends UnaryExpression { + protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) + + // eval on interpreted works, fallback on codegen does not, e.g. with ColumnTests.asCol and Vectors, the code generated still has child of type Vector but child eval returns X2, which is not good + override def eval(input: InternalRow): Any = { + val o = child.eval(input).asInstanceOf[Object] + o match { + case col: Set[Y]@unchecked => + col.toSeq + case _ => o + } + } + + def toSeqOr[T](isSet: => T, or: => T): T = + child.dataType match { + case ObjectType(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + isSet + case t => or + } + + override def dataType: DataType = + toSeqOr(ObjectType(classOf[scala.collection.Seq[_]]), child.dataType) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen(ctx, ev, c => + toSeqOr(s"$c.toSeq()", s"$c") + ) + +} \ No newline at end of file diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index ee762903..f6f9b2e0 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -1,31 +1,24 @@ package frameless import java.math.BigInteger - import java.util.Date - -import java.time.{ Duration, Instant, Period, LocalDate } - +import java.time.{Duration, Instant, LocalDate, Period} import java.sql.Timestamp - import scala.reflect.ClassTag - import org.apache.spark.sql.FramelessInternals import org.apache.spark.sql.FramelessInternals.UserDefinedType -import org.apache.spark.sql.{ reflection => ScalaReflection } +import org.apache.spark.sql.{reflection => ScalaReflection} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ - ArrayBasedMapData, - DateTimeUtils, - GenericArrayData -} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - import shapeless._ import shapeless.ops.hlist.IsHCons +import scala.collection.generic.CanBuildFrom +import scala.collection.immutable.TreeSet + abstract class TypedEncoder[T]( implicit val classTag: ClassTag[T]) @@ -501,27 +494,57 @@ object TypedEncoder { override def toString: String = s"arrayEncoder($jvmRepr)" } - trait SeqConversion[C[_]] extends Serializable { - def convertSeq[Y](c: Seq[Y]): C[Y] + /** + * Per #804 - when MapObjects is used in interpreted mode the type returned is Seq, not the derived type used in compilation + * + * This type class offers extensible conversion for more specific types. By default Seq, List and Vector are supported. + * + * @tparam C + */ + trait CollectionConversion[F[_], C[_], Y] extends Serializable { + def convert(c: F[Y]): C[Y] } - object SeqConversion { - implicit val seqToSeq = new SeqConversion[Seq] { - override def convertSeq[Y](c: Seq[Y]): Seq[Y] = c + object CollectionConversion { + implicit def seqToSeq[Y](implicit cbf: CanBuildFrom[Nothing, Y, Seq[Y]]) = new CollectionConversion[Seq, Seq, Y] { + override def convert(c: Seq[Y]): Seq[Y] = c + } + implicit def seqToVector[Y](implicit cbf: CanBuildFrom[Nothing, Y, Vector[Y]]) = new CollectionConversion[Seq, Vector, Y] { + override def convert(c: Seq[Y]): Vector[Y] = c.toVector + } + implicit def seqToList[Y](implicit cbf: CanBuildFrom[Nothing, Y, List[Y]]) = new CollectionConversion[Seq, List, Y] { + override def convert(c: Seq[Y]): List[Y] = c.toList } - implicit val seqToVector = new SeqConversion[Vector] { - override def convertSeq[Y](c: Seq[Y]): Vector[Y] = c.toVector + implicit def setToSet[Y](implicit cbf: CanBuildFrom[Nothing, Y, Set[Y]]) = new CollectionConversion[Set, Set, Y] { + override def convert(c: Set[Y]): Set[Y] = c } - implicit val seqToList = new SeqConversion[List] { - override def convertSeq[Y](c: Seq[Y]): List[Y] = c.toList + implicit def setToTreeSet[Y](implicit cbf: CanBuildFrom[Nothing, Y, TreeSet[Y]]) = new CollectionConversion[Set, TreeSet, Y] { + override def convert(c: Set[Y]): TreeSet[Y] = c.to[TreeSet] } } - implicit def collectionEncoder[C[X] <: Seq[X], T]( + implicit def seqEncoder[C[X] <: Seq[X], T]( + implicit + i0: Lazy[RecordFieldEncoder[T]], + i1: ClassTag[C[T]], + i2: CollectionConversion[Seq, C, T], + i3: CanBuildFrom[Nothing, T, C[T]] + ) = collectionEncoder[Seq, C, T] + + implicit def setEncoder[C[X] <: Set[X], T]( + implicit + i0: Lazy[RecordFieldEncoder[T]], + i1: ClassTag[C[T]], + i2: CollectionConversion[Set, C, T], + i3: CanBuildFrom[Nothing, T, C[T]] + ) = collectionEncoder[Set, C, T] + + def collectionEncoder[O[_], C[X], T]( implicit i0: Lazy[RecordFieldEncoder[T]], i1: ClassTag[C[T]], - i2: SeqConversion[C] + i2: CollectionConversion[O, C, T], + i3: CanBuildFrom[Nothing, T, C[T]] ): TypedEncoder[C[T]] = new TypedEncoder[C[T]] { private lazy val encodeT = i0.value.encoder @@ -538,20 +561,20 @@ object TypedEncoder { if (ScalaReflection.isNativeType(enc.jvmRepr)) { NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr) } else { - MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable) + // converts to Seq, both Set and Seq handling must convert to Seq first + MapObjects(enc.toCatalyst, SeqCaster(path), enc.jvmRepr, encodeT.nullable) } } def fromCatalyst(path: Expression): Expression = - CollectionCaster( + CollectionCaster[O, C, T]( MapObjects( i0.value.fromCatalyst, path, encodeT.catalystRepr, encodeT.nullable, - Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly - ) - , implicitly[SeqConversion[C]]) + Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling + ), implicitly[CollectionConversion[O,C,T]]) // This will convert Seq to the appropriate C[_] when eval'ing. override def toString: String = s"collectionEncoder($jvmRepr)" } @@ -561,16 +584,18 @@ object TypedEncoder { * @param i2 implicit `ClassTag[Set[T]]` to provide runtime information about the set type. * @tparam T the element type of the set. * @return a `TypedEncoder` instance for `Set[T]`. - */ - implicit def setEncoder[T]( + + implicit def setEncoder[C[X] <: Seq[X], T]( implicit i1: shapeless.Lazy[RecordFieldEncoder[T]], - i2: ClassTag[Set[T]] + i2: ClassTag[Set[T]], + i3: CollectionConversion[Set, C, T], + i4: CanBuildFrom[Nothing, T, C[T]] ): TypedEncoder[Set[T]] = { implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet) TypedEncoder.usingInjection - } + }*/ /** * @tparam A the key type diff --git a/dataset/src/test/scala/frameless/EncoderTests.scala b/dataset/src/test/scala/frameless/EncoderTests.scala index 494ec112..fe144281 100644 --- a/dataset/src/test/scala/frameless/EncoderTests.scala +++ b/dataset/src/test/scala/frameless/EncoderTests.scala @@ -1,7 +1,6 @@ package frameless -import scala.collection.immutable.Set - +import scala.collection.immutable.{Set, TreeSet} import org.scalatest.matchers.should.Matchers object EncoderTests { @@ -12,6 +11,8 @@ object EncoderTests { case class PeriodRow(p: java.time.Period) case class VectorOfObject(a: Vector[X1[Int]]) + + case class TreeSetOfObjects(a: TreeSet[X1[Int]]) } class EncoderTests extends TypedDatasetSuite with Matchers { @@ -36,7 +37,7 @@ class EncoderTests extends TypedDatasetSuite with Matchers { } test("It should encode a Vector of Objects") { - forceInterpreted { + evalCodeGens { implicit val e = implicitly[TypedEncoder[VectorOfObject]] implicit val te = TypedExpressionEncoder[VectorOfObject] implicit val xe = implicitly[TypedEncoder[X1[VectorOfObject]]] @@ -48,4 +49,18 @@ class EncoderTests extends TypedDatasetSuite with Matchers { ds.head.a.a shouldBe v } } + + test("It should encode a TreeSet of Objects") { + evalCodeGens { + implicit val e = implicitly[TypedEncoder[TreeSetOfObjects]] + implicit val te = TypedExpressionEncoder[TreeSetOfObjects] + implicit val xe = implicitly[TypedEncoder[X1[TreeSetOfObjects]]] + implicit val xte = TypedExpressionEncoder[X1[TreeSetOfObjects]] + val v = (1 to 20).map(X1(_)).to[TreeSet] + val ds = { + sqlContext.createDataset(Seq(X1[TreeSetOfObjects](TreeSetOfObjects(v)))) + } + ds.head.a.a shouldBe v + } + } }