Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new JdbcIO read/write params to Scio #4820

Merged
merged 20 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1095,8 +1095,7 @@ class PairSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
* a new SCollection of (key, top `num` values) pairs
* @group per_key
*/
def topByKey(num: Int)(implicit ord: Ordering[V]): SCollection[(K, Iterable[V])] =
this.applyPerKey(Top.perKey[K, V, Ordering[V]](num, ord))(kvListToTuple)
def topByKey(num: Int)(implicit ord: Ordering[V]): SCollection[(K, Iterable[V])] = this.applyPerKey(Top.perKey[K, V, Ordering[V]](num, ord))(kvListToTuple)

/**
* Return an SCollection with the values of each tuple.
Expand Down
49 changes: 24 additions & 25 deletions scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@

package com.spotify.scio.jdbc

import com.spotify.scio.values.SCollection
import com.spotify.scio.ScioContext
import com.spotify.scio.io.{EmptyTap, EmptyTapOf, ScioIO, Tap, TestIO}
import org.apache.beam.sdk.io.{jdbc => beam}
import java.sql.{PreparedStatement, ResultSet}

import com.spotify.scio.coders.{Coder, CoderMaterializer}
import com.spotify.scio.io.TapT
import com.spotify.scio.io._
import com.spotify.scio.values.SCollection
import org.apache.beam.sdk.io.{jdbc => beam}

sealed trait JdbcIO[T] extends ScioIO[T]

Expand All @@ -36,9 +33,10 @@ object JdbcIO {
}

private[jdbc] def jdbcIoId(opts: JdbcIoOptions): String = opts match {
case JdbcReadOptions(connOpts, query, _, _, _, _) => jdbcIoId(connOpts, query)
case JdbcWriteOptions(connOpts, statement, _, _, _, _) =>
jdbcIoId(connOpts, statement)
case readOpts: JdbcReadOptions[_] =>
jdbcIoId(readOpts.connectionOptions, readOpts.query)
case writeOpts: JdbcWriteOptions[_] =>
jdbcIoId(writeOpts.connectionOptions, writeOpts.statement)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nicer like this!

}

private[jdbc] def jdbcIoId(opts: JdbcConnectionOptions, query: String): String = {
Expand Down Expand Up @@ -66,7 +64,7 @@ object JdbcIO {
final case class JdbcSelect[T: Coder](readOptions: JdbcReadOptions[T]) extends JdbcIO[T] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO we've not defined the constructor for JdbcIOs properly. Here we should only have data that allows to identify the target destination (required to distinguish the mocked IO basically), so the connection option and the query.

All the other param should be passed as ReadP or WriteP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also noted that it stands out from other IOs. Will change it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do in another PR. I think we have to check some other IOs too (I recall CsvIO has the same issue)

override type ReadP = Unit
override type WriteP = Nothing
final override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T]
override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T]

override def testId: String = s"JdbcIO(${JdbcIO.jdbcIoId(readOptions)})"

Expand All @@ -77,24 +75,21 @@ final case class JdbcSelect[T: Coder](readOptions: JdbcReadOptions[T]) extends J
.withCoder(coder)
.withDataSourceConfiguration(JdbcIO.dataSourceConfiguration(readOptions.connectionOptions))
.withQuery(readOptions.query)
.withRowMapper(new beam.JdbcIO.RowMapper[T] {
override def mapRow(resultSet: ResultSet): T =
readOptions.rowMapper(resultSet)
})
.withRowMapper(readOptions.rowMapper(_))
.withOutputParallelization(readOptions.outputParallelization)

if (readOptions.dataSourceProviderFn != null) {
transform.withDataSourceProviderFn((_: Void) => readOptions.dataSourceProviderFn())
}
if (readOptions.statementPreparator != null) {
transform = transform
.withStatementPreparator(new beam.JdbcIO.StatementPreparator {
override def setParameters(preparedStatement: PreparedStatement): Unit =
readOptions.statementPreparator(preparedStatement)
})
.withStatementPreparator(readOptions.statementPreparator(_))
}
if (readOptions.fetchSize != JdbcIoOptions.BeamDefaultFetchSize) {
// override default fetch size.
transform = transform.withFetchSize(readOptions.fetchSize)
}
sc.applyTransform(transform)
sc.applyTransform(readOptions.configOverride(transform))
}

override protected def write(data: SCollection[T], params: WriteP): Tap[Nothing] =
Expand All @@ -107,7 +102,7 @@ final case class JdbcSelect[T: Coder](readOptions: JdbcReadOptions[T]) extends J
final case class JdbcWrite[T](writeOptions: JdbcWriteOptions[T]) extends JdbcIO[T] {
override type ReadP = Nothing
override type WriteP = Unit
final override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T]
override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T]

override def testId: String = s"JdbcIO(${JdbcIO.jdbcIoId(writeOptions)})"

Expand All @@ -119,23 +114,27 @@ final case class JdbcWrite[T](writeOptions: JdbcWriteOptions[T]) extends JdbcIO[
.write[T]()
.withDataSourceConfiguration(JdbcIO.dataSourceConfiguration(writeOptions.connectionOptions))
.withStatement(writeOptions.statement)

if (writeOptions.dataSourceProviderFn != null) {
transform.withDataSourceProviderFn((_: Void) => writeOptions.dataSourceProviderFn())
}
if (writeOptions.preparedStatementSetter != null) {
transform = transform
.withPreparedStatementSetter(new beam.JdbcIO.PreparedStatementSetter[T] {
override def setParameters(element: T, preparedStatement: PreparedStatement): Unit =
writeOptions.preparedStatementSetter(element, preparedStatement)
})
.withPreparedStatementSetter(writeOptions.preparedStatementSetter(_, _))
}
if (writeOptions.batchSize != JdbcIoOptions.BeamDefaultBatchSize) {
// override default batch size.
transform = transform.withBatchSize(writeOptions.batchSize)
}
if (writeOptions.autoSharding) {
transform = transform.withAutoSharding()
}

transform = transform
.withRetryConfiguration(writeOptions.retryConfiguration)
.withRetryStrategy(writeOptions.retryStrategy.apply)

data.applyInternal(transform)
data.applyInternal(writeOptions.configOverride(transform))
EmptyTap
}

Expand Down
23 changes: 20 additions & 3 deletions scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package com.spotify.scio.jdbc

import org.apache.beam.sdk.io.jdbc.JdbcIO.{DefaultRetryStrategy, RetryConfiguration}
import org.apache.beam.sdk.io.jdbc.JdbcIO.{DefaultRetryStrategy, Read, RetryConfiguration, Write}

import java.sql.{Driver, PreparedStatement, ResultSet, SQLException}
import javax.sql.DataSource

/**
* Options for a JDBC connection.
Expand Down Expand Up @@ -71,19 +72,28 @@ sealed trait JdbcIoOptions
* use apache beam default fetch size if the value is -1
* @param outputParallelization
* reshuffle result to distribute it to all workers. Default to true.
* @param dataSourceProviderFn
* function to provide a custom [[javax.sql.DataSource]]
*/
final case class JdbcReadOptions[T](
connectionOptions: JdbcConnectionOptions,
query: String,
statementPreparator: PreparedStatement => Unit = null,
rowMapper: ResultSet => T,
fetchSize: Int = JdbcIoOptions.BeamDefaultFetchSize,
outputParallelization: Boolean = JdbcIoOptions.DefaultOutputParallelization
outputParallelization: Boolean = JdbcIoOptions.DefaultOutputParallelization,
dataSourceProviderFn: () => DataSource = null,
configOverride: Read[T] => Read[T] = identity[Read[T]] _
) extends JdbcIoOptions

