Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Mar 21, 2024
2 parents e36eac2 + 08d7c3d commit aa1e6de
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 46 deletions.
10 changes: 9 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,15 @@ lazy val datasetSettings =
mc("org.apache.spark.sql.reflection.package$ScalaSubtypeLock$"),
mc("frameless.MapGroups"),
mc(f"frameless.MapGroups$$"),
dmm("frameless.functions.package.litAggr")
dmm("frameless.functions.package.litAggr"),
dmm("org.apache.spark.sql.FramelessInternals.column"),
dmm("frameless.TypedEncoder.collectionEncoder"),
dmm("frameless.TypedEncoder.setEncoder"),
dmm("frameless.functions.FramelessUdf.evalCode"),
dmm("frameless.functions.FramelessUdf.copy"),
dmm("frameless.functions.FramelessUdf.this"),
dmm("frameless.functions.FramelessUdf.apply"),
imt("frameless.functions.FramelessUdf.apply")
)
},
coverageExcludedPackages := "frameless.reflection",
Expand Down
67 changes: 67 additions & 0 deletions dataset/src/main/scala/frameless/CollectionCaster.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package frameless

import frameless.TypedEncoder.CollectionConversion
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{
CodegenContext,
CodegenFallback,
ExprCode
}
import org.apache.spark.sql.catalyst.expressions.{ Expression, UnaryExpression }
import org.apache.spark.sql.types.{ DataType, ObjectType }

case class CollectionCaster[F[_], C[_], Y](
child: Expression,
conversion: CollectionConversion[F, C, Y])
extends UnaryExpression
with CodegenFallback {

protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)

override def eval(input: InternalRow): Any = {
val o = child.eval(input).asInstanceOf[Object]
o match {
case col: F[Y] @unchecked =>
conversion.convert(col)
case _ => o
}
}

override def dataType: DataType = child.dataType
}

case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression)
extends UnaryExpression {

protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)

// eval on interpreted works, fallback on codegen does not, e.g. with ColumnTests.asCol and Vectors, the code generated still has child of type Vector but child eval returns X2, which is not good
override def eval(input: InternalRow): Any = {
val o = child.eval(input).asInstanceOf[Object]
o match {
case col: Set[Y] @unchecked =>
col.toSeq
case _ => o
}
}

def toSeqOr[T](isSet: => T, or: => T): T =
child.dataType match {
case ObjectType(cls)
if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
isSet
case t => or
}

override def dataType: DataType =
toSeqOr(ObjectType(classOf[scala.collection.Seq[_]]), child.dataType)

override protected def doGenCode(
ctx: CodegenContext,
ev: ExprCode
): ExprCode =
defineCodeGen(ctx, ev, c => toSeqOr(s"$c.toVector()", s"$c"))

}
112 changes: 85 additions & 27 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ import org.apache.spark.sql.catalyst.expressions.{
UnsafeArrayData,
Literal
}
import org.apache.spark.sql.FramelessInternals
import org.apache.spark.sql.FramelessInternals.UserDefinedType
import org.apache.spark.sql.{ reflection => ScalaReflection }

