Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rdd related helpers #132

Merged
merged 6 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/pom-3.2_2.12.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
<artifactId>spark-sql_${scala.compat.version}</artifactId>
<version>${spark3.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.compat.version}</artifactId>
<version>${spark3.version}</version>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*-
* =LICENSE=
* Kotlin Spark API: Examples for Spark 3.2+ (Scala 2.12)
* ----------
* Copyright (C) 2019 - 2022 JetBrains
* ----------
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =LICENSEEND=
*/
package org.jetbrains.kotlinx.spark.examples

import org.apache.spark.SparkConf
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.streaming.Durations
import org.apache.spark.streaming.api.java.JavaStreamingContext
import org.jetbrains.kotlinx.spark.api.withSpark
import scala.Tuple2
import java.io.Serializable

data class Row @JvmOverloads constructor(
var word: String = "",
) : Serializable

fun main() = withSpark {

val context = JavaStreamingContext(
SparkConf()
.setMaster("local[*]")
.setAppName("Test"),
Durations.seconds(1),
)

val lines = context.socketTextStream("localhost", 9999)

val words = lines.flatMap { it.split(" ").iterator() }

words.foreachRDD { rdd, time ->

// todo convert rdd to dataset using kotlin data class?

val rowRdd = rdd.map { Row(it) }

val dataframe = spark.createDataFrame(rowRdd, Row::class.java)
asm0dey marked this conversation as resolved.
Show resolved Hide resolved


}


context.start()
context.awaitTermination()
}
6 changes: 6 additions & 0 deletions kotlin-spark-api/3.2/pom_2.12.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
<version>${spark3.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.compat.version}</artifactId>
<version>${spark3.version}</version>
<scope>provided</scope>
</dependency>

<!-- Test dependencies -->
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ package org.jetbrains.kotlinx.spark.api

import org.apache.hadoop.shaded.org.apache.commons.math3.exception.util.ArgUtils
import org.apache.spark.SparkContext
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.api.java.*
import org.apache.spark.api.java.function.*
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.*
import org.apache.spark.sql.Encoders.*
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -154,6 +155,18 @@ inline fun <reified T> SparkSession.dsOf(vararg t: T): Dataset<T> =
inline fun <reified T> List<T>.toDS(spark: SparkSession): Dataset<T> =
spark.createDataset(this, encoder<T>())

/**
* Utility method to create dataset from RDD
*/
inline fun <reified T> RDD<T>.toDS(spark: SparkSession): Dataset<T> =
spark.createDataset(this, encoder<T>())

/**
* Utility method to create dataset from JavaRDD
*/
inline fun <reified T> JavaRDDLike<T, *>.toDS(spark: SparkSession): Dataset<T> =
spark.createDataset(this.rdd(), encoder<T>())

/**
* Main method of API, which gives you seamless integration with Spark:
* It creates encoder for any given supported type T
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
package org.jetbrains.kotlinx.spark.api

import org.apache.spark.SparkConf
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaRDDLike
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.SparkSession.Builder
import org.apache.spark.sql.UDFRegistration
import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR
Expand Down Expand Up @@ -78,18 +83,39 @@ inline fun withSpark(builder: Builder, logLevel: SparkLogLevel = ERROR, func: KS
KSparkSession(this).apply {
sparkContext.setLogLevel(logLevel)
func()
sc.stop()
asm0dey marked this conversation as resolved.
Show resolved Hide resolved
spark.stop()
}
}
.also { it.stop() }
}

/**
* Wrapper for spark creation which copies params from [sparkConf].
*
* @param sparkConf Sets a list of config options based on this.
* @param logLevel Control our logLevel. This overrides any user-defined log settings.
* @param func function which will be executed in context of [KSparkSession] (it means that `this` inside block will point to [KSparkSession])
*/
@JvmOverloads
inline fun withSpark(sparkConf: SparkConf, logLevel: SparkLogLevel = ERROR, func: KSparkSession.() -> Unit) {
withSpark(
builder = SparkSession.builder().config(sparkConf),
logLevel = logLevel,
func = func,
)
}