/**
* Options for writing to a JDBC source.
*
* NB: in case of transient failures, Beam runners may execute parts of write multiple times for
* fault tolerance. Because of that, you should avoid using INSERT statements, since that risks
* duplicating records in the database, or failing due to primary key conflicts. Consider using
* MERGE ("upsert") statements supported by your database instead.
*
* @param connectionOptions
* connection options
* @param statement
Expand All @@ -96,12 +106,19 @@ final case class JdbcReadOptions[T](
* [[org.apache.beam.sdk.io.jdbc.JdbcIO.RetryConfiguration]] for specifying retry behavior
* @param retryStrategy
* A predicate of [[java.sql.SQLException]] indicating a failure to retry
* @param autoSharding
* If true, enables using a dynamically determined number of shards to write.
* @param dataSourceProviderFn
* function to provide a custom [[javax.sql.DataSource]]
*/
final case class JdbcWriteOptions[T](
connectionOptions: JdbcConnectionOptions,
statement: String,
preparedStatementSetter: (T, PreparedStatement) => Unit = null,
batchSize: Long = JdbcIoOptions.BeamDefaultBatchSize,
retryConfiguration: RetryConfiguration = JdbcIoOptions.BeamDefaultRetryConfiguration,
retryStrategy: SQLException => Boolean = new DefaultRetryStrategy().apply
retryStrategy: SQLException => Boolean = new DefaultRetryStrategy().apply,
autoSharding: Boolean = false,
dataSourceProviderFn: () => DataSource = null,
configOverride: Write[T] => Write[T] = identity[Write[T]] _
) extends JdbcIoOptions
72 changes: 72 additions & 0 deletions scio-jdbc/src/test/scala/com/spotify/scio/jdbc/JdbcIOTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package com.spotify.scio.jdbc

