Skip to content

Commit

Permalink
typelevel#803 - clean udf eval needs typelevel#804
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Mar 21, 2024
1 parent 3bdb8ad commit c2f3492
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 5 deletions.
6 changes: 5 additions & 1 deletion dataset/src/main/scala/frameless/functions/Udf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ trait Udf {
*
* Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]].
*/
// Possibly add UserDefinedExpression trait to stop the functions being registered and used as aggregates
case class FramelessUdf[T, R](
function: AnyRef,
encoders: Seq[TypedEncoder[_]],
Expand All @@ -156,6 +157,9 @@ case class FramelessUdf[T, R](
lazy val typedEnc =
TypedExpressionEncoder[R](rencoder).asInstanceOf[ExpressionEncoder[R]]

lazy val isSerializedAsStructForTopLevel =
typedEnc.isSerializedAsStructForTopLevel

def eval(input: InternalRow): Any = {
val jvmTypes = children.map(_.eval(input))

Expand All @@ -165,7 +169,7 @@ case class FramelessUdf[T, R](
val retval =
if (returnCatalyst == null)
null
else if (typedEnc.isSerializedAsStructForTopLevel)
else if (isSerializedAsStructForTopLevel)
returnCatalyst
else
returnCatalyst.get(0, dataType)
Expand Down
60 changes: 56 additions & 4 deletions dataset/src/test/scala/frameless/package.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import java.time.format.DateTimeFormatter
import java.time.{ LocalDateTime => JavaLocalDateTime }
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.internal.SQLConf
import org.scalacheck.{ Arbitrary, Cogen, Gen }

import java.time.format.DateTimeFormatter
import java.time.{ LocalDateTime => JavaLocalDateTime }
import org.scalacheck.{ Arbitrary, Gen }
import scala.collection.immutable.{ ListSet, TreeSet }

package object frameless {

Expand Down Expand Up @@ -49,6 +50,46 @@ package object frameless {

def seqGen[A: Arbitrary]: Gen[scala.collection.Seq[A]] = arbSeq[A].arbitrary

implicit def arbList[A](
implicit
A: Arbitrary[A]
): Arbitrary[List[A]] =
Arbitrary(Gen.listOf(A.arbitrary).map(_.toList))

def listGen[A: Arbitrary]: Gen[List[A]] = arbList[A].arbitrary

implicit def arbSet[A](
implicit
A: Arbitrary[A]
): Arbitrary[Set[A]] =
Arbitrary(Gen.listOf(A.arbitrary).map(Set.newBuilder.++=(_).result()))

def setGen[A: Arbitrary]: Gen[Set[A]] = arbSet[A].arbitrary

implicit def cogenListSet[A: Cogen: Ordering]: Cogen[ListSet[A]] =
Cogen.it(_.toVector.sorted.iterator)

implicit def arbListSet[A](
implicit
A: Arbitrary[A]
): Arbitrary[ListSet[A]] =
Arbitrary(Gen.listOf(A.arbitrary).map(ListSet.newBuilder.++=(_).result()))

def listSetGen[A: Arbitrary]: Gen[ListSet[A]] = arbListSet[A].arbitrary

implicit def cogenTreeSet[A: Cogen: Ordering]: Cogen[TreeSet[A]] =
Cogen.it(_.toVector.sorted.iterator)

implicit def arbTreeSet[A](
implicit
A: Arbitrary[A],
o: Ordering[A]
): Arbitrary[TreeSet[A]] =
Arbitrary(Gen.listOf(A.arbitrary).map(TreeSet.newBuilder.++=(_).result()))

def treeSetGen[A: Arbitrary: Ordering]: Gen[TreeSet[A]] =
arbTreeSet[A].arbitrary

implicit val arbUdtEncodedClass: Arbitrary[UdtEncodedClass] = Arbitrary {
for {
int <- Arbitrary.arbitrary[Int]
Expand Down Expand Up @@ -76,7 +117,18 @@ package object frameless {
localDate <- listOfDates
} yield localDate.format(dateTimeFormatter)

val TEST_OUTPUT_DIR = "target/test-output"
private var outputDir: String = _

/** allow usage on non-build environments */
def setOutputDir(path: String): Unit = {
outputDir = path
}

lazy val TEST_OUTPUT_DIR =
if (outputDir ne null)
outputDir
else
"target/test-output"

/**
* Will dive down causes until either the cause is true or there are no more causes
Expand Down

0 comments on commit c2f3492

Please sign in to comment.