-
Notifications
You must be signed in to change notification settings - Fork 513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added new JdbcIO read/write params to Scio #4820
Changes from 3 commits
9aa40ed
5e2ee53
0b0bac8
58ed32a
13c928c
83d9e1c
5315e90
05416d6
cc5cc82
e513de6
2382d06
b0cec13
d56b60b
01e74cc
0a7c741
8c46df0
8cb973b
5ed97d2
281024c
e583a85
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,14 +17,11 @@ | |
|
||
package com.spotify.scio.jdbc | ||
|
||
import com.spotify.scio.values.SCollection | ||
import com.spotify.scio.ScioContext | ||
import com.spotify.scio.io.{EmptyTap, EmptyTapOf, ScioIO, Tap, TestIO} | ||
import org.apache.beam.sdk.io.{jdbc => beam} | ||
import java.sql.{PreparedStatement, ResultSet} | ||
|
||
import com.spotify.scio.coders.{Coder, CoderMaterializer} | ||
import com.spotify.scio.io.TapT | ||
import com.spotify.scio.io._ | ||
import com.spotify.scio.values.SCollection | ||
import org.apache.beam.sdk.io.{jdbc => beam} | ||
|
||
sealed trait JdbcIO[T] extends ScioIO[T] | ||
|
||
|
@@ -36,9 +33,10 @@ object JdbcIO { | |
} | ||
|
||
private[jdbc] def jdbcIoId(opts: JdbcIoOptions): String = opts match { | ||
case JdbcReadOptions(connOpts, query, _, _, _, _) => jdbcIoId(connOpts, query) | ||
case JdbcWriteOptions(connOpts, statement, _, _, _, _) => | ||
jdbcIoId(connOpts, statement) | ||
case readOpts: JdbcReadOptions[_] => | ||
jdbcIoId(readOpts.connectionOptions, readOpts.query) | ||
case writeOpts: JdbcWriteOptions[_] => | ||
jdbcIoId(writeOpts.connectionOptions, writeOpts.statement) | ||
} | ||
|
||
private[jdbc] def jdbcIoId(opts: JdbcConnectionOptions, query: String): String = { | ||
|
@@ -66,7 +64,7 @@ object JdbcIO { | |
final case class JdbcSelect[T: Coder](readOptions: JdbcReadOptions[T]) extends JdbcIO[T] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMHO we've not defined the constructor for All the other param should be passed as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I also noted that it stands out from other IOs. Will change it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's do in another PR. I think we have to check some other IOs too (I recall CsvIO has the same issue) |
||
override type ReadP = Unit | ||
override type WriteP = Nothing | ||
final override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T] | ||
override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T] | ||
|
||
override def testId: String = s"JdbcIO(${JdbcIO.jdbcIoId(readOptions)})" | ||
|
||
|
@@ -77,24 +75,21 @@ final case class JdbcSelect[T: Coder](readOptions: JdbcReadOptions[T]) extends J | |
.withCoder(coder) | ||
.withDataSourceConfiguration(JdbcIO.dataSourceConfiguration(readOptions.connectionOptions)) | ||
.withQuery(readOptions.query) | ||
.withRowMapper(new beam.JdbcIO.RowMapper[T] { | ||
override def mapRow(resultSet: ResultSet): T = | ||
readOptions.rowMapper(resultSet) | ||
}) | ||
.withRowMapper(readOptions.rowMapper(_)) | ||
.withOutputParallelization(readOptions.outputParallelization) | ||
|
||
if (readOptions.dataSourceProviderFn != null) { | ||
transform.withDataSourceProviderFn((_: Void) => readOptions.dataSourceProviderFn()) | ||
} | ||
if (readOptions.statementPreparator != null) { | ||
transform = transform | ||
.withStatementPreparator(new beam.JdbcIO.StatementPreparator { | ||
override def setParameters(preparedStatement: PreparedStatement): Unit = | ||
readOptions.statementPreparator(preparedStatement) | ||
}) | ||
.withStatementPreparator(readOptions.statementPreparator(_)) | ||
} | ||
if (readOptions.fetchSize != JdbcIoOptions.BeamDefaultFetchSize) { | ||
// override default fetch size. | ||
transform = transform.withFetchSize(readOptions.fetchSize) | ||
} | ||
sc.applyTransform(transform) | ||
sc.applyTransform(readOptions.configOverride(transform)) | ||
} | ||
|
||
override protected def write(data: SCollection[T], params: WriteP): Tap[Nothing] = | ||
|
@@ -107,7 +102,7 @@ final case class JdbcSelect[T: Coder](readOptions: JdbcReadOptions[T]) extends J | |
final case class JdbcWrite[T](writeOptions: JdbcWriteOptions[T]) extends JdbcIO[T] { | ||
override type ReadP = Nothing | ||
override type WriteP = Unit | ||
final override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T] | ||
override val tapT: TapT.Aux[T, Nothing] = EmptyTapOf[T] | ||
|
||
override def testId: String = s"JdbcIO(${JdbcIO.jdbcIoId(writeOptions)})" | ||
|
||
|
@@ -119,23 +114,27 @@ final case class JdbcWrite[T](writeOptions: JdbcWriteOptions[T]) extends JdbcIO[ | |
.write[T]() | ||
.withDataSourceConfiguration(JdbcIO.dataSourceConfiguration(writeOptions.connectionOptions)) | ||
.withStatement(writeOptions.statement) | ||
|
||
if (writeOptions.dataSourceProviderFn != null) { | ||
transform.withDataSourceProviderFn((_: Void) => writeOptions.dataSourceProviderFn()) | ||
} | ||
if (writeOptions.preparedStatementSetter != null) { | ||
transform = transform | ||
.withPreparedStatementSetter(new beam.JdbcIO.PreparedStatementSetter[T] { | ||
override def setParameters(element: T, preparedStatement: PreparedStatement): Unit = | ||
writeOptions.preparedStatementSetter(element, preparedStatement) | ||
}) | ||
.withPreparedStatementSetter(writeOptions.preparedStatementSetter(_, _)) | ||
} | ||
if (writeOptions.batchSize != JdbcIoOptions.BeamDefaultBatchSize) { | ||
// override default batch size. | ||
transform = transform.withBatchSize(writeOptions.batchSize) | ||
} | ||
if (writeOptions.autoSharding) { | ||
transform = transform.withAutoSharding() | ||
} | ||
|
||
transform = transform | ||
.withRetryConfiguration(writeOptions.retryConfiguration) | ||
.withRetryStrategy(writeOptions.retryStrategy.apply) | ||
|
||
data.applyInternal(transform) | ||
data.applyInternal(writeOptions.configOverride(transform)) | ||
EmptyTap | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package com.spotify.scio.jdbc | ||
|
||
import com.spotify.scio.ScioContext | ||
import org.apache.beam.sdk.Pipeline.PipelineVisitor | ||
import org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior | ||
import org.apache.beam.sdk.io.jdbc.{JdbcIO => BJdbcIO} | ||
import org.apache.beam.sdk.runners.TransformHierarchy | ||
import org.apache.beam.sdk.transforms.PTransform | ||
import org.scalatest.flatspec.AnyFlatSpec | ||
import org.scalatest.matchers.should._ | ||
|
||
import java.sql.ResultSet | ||
import scala.collection.mutable.ArrayBuffer | ||
|
||
class JdbcIOTests3 extends AnyFlatSpec with Matchers { | ||
|
||
it must "add to pipeline overridden Read transform" in { | ||
val args = Array[String]() | ||
val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](args) | ||
val sc = ScioContext(opts) | ||
var expectedTransform: BJdbcIO.Read[String] = null | ||
sc.jdbcSelect[String]( | ||
getDefaultReadOptions(opts).copy(configOverride = r => { | ||
expectedTransform = r.withQuery("overridden query") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we get back the query instead on memorizing the transform in a var ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, we can. This was the simplest code, otherwise we would need to match transform by type. Which is not difficult either :) |
||
expectedTransform | ||
}) | ||
) | ||
|
||
expectedTransform should not be null | ||
getPipelineTransforms(sc) should contain(expectedTransform) | ||
} | ||
|
||
private def getPipelineTransforms(sc: ScioContext): Iterable[PTransform[_, _]] = { | ||
val actualTransforms = new ArrayBuffer[PTransform[_, _]]() | ||
sc.pipeline.traverseTopologically(new PipelineVisitor.Defaults { | ||
override def enterCompositeTransform( | ||
node: TransformHierarchy#Node | ||
): PipelineVisitor.CompositeBehavior = { | ||
actualTransforms.addOne(node.getTransform) | ||
CompositeBehavior.ENTER_TRANSFORM | ||
} | ||
}) | ||
actualTransforms | ||
} | ||
|
||
def getDefaultReadOptions(opts: CloudSqlOptions): JdbcReadOptions[String] = | ||
JdbcReadOptions( | ||
connectionOptions = getConnectionOptions(opts), | ||
query = "SELECT <this> FROM <this>", | ||
rowMapper = (rs: ResultSet) => rs.getString(1) | ||
) | ||
|
||
def getWriteOptions(opts: CloudSqlOptions): JdbcWriteOptions[String] = | ||
JdbcWriteOptions[String]( | ||
connectionOptions = getConnectionOptions(opts), | ||
statement = "INSERT INTO <this> VALUES( ?, ? ..?)" | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks this s unused so far |
||
|
||
def getConnectionOptions(opts: CloudSqlOptions): JdbcConnectionOptions = | ||
JdbcConnectionOptions( | ||
username = opts.getCloudSqlUsername, | ||
password = Some(opts.getCloudSqlPassword), | ||
connectionUrl = connectionUrl(opts), | ||
classOf[java.sql.Driver] | ||
) | ||
|
||
def connectionUrl(opts: CloudSqlOptions): String = | ||
s"jdbc:mysql://google/${opts.getCloudSqlDb}?" + | ||
s"cloudSqlInstance=${opts.getCloudSqlInstanceConnectionName}&" + | ||
s"socketFactory=com.google.cloud.sql.mysql.SocketFactory" | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nicer like this!