Skip to content

Commit

Permalink
typelevel#804 - encoding for Set derivatives as well - test build
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Mar 20, 2024
1 parent ee38804 commit fb1c109
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 44 deletions.
44 changes: 36 additions & 8 deletions dataset/src/main/scala/frameless/CollectionCaster.scala
Original file line number Diff line number Diff line change
@@ -1,24 +1,52 @@
package frameless

import frameless.TypedEncoder.SeqConversion
import frameless.TypedEncoder.CollectionConversion
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, ObjectType}

case class CollectionCaster[C[_]](child: Expression, conversion: SeqConversion[C]) extends UnaryExpression with CodegenFallback {
case class CollectionCaster[F[_],C[_],Y](child: Expression, conversion: CollectionConversion[F,C,Y]) extends UnaryExpression with CodegenFallback {
protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild)

override def eval(input: InternalRow): Any = {
val o = child.eval(input).asInstanceOf[Object]
o match {
case seq: scala.collection.Seq[_] =>
conversion.convertSeq(seq)
case set: scala.collection.Set[_] =>
o
case col: F[Y] @unchecked =>
conversion.convert(col)
case _ => o
}
}

override def dataType: DataType = child.dataType
}

case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression) extends UnaryExpression {
protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild)

// eval on interpreted works, fallback on codegen does not, e.g. with ColumnTests.asCol and Vectors, the code generated still has child of type Vector but child eval returns X2, which is not good
override def eval(input: InternalRow): Any = {
val o = child.eval(input).asInstanceOf[Object]
o match {
case col: Set[Y]@unchecked =>
col.toSeq
case _ => o
}
}

def toSeqOr[T](isSet: => T, or: => T): T =
child.dataType match {
case ObjectType(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
isSet
case t => or
}

override def dataType: DataType =
toSeqOr(ObjectType(classOf[scala.collection.Seq[_]]), child.dataType)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
defineCodeGen(ctx, ev, c =>
toSeqOr(s"$c.toSeq()", s"$c")
)

}
91 changes: 58 additions & 33 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
package frameless

import java.math.BigInteger

import java.util.Date

import java.time.{ Duration, Instant, Period, LocalDate }

import java.time.{Duration, Instant, LocalDate, Period}
import java.sql.Timestamp

import scala.reflect.ClassTag

import org.apache.spark.sql.FramelessInternals
import org.apache.spark.sql.FramelessInternals.UserDefinedType
import org.apache.spark.sql.{ reflection => ScalaReflection }
import org.apache.spark.sql.{reflection => ScalaReflection}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{
ArrayBasedMapData,
DateTimeUtils,
GenericArrayData
}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import shapeless._
import shapeless.ops.hlist.IsHCons

import scala.collection.generic.CanBuildFrom
import scala.collection.immutable.TreeSet

abstract class TypedEncoder[T](
implicit
val classTag: ClassTag[T])
Expand Down Expand Up @@ -501,27 +494,57 @@ object TypedEncoder {
override def toString: String = s"arrayEncoder($jvmRepr)"
}