import com.spotify.scio.ScioContext
import org.apache.beam.sdk.Pipeline.PipelineVisitor
import org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior
import org.apache.beam.sdk.io.jdbc.{JdbcIO => BJdbcIO}
import org.apache.beam.sdk.runners.TransformHierarchy
import org.apache.beam.sdk.transforms.PTransform
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should._

import java.sql.ResultSet
import scala.collection.mutable.ArrayBuffer

class JdbcIOTests3 extends AnyFlatSpec with Matchers {

it must "add to pipeline overridden Read transform" in {
val args = Array[String]()
val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](args)
val sc = ScioContext(opts)
var expectedTransform: BJdbcIO.Read[String] = null
sc.jdbcSelect[String](
getDefaultReadOptions(opts).copy(configOverride = r => {
expectedTransform = r.withQuery("overridden query")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get back the query instead on memorizing the transform in a var ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can. This was the simplest code, otherwise we would need to match transform by type. Which is not difficult either :)

expectedTransform
})
)

expectedTransform should not be null
getPipelineTransforms(sc) should contain(expectedTransform)
}

private def getPipelineTransforms(sc: ScioContext): Iterable[PTransform[_, _]] = {
val actualTransforms = new ArrayBuffer[PTransform[_, _]]()
sc.pipeline.traverseTopologically(new PipelineVisitor.Defaults {
override def enterCompositeTransform(
node: TransformHierarchy#Node
): PipelineVisitor.CompositeBehavior = {
actualTransforms.addOne(node.getTransform)
CompositeBehavior.ENTER_TRANSFORM
}
})
actualTransforms
}

def getDefaultReadOptions(opts: CloudSqlOptions): JdbcReadOptions[String] =
JdbcReadOptions(
connectionOptions = getConnectionOptions(opts),
query = "SELECT <this> FROM <this>",
rowMapper = (rs: ResultSet) => rs.getString(1)
)

def getWriteOptions(opts: CloudSqlOptions): JdbcWriteOptions[String] =
JdbcWriteOptions[String](
connectionOptions = getConnectionOptions(opts),
statement = "INSERT INTO <this> VALUES( ?, ? ..?)"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks this s unused so far


def getConnectionOptions(opts: CloudSqlOptions): JdbcConnectionOptions =
JdbcConnectionOptions(
username = opts.getCloudSqlUsername,
password = Some(opts.getCloudSqlPassword),
connectionUrl = connectionUrl(opts),
classOf[java.sql.Driver]
)

def connectionUrl(opts: CloudSqlOptions): String =
s"jdbc:mysql://google/${opts.getCloudSqlDb}?" +
s"cloudSqlInstance=${opts.getCloudSqlInstanceConnectionName}&" +
s"socketFactory=com.google.cloud.sql.mysql.SocketFactory"

}