Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Mar 20, 2024
2 parents afc7ec5 + f0d5f16 commit 5202962
Show file tree
Hide file tree
Showing 7 changed files with 439 additions and 118 deletions.
5 changes: 4 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,10 @@ lazy val datasetSettings =
mc("org.apache.spark.sql.reflection.package$ScalaSubtypeLock$"),
mc("frameless.MapGroups"),
mc(f"frameless.MapGroups$$"),
dmm("frameless.functions.package.litAggr")
dmm("frameless.functions.package.litAggr"),
dmm("org.apache.spark.sql.FramelessInternals.column"),
dmm("frameless.TypedEncoder.collectionEncoder"),
dmm("frameless.TypedEncoder.setEncoder")
)
},
coverageExcludedPackages := "frameless.reflection",
Expand Down
67 changes: 67 additions & 0 deletions dataset/src/main/scala/frameless/CollectionCaster.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package frameless

import frameless.TypedEncoder.CollectionConversion
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{
CodegenContext,
CodegenFallback,
ExprCode
}
import org.apache.spark.sql.catalyst.expressions.{ Expression, UnaryExpression }
import org.apache.spark.sql.types.{ DataType, ObjectType }

case class CollectionCaster[F[_], C[_], Y](
child: Expression,
conversion: CollectionConversion[F, C, Y])
extends UnaryExpression
with CodegenFallback {

protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)

override def eval(input: InternalRow): Any = {
val o = child.eval(input).asInstanceOf[Object]
o match {
case col: F[Y] @unchecked =>
conversion.convert(col)
case _ => o
}
}

override def dataType: DataType = child.dataType
}

case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression)
extends UnaryExpression {

protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)

// eval on interpreted works, fallback on codegen does not, e.g. with ColumnTests.asCol and Vectors, the code generated still has child of type Vector but child eval returns X2, which is not good
override def eval(input: InternalRow): Any = {
val o = child.eval(input).asInstanceOf[Object]
o match {
case col: Set[Y] @unchecked =>
col.toSeq
case _ => o
}
}

def toSeqOr[T](isSet: => T, or: => T): T =
child.dataType match {
case ObjectType(cls)
if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
isSet
case t => or
}

override def dataType: DataType =
toSeqOr(ObjectType(classOf[scala.collection.Seq[_]]), child.dataType)

override protected def doGenCode(
ctx: CodegenContext,
ev: ExprCode
): ExprCode =
defineCodeGen(ctx, ev, c => toSeqOr(s"$c.toVector()", s"$c"))

}
109 changes: 81 additions & 28 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import java.util.Date
import java.time.{ Duration, Instant, LocalDate, Period }
import java.sql.Timestamp
import scala.reflect.ClassTag
import FramelessInternals.UserDefinedType
import org.apache.spark.sql.catalyst.expressions.{
Expression,
UnsafeArrayData,
Expand All @@ -18,7 +17,6 @@ import org.apache.spark.sql.catalyst.util.{
}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import shapeless._
import shapeless.ops.hlist.IsHCons
import com.sparkutils.shim.expressions.{
Expand All @@ -34,6 +32,8 @@ import org.apache.spark.sql.shim.{
Invoke5 => Invoke
}

import scala.collection.immutable.{ ListSet, TreeSet }

abstract class TypedEncoder[T](
implicit
val classTag: ClassTag[T])
Expand Down Expand Up @@ -509,10 +509,70 @@ object TypedEncoder {
override def toString: String = s"arrayEncoder($jvmRepr)"
}

implicit def collectionEncoder[C[X] <: Seq[X], T](
/**
* Per #804 - when MapObjects is used in interpreted mode the type returned is Seq, not the derived type used in compilation
*
* This type class offers extensible conversion for more specific types. By default Seq, List and Vector for Seq's and Set, TreeSet and ListSet are supported.
*
* @tparam C
*/
trait CollectionConversion[F[_], C[_], Y] extends Serializable {
def convert(c: F[Y]): C[Y]
}

object CollectionConversion {

implicit def seqToSeq[Y] = new CollectionConversion[Seq, Seq, Y] {
override def convert(c: Seq[Y]): Seq[Y] = c
}

implicit def seqToVector[Y] = new CollectionConversion[Seq, Vector, Y] {
override def convert(c: Seq[Y]): Vector[Y] = c.toVector
}

implicit def seqToList[Y] = new CollectionConversion[Seq, List, Y] {
override def convert(c: Seq[Y]): List[Y] = c.toList
}

implicit def setToSet[Y] = new CollectionConversion[Set, Set, Y] {
override def convert(c: Set[Y]): Set[Y] = c
}

implicit def setToTreeSet[Y](
implicit
ordering: Ordering[Y]
) = new CollectionConversion[Set, TreeSet, Y] {

override def convert(c: Set[Y]): TreeSet[Y] =
TreeSet.newBuilder.++=(c).result()
}

implicit def setToListSet[Y] = new CollectionConversion[Set, ListSet, Y] {

override def convert(c: Set[Y]): ListSet[Y] =
ListSet.newBuilder.++=(c).result()
}
}

implicit def seqEncoder[C[X] <: Seq[X], T](
implicit
i0: Lazy[RecordFieldEncoder[T]],
i1: ClassTag[C[T]],
i2: CollectionConversion[Seq, C, T]
) = collectionEncoder[Seq, C, T]

implicit def setEncoder[C[X] <: Set[X], T](
implicit
i0: Lazy[RecordFieldEncoder[T]],
i1: ClassTag[C[T]]
i1: ClassTag[C[T]],
i2: CollectionConversion[Set, C, T]
) = collectionEncoder[Set, C, T]

def collectionEncoder[O[_], C[X], T](
implicit
i0: Lazy[RecordFieldEncoder[T]],
i1: ClassTag[C[T]],
i2: CollectionConversion[O, C, T]
): TypedEncoder[C[T]] = new TypedEncoder[C[T]] {
private lazy val encodeT = i0.value.encoder

Expand All @@ -529,38 +589,31 @@ object TypedEncoder {
if (ScalaReflection.isNativeType(enc.jvmRepr)) {
NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr)
} else {
MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable)
// converts to Seq, both Set and Seq handling must convert to Seq first
MapObjects(
enc.toCatalyst,
SeqCaster(path),
enc.jvmRepr,
encodeT.nullable
)
}
}

def fromCatalyst(path: Expression): Expression =
MapObjects(
i0.value.fromCatalyst,
path,
encodeT.catalystRepr,
encodeT.nullable,
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly
)
CollectionCaster[O, C, T](
MapObjects(
i0.value.fromCatalyst,
path,
encodeT.catalystRepr,
encodeT.nullable,
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling
),
implicitly[CollectionConversion[O, C, T]]
) // This will convert Seq to the appropriate C[_] when eval'ing.

override def toString: String = s"collectionEncoder($jvmRepr)"
}

/**
* @param i1 implicit lazy `RecordFieldEncoder[T]` to encode individual elements of the set.
* @param i2 implicit `ClassTag[Set[T]]` to provide runtime information about the set type.
* @tparam T the element type of the set.
* @return a `TypedEncoder` instance for `Set[T]`.
*/
implicit def setEncoder[T](
implicit
i1: shapeless.Lazy[RecordFieldEncoder[T]],
i2: ClassTag[Set[T]]
): TypedEncoder[Set[T]] = {
implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet)

TypedEncoder.usingInjection
}

