Skip to content

Commit

Permalink
Add the evil tools as a temporary work-around for the lack of column …
Browse files Browse the repository at this point in the history
…tools being exposed.
  • Loading branch information
holdenk committed Sep 30, 2024
1 parent 507992a commit 6099c4e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
27 changes: 25 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,24 @@ val commonSettings = Seq(

// Allow kafka (and other) utils to have version specific files
val coreSources = unmanagedSourceDirectories in Compile := {
if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
if (sparkVersion.value >= "4.0.0") Seq(
(sourceDirectory in Compile)(_ / "4.0/scala"),
(sourceDirectory in Compile)(_ / "2.2/scala"),
(sourceDirectory in Compile)(_ / "3.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/scala"), (sourceDirectory in Compile)(_ / "2.0/java")
(sourceDirectory in Compile)(_ / "2.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/java")
).join.value
else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Compile)(_ / "2.2/scala"),
(sourceDirectory in Compile)(_ / "3.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/java")
).join.value
else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Compile)(_ / "2.2/scala"),
(sourceDirectory in Compile)(_ / "3.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/java")
).join.value
else if (sparkVersion.value >= "2.4.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Compile)(_ / "2.2/scala"),
Expand All @@ -186,6 +200,15 @@ val coreSources = unmanagedSourceDirectories in Compile := {

val coreTestSources = unmanagedSourceDirectories in Test := {
if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Test)(_ / "4.0/scala"),
(sourceDirectory in Test)(_ / "3.0/scala"),
(sourceDirectory in Test)(_ / "3.0/java"),
(sourceDirectory in Test)(_ / "2.2/scala"),
(sourceDirectory in Test)(_ / "2.0/scala"),
(sourceDirectory in Test)(_ / "2.0/java")
).join.value
else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Test)(_ / "pre-4.0/scala"),
(sourceDirectory in Test)(_ / "3.0/scala"),
(sourceDirectory in Test)(_ / "3.0/java"),
(sourceDirectory in Test)(_ / "2.2/scala"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.apache.spark.sql.internal

import org.apache.spark.sql._
import org.apache.spark.sql.internal._
import org.apache.spark.sql.catalyst.expressions._

object EvilExpressionColumnNode {
def getExpr(node: ColumnNode): Expression = {
ColumnNodeToExpressionConverter.apply(node)
}
def toColumnNode(expr: Expression): ColumnNode = {
ExpressionColumnNode(expr)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
*/
package com.holdenkarau.spark.testing

import org.apache.spark.sql.{Column => SColumn, SparkSession, DataFrame}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.{Column, SparkSession, DataFrame}
import org.apache.spark.sql.internal._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
Expand Down Expand Up @@ -53,10 +53,11 @@ class SampleSparkExpressionTest extends ScalaDataFrameSuiteBase {
}

object WorkingCodegenExpression {
private def withExpr(expr: Expression): SColumn = new SColumn(expr)
private def withExpr(expr: Expression): Column = new Column(
EvilExpressionColumnNode.toColumnNode(expr))

def work(col: SColumn): SColumn = withExpr {
WorkingCodegenExpression(col.expr)
def work(col: Column): Column = withExpr {
WorkingCodegenExpression(EvilExpressionColumnNode.getExpr(col.node))
}
}

Expand Down Expand Up @@ -93,10 +94,11 @@ case class WorkingCodegenExpression(child: Expression) extends UnaryExpression {
//end::unary[]

object FailingCodegenExpression {
private def withExpr(expr: Expression): SColumn = new SColumn(expr)
private def withExpr(expr: Expression): Column = new Column(
EvilExpressionColumnNode.toColumnNode(expr))

def fail(col: SColumn): SColumn = withExpr {
FailingCodegenExpression(col.expr)
def fail(col: Column): Column = withExpr {
FailingCodegenExpression(EvilExpressionColumnNode.getExpr(col.node))
}
}

Expand Down

0 comments on commit 6099c4e

Please sign in to comment.