/**
* This wrapper over [SparkSession] which provides several additional methods to create [org.apache.spark.sql.Dataset]
*/
@Suppress("EXPERIMENTAL_FEATURE_WARNING", "unused")
inline class KSparkSession(val spark: SparkSession) {
class KSparkSession(val spark: SparkSession) {

val sc: JavaSparkContext = JavaSparkContext(spark.sparkContext)
asm0dey marked this conversation as resolved.
Show resolved Hide resolved

inline fun <reified T> List<T>.toDS() = toDS(spark)
inline fun <reified T> Array<T>.toDS() = spark.dsOf(*this)
inline fun <reified T> dsOf(vararg arg: T) = spark.dsOf(*arg)
inline fun <reified T> RDD<T>.toDS() = toDS(spark)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should T be Serializable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't hold for all right? For instance primitives, are those Serializable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there can't be generic with primitive inside

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, so Int and other primitves that do not implement Serializable are allowed when having inline fun <reified T : Serializable> RDD<T>.toDS() = toDS(spark)? That's news to me. But useful!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe they will be boxed to serializable Integer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I'll add it to broadcast too!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no wait that doesn't hold. You can broadcast a non Serializable List for instance. And... you can also make an RDD of a non Serializable List

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List itself is not serializable, but its implementations are

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, but enforcing it to be Serializable requires users to for instance wrap their list like ArrayList(listOf(1, 2, 3)) which is not ideal...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nono, if "just" listOf doesn't work — let's not implement it.

inline fun <reified T> JavaRDDLike<T, *>.toDS() = toDS(spark)
val udf: UDFRegistration get() = spark.udf()
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ import ch.tutteli.atrium.api.fluent.en_GB.*
import ch.tutteli.atrium.api.verbs.expect
import io.kotest.core.spec.style.ShouldSpec
import io.kotest.matchers.shouldBe
import org.apache.spark.api.java.JavaDoubleRDD
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions.*
import org.apache.spark.sql.streaming.GroupState
Expand Down Expand Up @@ -593,6 +598,43 @@ class ApiTest : ShouldSpec({
it.nullable() shouldBe true
}
}
should("Easily convert a (Java)RDD to a Dataset") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this huge test should be split into smaller ones?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

// scala RDD
val rdd0: RDD<Int> = sc.parallelize(
listOf(1, 2, 3, 4, 5, 6)
).rdd()
val dataset0: Dataset<Int> = rdd0.toDS()
dataset0.show()

dataset0.toList<Int>() shouldBe listOf(1, 2, 3, 4, 5, 6)

// normal JavaRDD
val rdd1: JavaRDD<Int> = sc.parallelize(
listOf(1, 2, 3, 4, 5, 6)
)
val dataset1: Dataset<Int> = rdd1.toDS()
dataset1.show()

dataset1.toList<Int>() shouldBe listOf(1, 2, 3, 4, 5, 6)

// JavaDoubleRDD
val rdd2: JavaDoubleRDD = sc.parallelizeDoubles(
listOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)
)
val dataset2: Dataset<Double> = rdd2.toDS()
dataset2.show()

dataset2.toList<Double>() shouldBe listOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)

// JavaPairRDD
val rdd3: JavaPairRDD<Int, Double> = sc.parallelizePairs(
listOf(Tuple2(1, 1.0), Tuple2(2, 2.0), Tuple2(3, 3.0))
)
val dataset3: Dataset<Tuple2<Int, Double>> = rdd3.toDS()
dataset3.show()

dataset3.toList<Tuple2<Int, Double>>() shouldBe listOf(Tuple2(1, 1.0), Tuple2(2, 2.0), Tuple2(3, 3.0))
asm0dey marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
})
Expand Down