diff --git a/scio-examples/src/main/scala/com/spotify/scio/examples/extra/CloudSqlExample.scala b/scio-examples/src/main/scala/com/spotify/scio/examples/extra/CloudSqlExample.scala index 4c960edd3c..7b1692d388 100644 --- a/scio-examples/src/main/scala/com/spotify/scio/examples/extra/CloudSqlExample.scala +++ b/scio-examples/src/main/scala/com/spotify/scio/examples/extra/CloudSqlExample.scala @@ -29,14 +29,24 @@ object CloudSqlExample { val sc = ScioContext(opts) val connOptions = getConnectionOptions(opts) - val readOptions = getReadOptions(connOptions) - val writeOptions = getWriteOptions(connOptions) // Read from Cloud SQL - sc.jdbcSelect(readOptions) - .map(kv => (kv._1.toUpperCase, kv._2)) + sc.jdbcSelect( + connOptions, + // Read from a table called `word_count` with two columns `word` and `count` + "SELECT * FROM word_count" + ) { r => + r.getString(1) -> r.getLong(2) + }.map { case (word, count) => word.toUpperCase -> count } // Write to Cloud SQL - .saveAsJdbc(writeOptions) + .saveAsJdbc( + connOptions, + // Write to a table called `result_word_count` with two columns `word` and `count` + "INSERT INTO result_word_count values(?, ?)" + ) { case ((word, count), s) => + s.setString(1, word) + s.setLong(2, count) + } sc.run() () } @@ -64,23 +74,4 @@ object CloudSqlExample { driverClass = classOf[com.mysql.jdbc.Driver], connectionUrl = getJdbcUrl(opts) ) - - // Read from a table called `word_count` with two columns `word` and `count` - def getReadOptions(connOpts: JdbcConnectionOptions): JdbcReadOptions[(String, Long)] = - JdbcReadOptions( - connectionOptions = connOpts, - query = "SELECT * FROM word_count", - rowMapper = r => (r.getString(1), r.getLong(2)) - ) - - // Write to a table called `result_word_count` with two columns `word` and `count` - def getWriteOptions(connOpts: JdbcConnectionOptions): JdbcWriteOptions[(String, Long)] = - JdbcWriteOptions( - connectionOptions = connOpts, - statement = "INSERT INTO result_word_count values(?, ?)", - preparedStatementSetter = (kv, s) => { - s.setString(1, kv._1) - s.setLong(2, kv._2) - } - ) } diff --git a/scio-examples/src/test/scala/com/spotify/scio/examples/extra/CloudSqlExampleTest.scala b/scio-examples/src/test/scala/com/spotify/scio/examples/extra/CloudSqlExampleTest.scala index 609023199a..bc05032a97 100644 --- a/scio-examples/src/test/scala/com/spotify/scio/examples/extra/CloudSqlExampleTest.scala +++ b/scio-examples/src/test/scala/com/spotify/scio/examples/extra/CloudSqlExampleTest.scala @@ -32,16 +32,18 @@ class CloudSqlExampleTest extends PipelineSpec { ) val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](args) val connOpts = CloudSqlExample.getConnectionOptions(opts) - val readOpts = CloudSqlExample.getReadOptions(connOpts) - val writeOpts = CloudSqlExample.getWriteOptions(connOpts) + val query = "SELECT * FROM word_count" + val statement = "INSERT INTO result_word_count values(?, ?)" val input = Seq("a" -> 1L, "b" -> 2L, "c" -> 3L) val expected = input.map(kv => (kv._1.toUpperCase, kv._2)) JobTest[com.spotify.scio.examples.extra.CloudSqlExample.type] .args(args: _*) - .input(JdbcIO(readOpts), input) - .output(JdbcIO[(String, Long)](writeOpts))(coll => coll should containInAnyOrder(expected)) + .input(JdbcIO(connOpts, query), input) + .output(JdbcIO[(String, Long)](connOpts, statement))(coll => + coll should containInAnyOrder(expected) + ) .run() } } diff --git a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/datastore/DatastoreIO.scala b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/datastore/DatastoreIO.scala index 9cb744bf30..6eee9a1b9a 100644 --- a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/datastore/DatastoreIO.scala +++ b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/datastore/DatastoreIO.scala @@ -22,28 +22,32 @@ import com.spotify.scio.values.SCollection import com.spotify.scio.io.{EmptyTap, EmptyTapOf, ScioIO, Tap, TapT} import com.google.datastore.v1.{Entity, Query} import com.spotify.scio.coders.{Coder, CoderMaterializer} -import org.apache.beam.sdk.io.gcp.{datastore => beam} +import org.apache.beam.sdk.io.gcp.datastore.{DatastoreIO => BDatastoreIO, DatastoreV1 => BDatastore} final case class DatastoreIO(projectId: String) extends ScioIO[Entity] { override type ReadP = DatastoreIO.ReadParam - override type WriteP = Unit + override type WriteP = DatastoreIO.WriteParam override val tapT: TapT.Aux[Entity, Nothing] = EmptyTapOf[Entity] override protected def read(sc: ScioContext, params: ReadP): SCollection[Entity] = { val coder = CoderMaterializer.beam(sc, Coder.protoMessageCoder[Entity]) + val read = BDatastoreIO + .v1() + .read() + .withProjectId(projectId) + .withNamespace(params.namespace) + .withQuery(params.query) sc.applyTransform( - beam.DatastoreIO - .v1() - .read() - .withProjectId(projectId) - .withNamespace(params.namespace) - .withQuery(params.query) + Option(params.configOverride).map(_(read)).getOrElse(read) ).setCoder(coder) } override protected def write(data: SCollection[Entity], params: WriteP): Tap[Nothing] = { - data.applyInternal(beam.DatastoreIO.v1.write.withProjectId(projectId)) + val write = BDatastoreIO.v1.write.withProjectId(projectId) + data.applyInternal( + Option(params.configOverride).map(_(write)).getOrElse(write) + ) EmptyTap } @@ -51,5 +55,13 @@ final case class DatastoreIO(projectId: String) extends ScioIO[Entity] { } object DatastoreIO { - final case class ReadParam(query: Query, namespace: String = null) + final case class ReadParam( + query: Query, + namespace: String = null, + configOverride: BDatastore.Read => BDatastore.Read = identity + ) + + final case class WriteParam( + configOverride: BDatastore.Write => BDatastore.Write = identity + ) } diff --git a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/datastore/syntax/SCollectionSyntax.scala b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/datastore/syntax/SCollectionSyntax.scala index dd692f9f38..dc0af533fb 100644 --- a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/datastore/syntax/SCollectionSyntax.scala +++ b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/datastore/syntax/SCollectionSyntax.scala @@ -22,6 +22,8 @@ import com.spotify.scio.values.SCollection import com.spotify.scio.datastore.DatastoreIO import com.spotify.scio.io.ClosedTap import com.google.datastore.v1.Entity +import com.spotify.scio.datastore.DatastoreIO.WriteParam +import org.apache.beam.sdk.io.gcp.datastore.{DatastoreV1 => BDatastore} final class SCollectionEntityOps[T <: Entity](private val coll: SCollection[T]) extends AnyVal { @@ -29,8 +31,11 @@ final class SCollectionEntityOps[T <: Entity](private val coll: SCollection[T]) * Save this SCollection as a Datastore dataset. Note that elements must be of type `Entity`. * @group output */ - def saveAsDatastore(projectId: String): ClosedTap[Nothing] = - coll.covary_[Entity].write(DatastoreIO(projectId)) + def saveAsDatastore( + projectId: String, + configOverride: BDatastore.Write => BDatastore.Write = identity + ): ClosedTap[Nothing] = + coll.covary_[Entity].write(DatastoreIO(projectId))(WriteParam(configOverride)) } trait SCollectionSyntax { diff --git a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/datastore/syntax/ScioContextSyntax.scala b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/datastore/syntax/ScioContextSyntax.scala index 54ad8e7800..f6399941fb 100644 --- a/scio-google-cloud-platform/src/main/scala/com/spotify/scio/datastore/syntax/ScioContextSyntax.scala +++ b/scio-google-cloud-platform/src/main/scala/com/spotify/scio/datastore/syntax/ScioContextSyntax.scala @@ -21,6 +21,8 @@ import com.spotify.scio.ScioContext import com.spotify.scio.values.SCollection import com.spotify.scio.datastore.DatastoreIO import com.google.datastore.v1.{Entity, Query} +import com.spotify.scio.datastore.DatastoreIO.ReadParam +import org.apache.beam.sdk.io.gcp.datastore.{DatastoreV1 => BDatastore} final class ScioContextOps(private val sc: ScioContext) extends AnyVal { @@ -28,8 +30,13 @@ final class ScioContextOps(private val sc: ScioContext) extends AnyVal { * Get an SCollection for a Datastore query. * @group input */ - def datastore(projectId: String, query: Query, namespace: String = null): SCollection[Entity] = - sc.read(DatastoreIO(projectId))(DatastoreIO.ReadParam(query, namespace)) + def datastore( + projectId: String, + query: Query, + namespace: String = null, + configOverride: BDatastore.Read => BDatastore.Read = identity + ): SCollection[Entity] = + sc.read(DatastoreIO(projectId))(ReadParam(query, namespace, configOverride)) } trait ScioContextSyntax { diff --git a/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcConnectionOptions.scala b/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcConnectionOptions.scala new file mode 100644 index 0000000000..d10dc4c862 --- /dev/null +++ b/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcConnectionOptions.scala @@ -0,0 +1,39 @@ +/* + * Copyright 2019 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.jdbc + +import java.sql.Driver + +/** + * Options for a JDBC connection. + * + * @param username + * database login username + * @param password + * database login password if exists + * @param connectionUrl + * connection url, i.e "jdbc:mysql://[host]:[port]/db?" + * @param driverClass + * subclass of [[java.sql.Driver]] + */ +final case class JdbcConnectionOptions( + username: String, + password: Option[String], + connectionUrl: String, + driverClass: Class[_ <: Driver] +) diff --git a/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcIO.scala b/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcIO.scala index 3e2cb09f41..d1079de75c 100644 --- a/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcIO.scala +++ b/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcIO.scala @@ -17,30 +17,32 @@ package com.spotify.scio.jdbc -import com.spotify.scio.values.SCollection import com.spotify.scio.ScioContext -import com.spotify.scio.io.{EmptyTap, EmptyTapOf, ScioIO, Tap, TestIO} -import org.apache.beam.sdk.io.{jdbc => beam} -import java.sql.{PreparedStatement, ResultSet} - import com.spotify.scio.coders.{Coder, CoderMaterializer} -import com.spotify.scio.io.TapT +import com.spotify.scio.io._ +import com.spotify.scio.values.SCollection +import org.apache.beam.sdk.io.jdbc.{JdbcIO => BJdbcIO} + +import java.sql.{PreparedStatement, ResultSet, SQLException} +import javax.sql.DataSource sealed trait JdbcIO[T] extends ScioIO[T] object JdbcIO { + + @deprecated("Use new API overloads with multiple parameters", since = "0.13.0") final def apply[T](opts: JdbcIoOptions): JdbcIO[T] = + opts match { + case readOpts: JdbcReadOptions[_] => apply(readOpts.connectionOptions, readOpts.query) + case writeOpts: JdbcWriteOptions[_] => apply(writeOpts.connectionOptions, writeOpts.statement) + } + + final def apply[T](opts: JdbcConnectionOptions, query: String): JdbcIO[T] = new JdbcIO[T] with TestIO[T] { final override val tapT = EmptyTapOf[T] - override def testId: String = s"JdbcIO(${jdbcIoId(opts)})" + override def testId: String = s"JdbcIO(${jdbcIoId(opts, query)})" } - private[jdbc] def jdbcIoId(opts: JdbcIoOptions): String = opts match { - case JdbcReadOptions(connOpts, query, _, _, _, _) => jdbcIoId(connOpts, query) - case JdbcWriteOptions(connOpts, statement, _, _, _, _) => - jdbcIoId(connOpts, statement) - } - private[jdbc] def jdbcIoId(opts: JdbcConnectionOptions, query: String): String = { val user = opts.password .fold(s"${opts.username}")(password => s"${opts.username}:$password") @@ -49,52 +51,90 @@ object JdbcIO { private[jdbc] def dataSourceConfiguration( opts: JdbcConnectionOptions - ): beam.JdbcIO.DataSourceConfiguration = + ): BJdbcIO.DataSourceConfiguration = opts.password match { case Some(pass) => - beam.JdbcIO.DataSourceConfiguration + BJdbcIO.DataSourceConfiguration .create(opts.driverClass.getCanonicalName, opts.connectionUrl) .withUsername(opts.username) .withPassword(pass) case None => - beam.JdbcIO.DataSourceConfiguration + BJdbcIO.DataSourceConfiguration .create(opts.driverClass.getCanonicalName, opts.connectionUrl) .withUsername(opts.username) } + + object ReadParam { + private[jdbc] val BeamDefaultFetchSize = -1 + private[jdbc] val DefaultOutputParallelization = true + } + + final case class ReadParam[T]( + rowMapper: ResultSet => T, + statementPreparator: PreparedStatement => Unit = null, + fetchSize: Int = ReadParam.BeamDefaultFetchSize, + outputParallelization: Boolean = ReadParam.DefaultOutputParallelization, + dataSourceProviderFn: () => DataSource = null, + configOverride: BJdbcIO.Read[T] => BJdbcIO.Read[T] = identity[BJdbcIO.Read[T]] _ + ) + + object WriteParam { + private[jdbc] val BeamDefaultBatchSize = -1L + private[jdbc] val BeamDefaultMaxRetryAttempts = 5 + private[jdbc] val BeamDefaultInitialRetryDelay = org.joda.time.Duration.ZERO + private[jdbc] val BeamDefaultMaxRetryDelay = org.joda.time.Duration.ZERO + private[jdbc] val BeamDefaultRetryConfiguration = BJdbcIO.RetryConfiguration.create( + BeamDefaultMaxRetryAttempts, + BeamDefaultMaxRetryDelay, + BeamDefaultInitialRetryDelay + ) + private[jdbc] val DefaultRetryStrategy: SQLException => Boolean = + new BJdbcIO.DefaultRetryStrategy().apply + private[jdbc] val DefaultAutoSharding: Boolean = false + } + + final case class WriteParam[T]( + preparedStatementSetter: (T, PreparedStatement) => Unit, + batchSize: Long = WriteParam.BeamDefaultBatchSize, + retryConfiguration: BJdbcIO.RetryConfiguration = WriteParam.BeamDefaultRetryConfiguration, + retryStrategy: SQLException => Boolean = WriteParam.DefaultRetryStrategy, + autoSharding: Boolean = WriteParam.DefaultAutoSharding, + dataSourceProviderFn: () => DataSource = null, + configOverride: BJdbcIO.Write[T] => BJdbcIO.Write[T] = identity[BJdbcIO.Write[T]] _ + ) } -final case class JdbcSelect[T: Coder](readOptions: JdbcReadOptions[T]) extends JdbcIO[T] { - override type ReadP = Unit +final case class JdbcSelect[T: Coder](opts: JdbcConnectionOptions, query: String) + extends JdbcIO[T] { + override type ReadP = JdbcIO.ReadParam[T] override type WriteP = Nothing - final override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T] + override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T] - override def testId: String = s"JdbcIO(${JdbcIO.jdbcIoId(readOptions)})" + override def testId: String = s"JdbcIO(${JdbcIO.jdbcIoId(opts, query)})" override protected def read(sc: ScioContext, params: ReadP): SCollection[T] = { val coder = CoderMaterializer.beam(sc, Coder[T]) - var transform = beam.JdbcIO + var transform = BJdbcIO .read[T]() .withCoder(coder) - .withDataSourceConfiguration(JdbcIO.dataSourceConfiguration(readOptions.connectionOptions)) - .withQuery(readOptions.query) - .withRowMapper(new beam.JdbcIO.RowMapper[T] { - override def mapRow(resultSet: ResultSet): T = - readOptions.rowMapper(resultSet) - }) - .withOutputParallelization(readOptions.outputParallelization) - - if (readOptions.statementPreparator != null) { + .withDataSourceConfiguration(JdbcIO.dataSourceConfiguration(opts)) + .withQuery(query) + .withRowMapper(params.rowMapper(_)) + .withOutputParallelization(params.outputParallelization) + + if (params.dataSourceProviderFn != null) { + transform.withDataSourceProviderFn((_: Void) => params.dataSourceProviderFn()) + } + if (params.statementPreparator != null) { transform = transform - .withStatementPreparator(new beam.JdbcIO.StatementPreparator { - override def setParameters(preparedStatement: PreparedStatement): Unit = - readOptions.statementPreparator(preparedStatement) - }) + .withStatementPreparator(params.statementPreparator(_)) } - if (readOptions.fetchSize != JdbcIoOptions.BeamDefaultFetchSize) { + if (params.fetchSize != JdbcIO.ReadParam.BeamDefaultFetchSize) { // override default fetch size. - transform = transform.withFetchSize(readOptions.fetchSize) + transform = transform.withFetchSize(params.fetchSize) } - sc.applyTransform(transform) + + sc.applyTransform(params.configOverride(transform)) } override protected def write(data: SCollection[T], params: WriteP): Tap[Nothing] = @@ -104,38 +144,42 @@ final case class JdbcSelect[T: Coder](readOptions: JdbcReadOptions[T]) extends J EmptyTap } -final case class JdbcWrite[T](writeOptions: JdbcWriteOptions[T]) extends JdbcIO[T] { +final case class JdbcWrite[T](opts: JdbcConnectionOptions, statement: String) extends JdbcIO[T] { override type ReadP = Nothing - override type WriteP = Unit - final override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T] + override type WriteP = JdbcIO.WriteParam[T] + override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T] - override def testId: String = s"JdbcIO(${JdbcIO.jdbcIoId(writeOptions)})" + override def testId: String = s"JdbcIO(${JdbcIO.jdbcIoId(opts, statement)})" override protected def read(sc: ScioContext, params: ReadP): SCollection[T] = throw new UnsupportedOperationException("jdbc.Write is write-only") override protected def write(data: SCollection[T], params: WriteP): Tap[Nothing] = { - var transform = beam.JdbcIO + var transform = BJdbcIO .write[T]() - .withDataSourceConfiguration(JdbcIO.dataSourceConfiguration(writeOptions.connectionOptions)) - .withStatement(writeOptions.statement) - if (writeOptions.preparedStatementSetter != null) { + .withDataSourceConfiguration(JdbcIO.dataSourceConfiguration(opts)) + .withStatement(statement) + + if (params.dataSourceProviderFn != null) { + transform.withDataSourceProviderFn((_: Void) => params.dataSourceProviderFn()) + } + if (params.preparedStatementSetter != null) { transform = transform - .withPreparedStatementSetter(new beam.JdbcIO.PreparedStatementSetter[T] { - override def setParameters(element: T, preparedStatement: PreparedStatement): Unit = - writeOptions.preparedStatementSetter(element, preparedStatement) - }) + .withPreparedStatementSetter(params.preparedStatementSetter(_, _)) } - if (writeOptions.batchSize != JdbcIoOptions.BeamDefaultBatchSize) { + if (params.batchSize != JdbcIO.WriteParam.BeamDefaultBatchSize) { // override default batch size. - transform = transform.withBatchSize(writeOptions.batchSize) + transform = transform.withBatchSize(params.batchSize) + } + if (params.autoSharding) { + transform = transform.withAutoSharding() } transform = transform - .withRetryConfiguration(writeOptions.retryConfiguration) - .withRetryStrategy(writeOptions.retryStrategy.apply) + .withRetryConfiguration(params.retryConfiguration) + .withRetryStrategy(params.retryStrategy.apply) - data.applyInternal(transform) + data.applyInternal(params.configOverride(transform)) EmptyTap } diff --git a/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcOptions.scala b/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcIoOptions.scala similarity index 58% rename from scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcOptions.scala rename to scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcIoOptions.scala index ec242d080b..2e88a1852d 100644 --- a/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcOptions.scala +++ b/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcIoOptions.scala @@ -17,43 +17,11 @@ package com.spotify.scio.jdbc -import org.apache.beam.sdk.io.jdbc.JdbcIO.{DefaultRetryStrategy, RetryConfiguration} +import org.apache.beam.sdk.io.jdbc.JdbcIO.RetryConfiguration -import java.sql.{Driver, PreparedStatement, ResultSet, SQLException} - -/** - * Options for a JDBC connection. - * - * @param username - * database login username - * @param password - * database login password if exists - * @param connectionUrl - * connection url, i.e "jdbc:mysql://[host]:[port]/db?" - * @param driverClass - * subclass of [[java.sql.Driver]] - */ -final case class JdbcConnectionOptions( - username: String, - password: Option[String], - connectionUrl: String, - driverClass: Class[_ <: Driver] -) - -object JdbcIoOptions { - private[jdbc] val BeamDefaultBatchSize = -1L - private[jdbc] val BeamDefaultFetchSize = -1 - private[jdbc] val BeamDefaultMaxRetryAttempts = 5 - private[jdbc] val BeamDefaultInitialRetryDelay = org.joda.time.Duration.ZERO - private[jdbc] val BeamDefaultMaxRetryDelay = org.joda.time.Duration.ZERO - private[jdbc] val BeamDefaultRetryConfiguration = RetryConfiguration.create( - BeamDefaultMaxRetryAttempts, - BeamDefaultMaxRetryDelay, - BeamDefaultInitialRetryDelay - ) - private[jdbc] val DefaultOutputParallelization = true -} +import java.sql.{PreparedStatement, ResultSet, SQLException} +@deprecated("Use new API overloads with multiple parameters", since = "0.13.0") sealed trait JdbcIoOptions /** @@ -72,13 +40,14 @@ sealed trait JdbcIoOptions * @param outputParallelization * reshuffle result to distribute it to all workers. Default to true. */ +@deprecated("Use new API overloads with multiple parameters", since = "0.13.0") final case class JdbcReadOptions[T]( connectionOptions: JdbcConnectionOptions, query: String, statementPreparator: PreparedStatement => Unit = null, rowMapper: ResultSet => T, - fetchSize: Int = JdbcIoOptions.BeamDefaultFetchSize, - outputParallelization: Boolean = JdbcIoOptions.DefaultOutputParallelization + fetchSize: Int = JdbcIO.ReadParam.BeamDefaultFetchSize, + outputParallelization: Boolean = JdbcIO.ReadParam.DefaultOutputParallelization ) extends JdbcIoOptions /** @@ -97,11 +66,12 @@ final case class JdbcReadOptions[T]( * @param retryStrategy * A predicate of [[java.sql.SQLException]] indicating a failure to retry */ +@deprecated("Use new API overloads with multiple parameters", since = "0.13.0") final case class JdbcWriteOptions[T]( connectionOptions: JdbcConnectionOptions, statement: String, preparedStatementSetter: (T, PreparedStatement) => Unit = null, - batchSize: Long = JdbcIoOptions.BeamDefaultBatchSize, - retryConfiguration: RetryConfiguration = JdbcIoOptions.BeamDefaultRetryConfiguration, - retryStrategy: SQLException => Boolean = new DefaultRetryStrategy().apply + batchSize: Long = JdbcIO.WriteParam.BeamDefaultBatchSize, + retryConfiguration: RetryConfiguration = JdbcIO.WriteParam.BeamDefaultRetryConfiguration, + retryStrategy: SQLException => Boolean = JdbcIO.WriteParam.DefaultRetryStrategy ) extends JdbcIoOptions diff --git a/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/syntax/SCollectionSyntax.scala b/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/syntax/SCollectionSyntax.scala index 2350654eac..319c0bd6dd 100644 --- a/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/syntax/SCollectionSyntax.scala +++ b/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/syntax/SCollectionSyntax.scala @@ -17,17 +17,76 @@ package com.spotify.scio.jdbc.syntax -import com.spotify.scio.values.SCollection -import com.spotify.scio.jdbc.JdbcWriteOptions import com.spotify.scio.io.ClosedTap -import com.spotify.scio.jdbc.JdbcWrite +import com.spotify.scio.jdbc.{JdbcConnectionOptions, JdbcIO, JdbcWrite, JdbcWriteOptions} +import com.spotify.scio.values.SCollection +import org.apache.beam.sdk.io.jdbc.JdbcIO.{RetryConfiguration, Write} + +import java.sql.{PreparedStatement, SQLException} +import javax.sql.DataSource /** Enhanced version of [[com.spotify.scio.values.SCollection SCollection]] with JDBC methods. */ final class JdbcSCollectionOps[T](private val self: SCollection[T]) extends AnyVal { /** Save this SCollection as a JDBC database. */ + @deprecated("Use another overload with multiple parameters", since = "0.13.0") def saveAsJdbc(writeOptions: JdbcWriteOptions[T]): ClosedTap[Nothing] = - self.write(JdbcWrite(writeOptions)) + saveAsJdbc( + writeOptions.connectionOptions, + writeOptions.statement, + writeOptions.batchSize, + writeOptions.retryConfiguration, + writeOptions.retryStrategy + )(writeOptions.preparedStatementSetter) + + /** + * Save this SCollection as a JDBC database. + * + * NB: in case of transient failures, Beam runners may execute parts of write multiple times for + * fault tolerance. Because of that, you should avoid using INSERT statements, since that risks + * duplicating records in the database, or failing due to primary key conflicts. Consider using + * MERGE ("upsert") statements supported by your database instead. + * + * @param connectionOptions + * connection options + * @param statement + * query statement + * @param preparedStatementSetter + * function to set values in a [[java.sql.PreparedStatement]] + * @param batchSize + * use apache beam default batch size if the value is -1 + * @param retryConfiguration + * [[org.apache.beam.sdk.io.jdbc.JdbcIO.RetryConfiguration]] for specifying retry behavior + * @param retryStrategy + * A predicate of [[java.sql.SQLException]] indicating a failure to retry + * @param autoSharding + * If true, enables using a dynamically determined number of shards to write. + * @param dataSourceProviderFn + * function to provide a custom [[javax.sql.DataSource]] + * @param configOverride + * function to override or replace a Write transform before applying it + */ + def saveAsJdbc( + connectionOptions: JdbcConnectionOptions, + statement: String, + batchSize: Long = JdbcIO.WriteParam.BeamDefaultBatchSize, + retryConfiguration: RetryConfiguration = JdbcIO.WriteParam.BeamDefaultRetryConfiguration, + retryStrategy: SQLException => Boolean = JdbcIO.WriteParam.DefaultRetryStrategy, + autoSharding: Boolean = JdbcIO.WriteParam.DefaultAutoSharding, + dataSourceProviderFn: () => DataSource = null, + configOverride: Write[T] => Write[T] = identity + )(preparedStatementSetter: (T, PreparedStatement) => Unit): ClosedTap[Nothing] = + self.write(JdbcWrite[T](connectionOptions, statement))( + JdbcIO.WriteParam( + preparedStatementSetter, + batchSize, + retryConfiguration, + retryStrategy, + autoSharding, + dataSourceProviderFn, + configOverride + ) + ) } trait SCollectionSyntax { diff --git a/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/syntax/ScioContextSyntax.scala b/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/syntax/ScioContextSyntax.scala index 389519a728..486189d10f 100644 --- a/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/syntax/ScioContextSyntax.scala +++ b/scio-jdbc/src/main/scala/com/spotify/scio/jdbc/syntax/ScioContextSyntax.scala @@ -17,25 +17,75 @@ package com.spotify.scio.jdbc.syntax -import com.spotify.scio.values.SCollection +import com.spotify.scio.ScioContext import com.spotify.scio.coders.Coder +import com.spotify.scio.jdbc.sharded.{JdbcShardedReadOptions, JdbcShardedSelect} +import com.spotify.scio.jdbc.{JdbcConnectionOptions, JdbcIO, JdbcReadOptions, JdbcSelect} +import com.spotify.scio.values.SCollection +import org.apache.beam.sdk.io.jdbc.JdbcIO.Read +import java.sql.{PreparedStatement, ResultSet} +import javax.sql.DataSource import scala.reflect.ClassTag -import com.spotify.scio.ScioContext -import com.spotify.scio.jdbc.sharded.{JdbcShardedReadOptions, JdbcShardedSelect} -import com.spotify.scio.jdbc.{JdbcReadOptions, JdbcSelect} /** Enhanced version of [[ScioContext]] with JDBC methods. */ final class JdbcScioContextOps(private val self: ScioContext) extends AnyVal { /** Get an SCollection for a JDBC query. */ + @deprecated("Use another overload with multiple parameters", since = "0.13.0") def jdbcSelect[T: ClassTag: Coder](readOptions: JdbcReadOptions[T]): SCollection[T] = - self.read(JdbcSelect(readOptions)) + jdbcSelect( + readOptions.connectionOptions, + readOptions.query, + readOptions.statementPreparator, + readOptions.fetchSize, + readOptions.outputParallelization + )(readOptions.rowMapper) + + /** + * Get an SCollection for a JDBC query. + * + * @param connectionOptions + * connection options + * @param query + * query string + * @param rowMapper + * function to map from a SQL [[java.sql.ResultSet]] to `T` + * @param statementPreparator + * function to prepare a [[java.sql.PreparedStatement]] + * @param fetchSize + * use apache beam default fetch size if the value is -1 + * @param outputParallelization + * reshuffle result to distribute it to all workers. Default to true. + * @param dataSourceProviderFn + * function to provide a custom [[javax.sql.DataSource]] + * @param configOverride + * function to override or replace a Read transform before applying it + */ + def jdbcSelect[T: ClassTag: Coder]( + connectionOptions: JdbcConnectionOptions, + query: String, + statementPreparator: PreparedStatement => Unit = null, + fetchSize: Int = JdbcIO.ReadParam.BeamDefaultFetchSize, + outputParallelization: Boolean = JdbcIO.ReadParam.DefaultOutputParallelization, + dataSourceProviderFn: () => DataSource = null, + configOverride: Read[T] => Read[T] = identity[Read[T]] _ + )(rowMapper: ResultSet => T): SCollection[T] = + self.read(JdbcSelect(connectionOptions, query))( + JdbcIO.ReadParam( + rowMapper, + statementPreparator, + fetchSize, + outputParallelization, + dataSourceProviderFn, + configOverride + ) + ) /** * Sharded JDBC read from a table or materialized view. * @param readOptions - * The following paramters in the options class could be specified: + * The following parameters in the options class could be specified: * * shardColumn: the column to shard by. Must be of integer/long type ideally with evenly * distributed values. diff --git a/scio-jdbc/src/test/scala/com/spotify/scio/jdbc/JdbcIOTests.scala b/scio-jdbc/src/test/scala/com/spotify/scio/jdbc/JdbcIOTests.scala new file mode 100644 index 0000000000..e947f03e62 --- /dev/null +++ b/scio-jdbc/src/test/scala/com/spotify/scio/jdbc/JdbcIOTests.scala @@ -0,0 +1,118 @@ +/* + * Copyright 2023 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.jdbc + +import com.spotify.scio.ScioContext +import org.apache.beam.sdk.Pipeline.PipelineVisitor +import org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior +import org.apache.beam.sdk.io.jdbc.{JdbcIO => BJdbcIO} +import org.apache.beam.sdk.runners.TransformHierarchy +import org.apache.beam.sdk.transforms.PTransform +import org.apache.beam.sdk.transforms.display.DisplayData +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should._ + +import java.sql.ResultSet +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +case class Foo(field: String) + +class JdbcIOTests extends AnyFlatSpec with Matchers { + + private val ReadQueryId = DisplayData.Identifier.of( + DisplayData.Path.root(), + classOf[BJdbcIO.Read[_]], + "query" + ) + + private val WriteStatementId = "[fn]class org.apache.beam.sdk.io.jdbc.JdbcIO$WriteFn:statement" + + it must "add to pipeline overridden Read transform" in { + val args = Array[String]() + val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](args) + val sc = ScioContext(opts) + sc.jdbcSelect[String]( + getConnectionOptions(opts), + "initial query", + configOverride = (x: BJdbcIO.Read[String]) => x.withQuery("overridden query") + ) { (rs: ResultSet) => + rs.getString(1) + } + + val transform = getPipelineTransforms(sc).collect { case t: BJdbcIO.Read[String] => t }.head + val displayData = DisplayData.from(transform).asMap().asScala + + displayData should contain key ReadQueryId + displayData(ReadQueryId).getValue should be("overridden query") + } + + it must "add to pipeline overridden Write transform" in { + val args = Array[String]() + val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](args) + val sc = ScioContext(opts) + + sc.parallelize(List("1", "2", "3")) + .saveAsJdbc( + getConnectionOptions(opts), + "INSERT INTO VALUES( ?, ? ..?)", + configOverride = _.withStatement("updated statement") + ) { (_, _) => } + + val transform = getPipelineTransforms(sc).filter(x => x.toString.contains("WriteFn")).head + val displayData = + DisplayData.from(transform).asMap().asScala.map { case (k, v) => (k.toString, v) } + + displayData should contain key WriteStatementId + displayData(WriteStatementId).getValue should be("updated statement") + } + + private def getPipelineTransforms(sc: ScioContext): Iterable[PTransform[_, _]] = { + val actualTransforms = new ArrayBuffer[PTransform[_, _]]() + sc.pipeline.traverseTopologically(new PipelineVisitor.Defaults { + override def enterCompositeTransform( + node: TransformHierarchy#Node + ): PipelineVisitor.CompositeBehavior = { + if (node.getTransform != null) { + actualTransforms.append(node.getTransform) + } + CompositeBehavior.ENTER_TRANSFORM + } + + override def visitPrimitiveTransform(node: TransformHierarchy#Node): Unit = + if (node.getTransform != null) { + actualTransforms.append(node.getTransform) + } + }) + actualTransforms + } + + def getConnectionOptions(opts: CloudSqlOptions): JdbcConnectionOptions = + JdbcConnectionOptions( + username = opts.getCloudSqlUsername, + password = Some(opts.getCloudSqlPassword), + connectionUrl = connectionUrl(opts), + classOf[java.sql.Driver] + ) + + def connectionUrl(opts: CloudSqlOptions): String = + s"jdbc:mysql://google/${opts.getCloudSqlDb}?" + + s"cloudSqlInstance=${opts.getCloudSqlInstanceConnectionName}&" + + s"socketFactory=com.google.cloud.sql.mysql.SocketFactory" + +} diff --git a/scio-jdbc/src/test/scala/com/spotify/scio/jdbc/JdbcTest.scala b/scio-jdbc/src/test/scala/com/spotify/scio/jdbc/JdbcTest.scala index cf8a1950b4..116ab58a31 100644 --- a/scio-jdbc/src/test/scala/com/spotify/scio/jdbc/JdbcTest.scala +++ b/scio-jdbc/src/test/scala/com/spotify/scio/jdbc/JdbcTest.scala @@ -24,29 +24,20 @@ import org.apache.beam.sdk.io.{jdbc => beam} import com.spotify.scio.testing._ object JdbcJob { + + val query = "SELECT FROM " + val statement = "INSERT INTO VALUES( ?, ? ..?)" def main(cmdlineArgs: Array[String]): Unit = { val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](cmdlineArgs) val sc = ScioContext(opts) - sc.jdbcSelect(getReadOptions(opts)) + val connectionOpts = getConnectionOptions(opts) + sc.jdbcSelect[String](connectionOpts, query)((rs: ResultSet) => rs.getString(1)) .map(_ + "J") - .saveAsJdbc(getWriteOptions(opts)) + .saveAsJdbc(connectionOpts, statement) { (_, _) => } sc.run() () } - def getReadOptions(opts: CloudSqlOptions): JdbcReadOptions[String] = - JdbcReadOptions( - connectionOptions = getConnectionOptions(opts), - query = "SELECT FROM ", - rowMapper = (rs: ResultSet) => rs.getString(1) - ) - - def getWriteOptions(opts: CloudSqlOptions): JdbcWriteOptions[String] = - JdbcWriteOptions[String]( - connectionOptions = getConnectionOptions(opts), - statement = "INSERT INTO VALUES( ?, ? ..?)" - ) - def connectionUrl(opts: CloudSqlOptions): String = s"jdbc:mysql://google/${opts.getCloudSqlDb}?" + s"cloudSqlInstance=${opts.getCloudSqlInstanceConnectionName}&" + @@ -70,13 +61,14 @@ class JdbcTest extends PipelineSpec { "--cloudSqlInstanceConnectionName=project-id:zone:db-instance-name" ) val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](args.toArray) - val readOpts = JdbcJob.getReadOptions(opts) - val writeOpts = JdbcJob.getWriteOptions(opts) + val connectionOpts = JdbcJob.getConnectionOptions(opts) JobTest[JdbcJob.type] .args(args: _*) - .input(JdbcIO[String](readOpts), Seq("a", "b", "c")) - .output(JdbcIO[String](writeOpts))(coll => coll should containInAnyOrder(xs)) + .input(JdbcIO[String](connectionOpts, JdbcJob.query), Seq("a", "b", "c")) + .output(JdbcIO[String](connectionOpts, JdbcJob.statement))(coll => + coll should containInAnyOrder(xs) + ) .run() } @@ -96,15 +88,16 @@ class JdbcTest extends PipelineSpec { "--cloudSqlInstanceConnectionName=project-id:zone:db-instance-name" ) val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](args.toArray) - val readOpts = JdbcJob.getReadOptions(opts) - val writeOpts = JdbcJob.getWriteOptions(opts) + val connectionOpts = JdbcJob.getConnectionOptions(opts) val expected = Seq("aJ", "bJ", "cJ") JobTest[JdbcJob.type] .args(args: _*) - .input(JdbcIO[String](readOpts), Seq("a", "b", "c")) - .output(JdbcIO[String](writeOpts))(coll => coll should containInAnyOrder(expected)) + .input(JdbcIO[String](connectionOpts, JdbcJob.query), Seq("a", "b", "c")) + .output(JdbcIO[String](connectionOpts, JdbcJob.statement))(coll => + coll should containInAnyOrder(expected) + ) .run() }