Skip to content

Commit

Permalink
Add support for Spark 4 Preview versions (#423)
Browse files Browse the repository at this point in the history
* Work on supporting Spark 4

* Resolve from maven local for snapshots

* Add 4.0.0-preview{1,2}

* Add Spark 4 preview to actions & readme

* Fix weird merge issue dropping the {

* don't hardcode ref main

* eviction error level global

* update java setup

* Use 17 instead of 1.17

* Swap Column to ColumnGeneratorBase & swap DataframeGenerator to DataFrameGenerator

* Lets make Dataframe into DataFrame to match (yaaaay old typos.)

* Add explicit sql-api

* Add the evil tools as a temporary work-around for the lack of column tools being exposed.

* Fix version check for core src.

* Only create valid decimals because ANIS error doesn't just go null anymore
  • Loading branch information
holdenk authored Sep 30, 2024
1 parent 010605d commit 8a3038e
Show file tree
Hide file tree
Showing 12 changed files with 287 additions and 64 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/github-actions-basic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ jobs:
- spark: "3.4.0"
java-version: "17"
distribution: "temurin"
# - spark: "4.0.0-PREVIEW2"
# java-version: "17"
# distribution: "temurin"
- spark: "4.0.0-preview2"
java-version: "17"
distribution: "temurin"
env:
SPARK_VERSION: ${{ matrix.spark }}
steps:
Expand All @@ -50,11 +50,11 @@ jobs:
with:
fetch-depth: 0
repository: holdenk/spark-testing-base
ref: main
- uses: actions/setup-java@v3
- uses: actions/setup-java@v4
with:
java-version: ${{ matrix.java-version }}
distribution: ${{ matrix.distribution }}
cache: sbt
- name: Cache maven modules
id: cache-maven
uses: actions/cache@v4.0.0
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ While we hope you choose our library, https://github.com/juanrh/sscheck , https:

## [Release Notes](RELEASE_NOTES.md)

JDK17 support exists only for Spark 3.4.0
JDK17 support exists only for Spark 3.4.0 & Spark 4 (previews)

## Security Disclosure e-mails

Expand Down
68 changes: 56 additions & 12 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,17 @@ lazy val core = (project in file("core"))
"org.apache.spark" %% "spark-sql" % sparkVersion.value,
"org.apache.spark" %% "spark-hive" % sparkVersion.value,
"org.apache.spark" %% "spark-catalyst" % sparkVersion.value,
"org.apache.spark" %% "spark-yarn" % sparkVersion.value,
"org.apache.spark" %% "spark-mllib" % sparkVersion.value
) ++ commonDependencies ++
{
if (sparkVersion.value > "3.0.0") {
if (sparkVersion.value > "4.0.0") {
Seq(
"org.apache.spark" %% "spark-sql-api" % sparkVersion.value,
"io.netty" % "netty-all" % "4.1.96.Final",
"io.netty" % "netty-tcnative-classes" % "2.0.66.Final",
"com.github.luben" % "zstd-jni" % "1.5.5-4"
)
} else if (sparkVersion.value > "3.0.0") {
Seq(
"io.netty" % "netty-all" % "4.1.77.Final",
"io.netty" % "netty-tcnative-classes" % "2.0.52.Final"
Expand Down Expand Up @@ -101,14 +107,24 @@ lazy val kafka_0_8 = {
val commonSettings = Seq(
organization := "com.holdenkarau",
publishMavenStyle := true,
libraryDependencySchemes += "com.github.luben" %% "zstd-jni" % "early-semver", // "early-semver",
evictionErrorLevel := Level.Info,
sparkVersion := System.getProperty("sparkVersion", "2.4.8"),
sparkTestingVersion := "1.5.3",
sparkTestingVersion := "1.6.0",
version := sparkVersion.value + "_" + sparkTestingVersion.value,
scalaVersion := {
"2.12.15"
if (sparkVersion.value >= "4.0.0") {
"2.13.13"
} else {
"2.12.15"
}
},
crossScalaVersions := {
if (sparkVersion.value >= "3.2.0") {
if (sparkVersion.value >= "4.0.0") {
Seq("2.13.13")
} else if (sparkVersion.value >= "3.5.0") {
Seq("2.12.15", "2.13.13")
} else if (sparkVersion.value >= "3.2.0") {
Seq("2.12.15", "2.13.10")
} else if (sparkVersion.value >= "3.0.0") {
Seq("2.12.15")
Expand All @@ -118,9 +134,13 @@ val commonSettings = Seq(
},
scalacOptions ++= Seq("-deprecation", "-unchecked", "-Yrangepos"),
javacOptions ++= {
Seq("-source", "1.8", "-target", "1.8")
if (sparkVersion.value >= "4.0.0") {
Seq("-source", "17", "-target", "17")
} else {
Seq("-source", "1.8", "-target", "1.8")
}
},
javaOptions ++= Seq("-Xms5G", "-Xmx5G"),
javaOptions ++= Seq("-Xms8G", "-Xmx8G"),

coverageHighlighting := true,

Expand All @@ -142,16 +162,31 @@ val commonSettings = Seq(
"Typesafe repository" at "https://repo.typesafe.com/typesafe/releases/",
"Second Typesafe repo" at "https://repo.typesafe.com/typesafe/maven-releases/",
"Mesosphere Public Repository" at "https://downloads.mesosphere.io/maven",
Resolver.sonatypeRepo("public")
Resolver.sonatypeRepo("public"),
Resolver.mavenLocal
)
)

// Allow kafka (and other) utils to have version specific files
val coreSources = unmanagedSourceDirectories in Compile := {
if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
if (sparkVersion.value >= "4.0.0") Seq(
(sourceDirectory in Compile)(_ / "4.0/scala"),
(sourceDirectory in Compile)(_ / "2.2/scala"),
(sourceDirectory in Compile)(_ / "3.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/scala"), (sourceDirectory in Compile)(_ / "2.0/java")
(sourceDirectory in Compile)(_ / "2.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/java")
).join.value
else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Compile)(_ / "2.2/scala"),
(sourceDirectory in Compile)(_ / "3.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/java")
).join.value
else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Compile)(_ / "2.2/scala"),
(sourceDirectory in Compile)(_ / "3.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/scala"),
(sourceDirectory in Compile)(_ / "2.0/java")
).join.value
else if (sparkVersion.value >= "2.4.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Compile)(_ / "2.2/scala"),
Expand All @@ -164,7 +199,16 @@ val coreSources = unmanagedSourceDirectories in Compile := {
}

val coreTestSources = unmanagedSourceDirectories in Test := {
if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
if (sparkVersion.value >= "4.0.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Test)(_ / "4.0/scala"),
(sourceDirectory in Test)(_ / "3.0/scala"),
(sourceDirectory in Test)(_ / "3.0/java"),
(sourceDirectory in Test)(_ / "2.2/scala"),
(sourceDirectory in Test)(_ / "2.0/scala"),
(sourceDirectory in Test)(_ / "2.0/java")
).join.value
else if (sparkVersion.value >= "3.0.0" && scalaVersion.value >= "2.12.0") Seq(
(sourceDirectory in Test)(_ / "pre-4.0/scala"),
(sourceDirectory in Test)(_ / "3.0/scala"),
(sourceDirectory in Test)(_ / "3.0/java"),
(sourceDirectory in Test)(_ / "2.2/scala"),
Expand Down Expand Up @@ -243,6 +287,6 @@ lazy val publishSettings = Seq(
}
)

lazy val noPublishSettings =
lazy val noPublishSettings = {
skip in publish := true
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.scalacheck.{Arbitrary, Gen}

object DataframeGenerator {
object DataFrameGenerator {

/**
* Creates a DataFrame Generator for the given Schema.
Expand Down Expand Up @@ -48,13 +48,16 @@ object DataframeGenerator {
*/
def arbitraryDataFrameWithCustomFields(
sqlContext: SQLContext, schema: StructType, minPartitions: Int = 1)
(userGenerators: ColumnGenerator*): Arbitrary[DataFrame] = {
(userGenerators: ColumnGeneratorBase*): Arbitrary[DataFrame] = {
import sqlContext._

val arbitraryRDDs = RDDGenerator.genRDD(
sqlContext.sparkContext, minPartitions)(
getRowGenerator(schema, userGenerators))
Arbitrary {
arbitraryRDDs.map(sqlContext.createDataFrame(_, schema))
arbitraryRDDs.map { r =>
sqlContext.createDataFrame(r, schema)
}
}
}

Expand All @@ -80,7 +83,7 @@ object DataframeGenerator {
* @return Gen[Row]
*/
def getRowGenerator(
schema: StructType, customGenerators: Seq[ColumnGenerator]): Gen[Row] = {
schema: StructType, customGenerators: Seq[ColumnGeneratorBase]): Gen[Row] = {
val generators: List[Gen[Any]] =
createGenerators(schema.fields, customGenerators)
val listGen: Gen[List[Any]] =
Expand All @@ -92,14 +95,14 @@ object DataframeGenerator {

private def createGenerators(
fields: Array[StructField],
userGenerators: Seq[ColumnGenerator]):
userGenerators: Seq[ColumnGeneratorBase]):
List[Gen[Any]] = {
val generatorMap = userGenerators.map(
generator => (generator.columnName -> generator)).toMap
fields.toList.map { field =>
if (generatorMap.contains(field.name)) {
generatorMap.get(field.name) match {
case Some(gen: Column) => gen.gen
case Some(gen: ColumnGenerator) => gen.gen
case Some(list: ColumnList) => getGenerator(field.dataType, list.gen, nullable = field.nullable)
}
}
Expand All @@ -109,7 +112,7 @@ object DataframeGenerator {

private def getGenerator(
dataType: DataType,
generators: Seq[ColumnGenerator] = Seq(),
generators: Seq[ColumnGeneratorBase] = Seq(),
nullable: Boolean = false): Gen[Any] = {
val nonNullGen = dataType match {
case ByteType => Arbitrary.arbitrary[Byte]
Expand All @@ -128,9 +131,21 @@ object DataframeGenerator {
l => new Date(l/10000)
}
case dec: DecimalType => {
// With the new ANSI default we need to make sure were passing in
// valid values.
Arbitrary.arbitrary[BigDecimal]
.retryUntil(_.precision <= dec.precision)
.retryUntil { d =>
try {
val sd = new Decimal()
// Make sure it can be converted
sd.set(d, dec.precision, dec.scale)
true
} catch {
case e: Exception => false
}
}
.map(_.bigDecimal.setScale(dec.scale, RoundingMode.HALF_UP))
.asInstanceOf[Gen[java.math.BigDecimal]]
}
case arr: ArrayType => {
val elementGenerator = getGenerator(arr.elementType, nullable = arr.containsNull)
Expand Down Expand Up @@ -165,27 +180,27 @@ object DataframeGenerator {
}

/**
* Previously ColumnGenerator. Allows the user to specify a generator for a
* Previously Column. Allows the user to specify a generator for a
* specific column.
*/
class Column(val columnName: String, generator: => Gen[Any])
extends ColumnGenerator {
class ColumnGenerator(val columnName: String, generator: => Gen[Any])
extends ColumnGeneratorBase {
lazy val gen = generator
}

/**
* ColumnList allows users to specify custom generators for a list of
* columns inside a StructType column.
*/
class ColumnList(val columnName: String, generators: => Seq[ColumnGenerator])
extends ColumnGenerator {
class ColumnList(val columnName: String, generators: => Seq[ColumnGeneratorBase])
extends ColumnGeneratorBase {
lazy val gen = generators
}

/**
* ColumnGenerator - prevously Column; it is now the base class for all
* ColumnGenerators.
*/
abstract class ColumnGenerator extends java.io.Serializable {
abstract class ColumnGeneratorBase extends java.io.Serializable {
val columnName: String
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ trait Prettify {
val maxNumberOfShownValues = 100

implicit def prettyDataFrame(dataframe: DataFrame): Pretty =
Pretty { _ => describeDataframe(dataframe)}
Pretty { _ => describeDataFrame(dataframe)}

implicit def prettyRDD(rdd: RDD[_]): Pretty =
Pretty { _ => describeRDD(rdd)}

implicit def prettyDataset(dataset: Dataset[_]): Pretty =
Pretty { _ => describeDataset(dataset)}

private def describeDataframe(dataframe: DataFrame) =
private def describeDataFrame(dataframe: DataFrame) =
s"""<DataFrame: schema = ${dataframe.toString}, size = ${dataframe.count()},
|values = (${dataframe.take(maxNumberOfShownValues).mkString(", ")})>""".
stripMargin.replace("\n", " ")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.apache.spark.sql.internal

import org.apache.spark.sql._
import org.apache.spark.sql.internal._
import org.apache.spark.sql.catalyst.expressions._

object EvilExpressionColumnNode {
def getExpr(node: ColumnNode): Expression = {
ColumnNodeToExpressionConverter.apply(node)
}
def toColumnNode(expr: Expression): ColumnNode = {
ExpressionColumnNode(expr)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ class MLScalaCheckTest extends AnyFunSuite with SharedSparkContext with Checkers
test("vector generation") {
val schema = StructType(List(StructField("vector", VectorType)))
val sqlContext = SparkSession.builder.getOrCreate().sqlContext
val dataframeGen = DataframeGenerator.arbitraryDataFrame(sqlContext, schema)
val dataFrameGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema)

val property =
forAll(dataframeGen.arbitrary) {
dataframe => {
dataframe.schema === schema && dataframe.count >= 0
forAll(dataFrameGen.arbitrary) {
dataFrame => {
dataFrame.schema === schema && dataFrame.count >= 0
}
}

Expand All @@ -30,12 +30,12 @@ class MLScalaCheckTest extends AnyFunSuite with SharedSparkContext with Checkers
test("matrix generation") {
val schema = StructType(List(StructField("matrix", MatrixType)))
val sqlContext = SparkSession.builder.getOrCreate().sqlContext
val dataframeGen = DataframeGenerator.arbitraryDataFrame(sqlContext, schema)
val dataFrameGen = DataFrameGenerator.arbitraryDataFrame(sqlContext, schema)

val property =
forAll(dataframeGen.arbitrary) {
dataframe => {
dataframe.schema === schema && dataframe.count >= 0
forAll(dataFrameGen.arbitrary) {
dataFrame => {
dataFrame.schema === schema && dataFrame.count >= 0
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class PrettifyTest extends AnyFunSuite with SharedSparkContext with Checkers wit
test("pretty output of DataFrame's check") {
val schema = StructType(List(StructField("name", StringType), StructField("age", IntegerType)))
val sqlContext = SparkSession.builder.getOrCreate().sqlContext
val nameGenerator = new Column("name", Gen.const("Holden Hanafy"))
val ageGenerator = new Column("age", Gen.const(20))
val nameGenerator = new ColumnGenerator("name", Gen.const("Holden Hanafy"))
val ageGenerator = new ColumnGenerator("age", Gen.const(20))

val dataframeGen = DataframeGenerator.arbitraryDataFrameWithCustomFields(sqlContext, schema)(nameGenerator, ageGenerator)
val dataframeGen = DataFrameGenerator.arbitraryDataFrameWithCustomFields(sqlContext, schema)(nameGenerator, ageGenerator)

val actual = runFailingCheck(dataframeGen.arbitrary)
val expected =
Expand Down
Loading

0 comments on commit 8a3038e

Please sign in to comment.