/**
* @tparam A the key type
* @tparam B the value type
Expand Down
58 changes: 46 additions & 12 deletions dataset/src/main/scala/frameless/functions/Udf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package frameless
package functions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, NonSQLExpression}
import org.apache.spark.sql.catalyst.expressions.{
Expression,
LeafExpression,
NonSQLExpression
}
import org.apache.spark.sql.catalyst.expressions.codegen._
import Block._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -67,8 +71,17 @@ trait Udf {
) => TypedColumn[T, R] = {
case us =>
val scalaUdf =
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R],
s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2], s(1).asInstanceOf[A3]))
FramelessUdf(
f,
us.toList[UntypedExpression[T]],
TypedEncoder[R],
s =>
f(
s.head.asInstanceOf[A1],
s(1).asInstanceOf[A2],
s(2).asInstanceOf[A3]
)
)
new TypedColumn[T, R](scalaUdf)
}

Expand All @@ -81,8 +94,18 @@ trait Udf {
def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1, A2, A3, A4) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = {
case us =>
val scalaUdf =
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R],
s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2], s(1).asInstanceOf[A3], s(1).asInstanceOf[A4]))
FramelessUdf(
f,
us.toList[UntypedExpression[T]],
TypedEncoder[R],
s =>
f(
s.head.asInstanceOf[A1],
s(1).asInstanceOf[A2],
s(2).asInstanceOf[A3],
s(3).asInstanceOf[A4]
)
)
new TypedColumn[T, R](scalaUdf)
}

Expand All @@ -95,8 +118,19 @@ trait Udf {
def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = {
case us =>
val scalaUdf =
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R],
s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2], s(1).asInstanceOf[A3], s(1).asInstanceOf[A4], s(1).asInstanceOf[A5]))
FramelessUdf(
f,
us.toList[UntypedExpression[T]],
TypedEncoder[R],
s =>
f(
s.head.asInstanceOf[A1],
s(1).asInstanceOf[A2],
s(2).asInstanceOf[A3],
s(3).asInstanceOf[A4],
s(4).asInstanceOf[A5]
)
)
new TypedColumn[T, R](scalaUdf)
}
}
Expand All @@ -119,7 +153,8 @@ case class FramelessUdf[T, R](

override def toString: String = s"FramelessUdf(${children.mkString(", ")})"

lazy val typedEnc = TypedExpressionEncoder[R](rencoder).asInstanceOf[ExpressionEncoder[R]]
lazy val typedEnc =
TypedExpressionEncoder[R](rencoder).asInstanceOf[ExpressionEncoder[R]]

def eval(input: InternalRow): Any = {
val jvmTypes = children.map(_.eval(input))
Expand All @@ -130,11 +165,10 @@ case class FramelessUdf[T, R](
val retval =
if (returnCatalyst == null)
null
else if (typedEnc.isSerializedAsStructForTopLevel)
returnCatalyst
else
if (typedEnc.isSerializedAsStructForTopLevel)
returnCatalyst
else
returnCatalyst.get(0, dataType)
returnCatalyst.get(0, dataType)

retval
}
Expand Down
Loading

0 comments on commit 5202962

Please sign in to comment.