trait SeqConversion[C[_]] extends Serializable {
def convertSeq[Y](c: Seq[Y]): C[Y]
/**
* Per #804 - when MapObjects is used in interpreted mode the type returned is Seq, not the derived type used in compilation
*
* This type class offers extensible conversion for more specific types. By default Seq, List and Vector are supported.
*
* @tparam C
*/
trait CollectionConversion[F[_], C[_], Y] extends Serializable {
def convert(c: F[Y]): C[Y]
}

object SeqConversion {
implicit val seqToSeq = new SeqConversion[Seq] {
override def convertSeq[Y](c: Seq[Y]): Seq[Y] = c
object CollectionConversion {
implicit def seqToSeq[Y](implicit cbf: CanBuildFrom[Nothing, Y, Seq[Y]]) = new CollectionConversion[Seq, Seq, Y] {
override def convert(c: Seq[Y]): Seq[Y] = c
}
implicit def seqToVector[Y](implicit cbf: CanBuildFrom[Nothing, Y, Vector[Y]]) = new CollectionConversion[Seq, Vector, Y] {
override def convert(c: Seq[Y]): Vector[Y] = c.toVector
}
implicit def seqToList[Y](implicit cbf: CanBuildFrom[Nothing, Y, List[Y]]) = new CollectionConversion[Seq, List, Y] {
override def convert(c: Seq[Y]): List[Y] = c.toList
}
implicit val seqToVector = new SeqConversion[Vector] {
override def convertSeq[Y](c: Seq[Y]): Vector[Y] = c.toVector
implicit def setToSet[Y](implicit cbf: CanBuildFrom[Nothing, Y, Set[Y]]) = new CollectionConversion[Set, Set, Y] {
override def convert(c: Set[Y]): Set[Y] = c
}
implicit val seqToList = new SeqConversion[List] {
override def convertSeq[Y](c: Seq[Y]): List[Y] = c.toList
implicit def setToTreeSet[Y](implicit cbf: CanBuildFrom[Nothing, Y, TreeSet[Y]]) = new CollectionConversion[Set, TreeSet, Y] {
override def convert(c: Set[Y]): TreeSet[Y] = c.to[TreeSet]
}
}

implicit def collectionEncoder[C[X] <: Seq[X], T](
implicit def seqEncoder[C[X] <: Seq[X], T](
implicit
i0: Lazy[RecordFieldEncoder[T]],
i1: ClassTag[C[T]],
i2: CollectionConversion[Seq, C, T],
i3: CanBuildFrom[Nothing, T, C[T]]
) = collectionEncoder[Seq, C, T]

implicit def setEncoder[C[X] <: Set[X], T](
implicit
i0: Lazy[RecordFieldEncoder[T]],
i1: ClassTag[C[T]],
i2: CollectionConversion[Set, C, T],
i3: CanBuildFrom[Nothing, T, C[T]]
) = collectionEncoder[Set, C, T]

def collectionEncoder[O[_], C[X], T](
implicit
i0: Lazy[RecordFieldEncoder[T]],
i1: ClassTag[C[T]],
i2: SeqConversion[C]
i2: CollectionConversion[O, C, T],
i3: CanBuildFrom[Nothing, T, C[T]]
): TypedEncoder[C[T]] = new TypedEncoder[C[T]] {
private lazy val encodeT = i0.value.encoder

Expand All @@ -538,20 +561,20 @@ object TypedEncoder {
if (ScalaReflection.isNativeType(enc.jvmRepr)) {
NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr)
} else {
MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable)
// converts to Seq, both Set and Seq handling must convert to Seq first
MapObjects(enc.toCatalyst, SeqCaster(path), enc.jvmRepr, encodeT.nullable)
}
}

def fromCatalyst(path: Expression): Expression =
CollectionCaster(
CollectionCaster[O, C, T](
MapObjects(
i0.value.fromCatalyst,
path,
encodeT.catalystRepr,
encodeT.nullable,
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly
)
, implicitly[SeqConversion[C]])
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling
), implicitly[CollectionConversion[O,C,T]]) // This will convert Seq to the appropriate C[_] when eval'ing.

override def toString: String = s"collectionEncoder($jvmRepr)"
}
Expand All @@ -561,16 +584,18 @@ object TypedEncoder {
* @param i2 implicit `ClassTag[Set[T]]` to provide runtime information about the set type.
* @tparam T the element type of the set.
* @return a `TypedEncoder` instance for `Set[T]`.
*/
implicit def setEncoder[T](
implicit def setEncoder[C[X] <: Seq[X], T](
implicit
i1: shapeless.Lazy[RecordFieldEncoder[T]],
i2: ClassTag[Set[T]]
i2: ClassTag[Set[T]],
i3: CollectionConversion[Set, C, T],
i4: CanBuildFrom[Nothing, T, C[T]]
): TypedEncoder[Set[T]] = {
implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet)
TypedEncoder.usingInjection
}
}*/

/**
* @tparam A the key type
Expand Down
21 changes: 18 additions & 3 deletions dataset/src/test/scala/frameless/EncoderTests.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package frameless

import scala.collection.immutable.Set

import scala.collection.immutable.{Set, TreeSet}
import org.scalatest.matchers.should.Matchers

object EncoderTests {
Expand All @@ -12,6 +11,8 @@ object EncoderTests {
case class PeriodRow(p: java.time.Period)

case class VectorOfObject(a: Vector[X1[Int]])

case class TreeSetOfObjects(a: TreeSet[X1[Int]])
}

class EncoderTests extends TypedDatasetSuite with Matchers {
Expand All @@ -36,7 +37,7 @@ class EncoderTests extends TypedDatasetSuite with Matchers {
}

test("It should encode a Vector of Objects") {
forceInterpreted {
evalCodeGens {
implicit val e = implicitly[TypedEncoder[VectorOfObject]]
implicit val te = TypedExpressionEncoder[VectorOfObject]
implicit val xe = implicitly[TypedEncoder[X1[VectorOfObject]]]
Expand All @@ -48,4 +49,18 @@ class EncoderTests extends TypedDatasetSuite with Matchers {
ds.head.a.a shouldBe v
}
}

test("It should encode a TreeSet of Objects") {
evalCodeGens {
implicit val e = implicitly[TypedEncoder[TreeSetOfObjects]]
implicit val te = TypedExpressionEncoder[TreeSetOfObjects]
implicit val xe = implicitly[TypedEncoder[X1[TreeSetOfObjects]]]
implicit val xte = TypedExpressionEncoder[X1[TreeSetOfObjects]]
val v = (1 to 20).map(X1(_)).to[TreeSet]
val ds = {
sqlContext.createDataset(Seq(X1[TreeSetOfObjects](TreeSetOfObjects(v))))
}
ds.head.a.a shouldBe v
}
}
}

0 comments on commit fb1c109

Please sign in to comment.