From 6099c4ebf23d3ee0f50154fa6756809897c619e0 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 29 Sep 2024 19:50:04 -0700 Subject: [PATCH] Add the evil tools as a temporary work-around for the lack of column tools being exposed. --- build.sbt | 27 +++++++++++++++++-- .../apache/spark/sql/internal/evilTools.scala | 14 ++++++++++ .../testing/SampleSparkExpressionTest.scala | 18 +++++++------ 3 files changed, 49 insertions(+), 10 deletions(-) create mode 100644 core/src/main/4.0/scala/org/apache/spark/sql/internal/evilTools.scala diff --git a/build.sbt b/build.sbt index 01bc7cfa..2628d058 100644 --- a/build.sbt +++ b/build.sbt @@ -169,10 +169,24 @@ val commonSettings = Seq( // Allow kafka (and other) utils to have version specific files val coreSources = unmanagedSourceDirectories in Compile := { - if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq( + if (sparkVersion.value >= "4.0.0") Seq( + (sourceDirectory in Compile)(_ / "4.0/scala"), (sourceDirectory in Compile)(_ / "2.2/scala"), (sourceDirectory in Compile)(_ / "3.0/scala"), - (sourceDirectory in Compile)(_ / "2.0/scala"), (sourceDirectory in Compile)(_ / "2.0/java") + (sourceDirectory in Compile)(_ / "2.0/scala"), + (sourceDirectory in Compile)(_ / "2.0/java") + ).join.value + else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq( + (sourceDirectory in Compile)(_ / "2.2/scala"), + (sourceDirectory in Compile)(_ / "3.0/scala"), + (sourceDirectory in Compile)(_ / "2.0/scala"), + (sourceDirectory in Compile)(_ / "2.0/java") + ).join.value + else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq( + (sourceDirectory in Compile)(_ / "2.2/scala"), + (sourceDirectory in Compile)(_ / "3.0/scala"), + (sourceDirectory in Compile)(_ / "2.0/scala"), + (sourceDirectory in Compile)(_ / "2.0/java") ).join.value else if (sparkVersion.value >= "2.4.0" && scalaVersion.value >= "2.12.0") Seq( (sourceDirectory in Compile)(_ / "2.2/scala"), @@ -186,6 +200,15 @@ val coreSources = unmanagedSourceDirectories in Compile := { val coreTestSources = unmanagedSourceDirectories in Test := { if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq( + (sourceDirectory in Test)(_ / "4.0/scala"), + (sourceDirectory in Test)(_ / "3.0/scala"), + (sourceDirectory in Test)(_ / "3.0/java"), + (sourceDirectory in Test)(_ / "2.2/scala"), + (sourceDirectory in Test)(_ / "2.0/scala"), + (sourceDirectory in Test)(_ / "2.0/java") + ).join.value + else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq( + (sourceDirectory in Test)(_ / "pre-4.0/scala"), (sourceDirectory in Test)(_ / "3.0/scala"), (sourceDirectory in Test)(_ / "3.0/java"), (sourceDirectory in Test)(_ / "2.2/scala"), diff --git a/core/src/main/4.0/scala/org/apache/spark/sql/internal/evilTools.scala b/core/src/main/4.0/scala/org/apache/spark/sql/internal/evilTools.scala new file mode 100644 index 00000000..f4a0bb3e --- /dev/null +++ b/core/src/main/4.0/scala/org/apache/spark/sql/internal/evilTools.scala @@ -0,0 +1,14 @@ +package org.apache.spark.sql.internal + +import org.apache.spark.sql._ +import org.apache.spark.sql.internal._ +import org.apache.spark.sql.catalyst.expressions._ + +object EvilExpressionColumnNode { + def getExpr(node: ColumnNode): Expression = { + ColumnNodeToExpressionConverter.apply(node) + } + def toColumnNode(expr: Expression): ColumnNode = { + ExpressionColumnNode(expr) + } +} diff --git a/core/src/test/4.0/scala/com/holdenkarau/spark/testing/SampleSparkExpressionTest.scala b/core/src/test/4.0/scala/com/holdenkarau/spark/testing/SampleSparkExpressionTest.scala index 619d517e..bb8e0805 100644 --- a/core/src/test/4.0/scala/com/holdenkarau/spark/testing/SampleSparkExpressionTest.scala +++ b/core/src/test/4.0/scala/com/holdenkarau/spark/testing/SampleSparkExpressionTest.scala @@ -16,8 +16,8 @@ */ package com.holdenkarau.spark.testing -import org.apache.spark.sql.{Column => SColumn, SparkSession, DataFrame} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.{Column, SparkSession, DataFrame} +import org.apache.spark.sql.internal._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -53,10 +53,11 @@ class SampleSparkExpressionTest extends ScalaDataFrameSuiteBase { } object WorkingCodegenExpression { - private def withExpr(expr: Expression): SColumn = new SColumn(expr) + private def withExpr(expr: Expression): Column = new Column( + EvilExpressionColumnNode.toColumnNode(expr)) - def work(col: SColumn): SColumn = withExpr { - WorkingCodegenExpression(col.expr) + def work(col: Column): Column = withExpr { + WorkingCodegenExpression(EvilExpressionColumnNode.getExpr(col.node)) } } @@ -93,10 +94,11 @@ case class WorkingCodegenExpression(child: Expression) extends UnaryExpression { //end::unary[] object FailingCodegenExpression { - private def withExpr(expr: Expression): SColumn = new SColumn(expr) + private def withExpr(expr: Expression): Column = new Column( + EvilExpressionColumnNode.toColumnNode(expr)) - def fail(col: SColumn): SColumn = withExpr { - FailingCodegenExpression(col.expr) + def fail(col: Column): Column = withExpr { + FailingCodegenExpression(EvilExpressionColumnNode.getExpr(col.node)) } }