import org.apache.spark.sql.catalyst.util.{
ArrayBasedMapData,
DateTimeUtils,
GenericArrayData
}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import shapeless._
import shapeless.ops.hlist.IsHCons
import com.sparkutils.shim.expressions.{
Expand All @@ -34,6 +37,8 @@ import org.apache.spark.sql.shim.{
Invoke5 => Invoke
}

import scala.collection.immutable.{ ListSet, TreeSet }

abstract class TypedEncoder[T](
implicit
val classTag: ClassTag[T])
Expand Down Expand Up @@ -509,10 +514,70 @@ object TypedEncoder {
override def toString: String = s"arrayEncoder($jvmRepr)"
}

implicit def collectionEncoder[C[X] <: Seq[X], T](
/**
* Per #804 - when MapObjects is used in interpreted mode the type returned is Seq, not the derived type used in compilation
*
* This type class offers extensible conversion for more specific types. By default Seq, List and Vector for Seq's and Set, TreeSet and ListSet are supported.
*
* @tparam C
*/
trait CollectionConversion[F[_], C[_], Y] extends Serializable {
def convert(c: F[Y]): C[Y]
}

object CollectionConversion {

implicit def seqToSeq[Y] = new CollectionConversion[Seq, Seq, Y] {
override def convert(c: Seq[Y]): Seq[Y] = c
}

implicit def seqToVector[Y] = new CollectionConversion[Seq, Vector, Y] {
override def convert(c: Seq[Y]): Vector[Y] = c.toVector
}

implicit def seqToList[Y] = new CollectionConversion[Seq, List, Y] {
override def convert(c: Seq[Y]): List[Y] = c.toList
}

implicit def setToSet[Y] = new CollectionConversion[Set, Set, Y] {
override def convert(c: Set[Y]): Set[Y] = c
}

implicit def setToTreeSet[Y](
implicit
ordering: Ordering[Y]
) = new CollectionConversion[Set, TreeSet, Y] {

override def convert(c: Set[Y]): TreeSet[Y] =
TreeSet.newBuilder.++=(c).result()
}

implicit def setToListSet[Y] = new CollectionConversion[Set, ListSet, Y] {

override def convert(c: Set[Y]): ListSet[Y] =
ListSet.newBuilder.++=(c).result()
}
}

implicit def seqEncoder[C[X] <: Seq[X], T](
implicit
i0: Lazy[RecordFieldEncoder[T]],
i1: ClassTag[C[T]],
i2: CollectionConversion[Seq, C, T]
) = collectionEncoder[Seq, C, T]

implicit def setEncoder[C[X] <: Set[X], T](
implicit
i0: Lazy[RecordFieldEncoder[T]],
i1: ClassTag[C[T]]
i1: ClassTag[C[T]],
i2: CollectionConversion[Set, C, T]
) = collectionEncoder[Set, C, T]

def collectionEncoder[O[_], C[X], T](
implicit
i0: Lazy[RecordFieldEncoder[T]],
i1: ClassTag[C[T]],
i2: CollectionConversion[O, C, T]
): TypedEncoder[C[T]] = new TypedEncoder[C[T]] {
private lazy val encodeT = i0.value.encoder

Expand All @@ -529,38 +594,31 @@ object TypedEncoder {
if (ScalaReflection.isNativeType(enc.jvmRepr)) {
NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr)
} else {
MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable)
// converts to Seq, both Set and Seq handling must convert to Seq first
MapObjects(
enc.toCatalyst,
SeqCaster(path),
enc.jvmRepr,
encodeT.nullable
)
}
}

def fromCatalyst(path: Expression): Expression =
MapObjects(
i0.value.fromCatalyst,
path,
encodeT.catalystRepr,
encodeT.nullable,
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly
)
CollectionCaster[O, C, T](
MapObjects(
i0.value.fromCatalyst,
path,
encodeT.catalystRepr,
encodeT.nullable,
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling
),
implicitly[CollectionConversion[O, C, T]]
) // This will convert Seq to the appropriate C[_] when eval'ing.

override def toString: String = s"collectionEncoder($jvmRepr)"
}

/**
* @param i1 implicit lazy `RecordFieldEncoder[T]` to encode individual elements of the set.
* @param i2 implicit `ClassTag[Set[T]]` to provide runtime information about the set type.
* @tparam T the element type of the set.
* @return a `TypedEncoder` instance for `Set[T]`.
*/
implicit def setEncoder[T](
implicit
i1: shapeless.Lazy[RecordFieldEncoder[T]],
i2: ClassTag[Set[T]]
): TypedEncoder[Set[T]] = {
implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet)

TypedEncoder.usingInjection
}

/**
* @tparam A the key type
* @tparam B the value type
Expand Down
16 changes: 0 additions & 16 deletions dataset/src/main/scala/frameless/functions/Udf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,12 @@ trait Udf {
) => TypedColumn[T, R] = {
case us =>
val scalaUdf =
<<<<<<< HEAD
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
=======
FramelessUdf(
f,
us.toList[UntypedExpression[T]],
TypedEncoder[R],
s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2])
)
>>>>>>> 3bdb8ad (#803 - clean udf from #804, no shim start)
new TypedColumn[T, R](scalaUdf)
}

Expand All @@ -75,9 +71,6 @@ trait Udf {
) => TypedColumn[T, R] = {
case us =>
val scalaUdf =
<<<<<<< HEAD
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
=======
FramelessUdf(
f,
us.toList[UntypedExpression[T]],
Expand All @@ -89,7 +82,6 @@ trait Udf {
s(2).asInstanceOf[A3]
)
)
>>>>>>> 3bdb8ad (#803 - clean udf from #804, no shim start)
new TypedColumn[T, R](scalaUdf)
}

Expand All @@ -102,9 +94,6 @@ trait Udf {
def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1, A2, A3, A4) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = {
case us =>
val scalaUdf =
<<<<<<< HEAD
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
=======
FramelessUdf(
f,
us.toList[UntypedExpression[T]],
Expand All @@ -117,7 +106,6 @@ trait Udf {
s(3).asInstanceOf[A4]
)
)
>>>>>>> 3bdb8ad (#803 - clean udf from #804, no shim start)
new TypedColumn[T, R](scalaUdf)
}

Expand All @@ -130,9 +118,6 @@ trait Udf {
def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = {
case us =>
val scalaUdf =
<<<<<<< HEAD
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
=======
FramelessUdf(
f,
us.toList[UntypedExpression[T]],
Expand All @@ -146,7 +131,6 @@ trait Udf {
s(4).asInstanceOf[A5]
)
)
>>>>>>> 3bdb8ad (#803 - clean udf from #804, no shim start)
new TypedColumn[T, R](scalaUdf)
}
}
Expand Down
56 changes: 54 additions & 2 deletions dataset/src/test/scala/frameless/EncoderTests.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package frameless

import scala.collection.immutable.Set

import scala.collection.immutable.{ ListSet, Set, TreeSet }
import org.scalatest.matchers.should.Matchers

object EncoderTests {
Expand All @@ -10,6 +9,8 @@ object EncoderTests {
case class InstantRow(i: java.time.Instant)
case class DurationRow(d: java.time.Duration)
case class PeriodRow(p: java.time.Period)

case class ContainerOf[CC[X] <: Iterable[X]](a: CC[X1[Int]])
}

class EncoderTests extends TypedDatasetSuite with Matchers {
Expand All @@ -32,4 +33,55 @@ class EncoderTests extends TypedDatasetSuite with Matchers {
test("It should encode java.time.Period") {
implicitly[TypedEncoder[PeriodRow]]
}

def performCollection[C[X] <: Iterable[X]](
toType: Seq[X1[Int]] => C[X1[Int]]
)(implicit
ce: TypedEncoder[C[X1[Int]]]
): (Unit, Unit) = evalCodeGens {

implicit val cte = TypedExpressionEncoder[C[X1[Int]]]
implicit val e = implicitly[TypedEncoder[ContainerOf[C]]]
implicit val te = TypedExpressionEncoder[ContainerOf[C]]
implicit val xe = implicitly[TypedEncoder[X1[ContainerOf[C]]]]
implicit val xte = TypedExpressionEncoder[X1[ContainerOf[C]]]
val v = toType((1 to 20).map(X1(_)))
val ds = {
sqlContext.createDataset(Seq(X1[ContainerOf[C]](ContainerOf[C](v))))
}
ds.head.a.a shouldBe v
()
}

test("It should serde a Seq of Objects") {
performCollection[Seq](_)
}

test("It should serde a Set of Objects") {
performCollection[Set](_)
}

test("It should serde a Vector of Objects") {
performCollection[Vector](_.toVector)
}

test("It should serde a TreeSet of Objects") {
// only needed for 2.12
implicit val ordering = new Ordering[X1[Int]] {
val intordering = implicitly[Ordering[Int]]

override def compare(x: X1[Int], y: X1[Int]): Int =
intordering.compare(x.a, y.a)
}

performCollection[TreeSet](TreeSet.newBuilder.++=(_).result())
}

test("It should serde a List of Objects") {
performCollection[List](_.toList)
}

test("It should serde a ListSet of Objects") {
performCollection[ListSet](ListSet.newBuilder.++=(_).result())
}
}
Loading

0 comments on commit aa1e6de

Please sign in to comment.