Skip to content

Commit

Permalink
Use transform finder in JdbcIOTest (#4865)
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones authored Jun 8, 2023
1 parent ea7e98b commit 1f32fd6
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 75 deletions.
73 changes: 42 additions & 31 deletions scio-jdbc/src/main/scala/com/spotify/scio/jdbc/JdbcIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ package com.spotify.scio.jdbc
import com.spotify.scio.ScioContext
import com.spotify.scio.coders.{Coder, CoderMaterializer}
import com.spotify.scio.io._
import com.spotify.scio.util.Functions
import com.spotify.scio.values.SCollection
import org.apache.beam.sdk.io.jdbc.JdbcIO.{PreparedStatementSetter, StatementPreparator}
import org.apache.beam.sdk.io.jdbc.{JdbcIO => BJdbcIO}
import org.joda.time.Duration

import java.sql.{PreparedStatement, ResultSet, SQLException}
import javax.sql.DataSource
import scala.util.chaining._

sealed trait JdbcIO[T] extends ScioIO[T]

Expand Down Expand Up @@ -121,25 +124,31 @@ final case class JdbcSelect[T: Coder](opts: JdbcConnectionOptions, query: String

override protected def read(sc: ScioContext, params: ReadP): SCollection[T] = {
val coder = CoderMaterializer.beam(sc, Coder[T])
var transform = BJdbcIO
val transform = BJdbcIO
.read[T]()
.withCoder(coder)
.withDataSourceConfiguration(JdbcIO.dataSourceConfiguration(opts))
.withQuery(query)
.withRowMapper(params.rowMapper(_))
.withOutputParallelization(params.outputParallelization)

if (params.dataSourceProviderFn != null) {
transform.withDataSourceProviderFn((_: Void) => params.dataSourceProviderFn())
}
if (params.statementPreparator != null) {
transform = transform
.withStatementPreparator(params.statementPreparator(_))
}
if (params.fetchSize != JdbcIO.ReadParam.BeamDefaultFetchSize) {
// override default fetch size.
transform = transform.withFetchSize(params.fetchSize)
}
.pipe { r =>
Option(params.dataSourceProviderFn)
.map(fn => Functions.serializableFn[Void, DataSource](_ => fn()))
.fold(r)(r.withDataSourceProviderFn)
}
.pipe { r =>
Option(params.statementPreparator)
.map[StatementPreparator](fn => fn(_))
.fold(r)(r.withStatementPreparator)
}
.pipe { r =>
if (params.fetchSize != JdbcIO.ReadParam.BeamDefaultFetchSize) {
// override default fetch size.
r.withFetchSize(params.fetchSize)
} else {
r
}
}

sc.applyTransform(params.configOverride(transform))
}
Expand All @@ -162,29 +171,31 @@ final case class JdbcWrite[T](opts: JdbcConnectionOptions, statement: String) ex
throw new UnsupportedOperationException("jdbc.Write is write-only")

override protected def write(data: SCollection[T], params: WriteP): Tap[Nothing] = {
var transform = BJdbcIO
val transform = BJdbcIO
.write[T]()
.withDataSourceConfiguration(JdbcIO.dataSourceConfiguration(opts))
.withStatement(statement)

if (params.dataSourceProviderFn != null) {
transform.withDataSourceProviderFn((_: Void) => params.dataSourceProviderFn())
}
if (params.preparedStatementSetter != null) {
transform = transform
.withPreparedStatementSetter(params.preparedStatementSetter(_, _))
}
if (params.batchSize != JdbcIO.WriteParam.BeamDefaultBatchSize) {
// override default batch size.
transform = transform.withBatchSize(params.batchSize)
}
if (params.autoSharding) {
transform = transform.withAutoSharding()
}

transform = transform
.withRetryConfiguration(params.retryConfiguration)
.withRetryStrategy(params.retryStrategy.apply)
.pipe { w =>
Option(params.dataSourceProviderFn)
.map(fn => Functions.serializableFn[Void, DataSource](_ => fn()))
.fold(w)(w.withDataSourceProviderFn)
}
.pipe { w =>
Option(params.preparedStatementSetter)
.map[PreparedStatementSetter[T]](fn => fn(_, _))
.fold(w)(w.withPreparedStatementSetter)
}
.pipe { w =>
if (params.batchSize != JdbcIO.WriteParam.BeamDefaultBatchSize) {
// override default batch size.
w.withBatchSize(params.batchSize)
} else {
w
}
}
.pipe(w => if (params.autoSharding) w.withAutoSharding() else w)

data.applyInternal(params.configOverride(transform))
EmptyTap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,48 @@
package com.spotify.scio.jdbc

import com.spotify.scio.ScioContext
import org.apache.beam.sdk.Pipeline.PipelineVisitor
import org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior
import com.spotify.scio.testing.{EqualNamePTransformMatcher, TransformFinder}
import org.apache.beam.sdk.io.jdbc.{JdbcIO => BJdbcIO}
import org.apache.beam.sdk.runners.TransformHierarchy
import org.apache.beam.sdk.transforms.PTransform
import org.apache.beam.sdk.transforms.display.DisplayData
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should._

import java.sql.ResultSet
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._

case class Foo(field: String)

class JdbcIOTests extends AnyFlatSpec with Matchers {

object JdbcIOTest {
private val ReadQueryId = DisplayData.Identifier.of(
DisplayData.Path.root(),
classOf[BJdbcIO.Read[_]],
"query"
)

private val WriteStatementId = "[fn]class org.apache.beam.sdk.io.jdbc.JdbcIO$WriteFn:statement"
private val WriteStatementId = DisplayData.Identifier.of(
DisplayData.Path.absolute("fn"),
Class.forName("org.apache.beam.sdk.io.jdbc.JdbcIO$WriteFn"),
"statement"
)
}

class JdbcIOTest extends AnyFlatSpec with Matchers {
import JdbcIOTest._

it must "add to pipeline overridden Read transform" in {
val args = Array[String]()
val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](args)
val name = "jdbcSelect"
val sc = ScioContext(opts)
sc.jdbcSelect[String](
getConnectionOptions(opts),
"initial query",
configOverride = (x: BJdbcIO.Read[String]) => x.withQuery("overridden query")
) { (rs: ResultSet) =>
rs.getString(1)
}
sc.withName(name)
.jdbcSelect[String](
getConnectionOptions(opts),
"initial query",
configOverride = (x: BJdbcIO.Read[String]) => x.withQuery("overridden query")
)(rs => rs.getString(1))

val transform = getPipelineTransforms(sc).collect { case t: BJdbcIO.Read[String] => t }.head
val finder = new TransformFinder(new EqualNamePTransformMatcher(name))
sc.pipeline.traverseTopologically(finder)
val transform = finder.result().head
val displayData = DisplayData.from(transform).asMap().asScala

displayData should contain key ReadQueryId
Expand All @@ -66,42 +70,25 @@ class JdbcIOTests extends AnyFlatSpec with Matchers {
val args = Array[String]()
val (opts, _) = ScioContext.parseArguments[CloudSqlOptions](args)
val sc = ScioContext(opts)

val name = "saveAsJdbc"
sc.parallelize(List("1", "2", "3"))
.withName(name)
.saveAsJdbc(
getConnectionOptions(opts),
"INSERT INTO <this> VALUES( ?, ? ..?)",
"INSERT INTO <this> VALUES(?)",
configOverride = _.withStatement("updated statement")
) { (_, _) => }

val transform = getPipelineTransforms(sc).filter(x => x.toString.contains("WriteFn")).head
val displayData =
DisplayData.from(transform).asMap().asScala.map { case (k, v) => (k.toString, v) }
)((x, ps) => ps.setString(0, x))

// find the underlying jdbc write
val finder = new TransformFinder(new EqualNamePTransformMatcher(name + "/ParDo(Write)"))
sc.pipeline.traverseTopologically(finder)
val transform = finder.result().head
val displayData = DisplayData.from(transform).asMap().asScala
println(displayData)
displayData should contain key WriteStatementId
displayData(WriteStatementId).getValue should be("updated statement")
}

private def getPipelineTransforms(sc: ScioContext): Iterable[PTransform[_, _]] = {
val actualTransforms = new ArrayBuffer[PTransform[_, _]]()
sc.pipeline.traverseTopologically(new PipelineVisitor.Defaults {
override def enterCompositeTransform(
node: TransformHierarchy#Node
): PipelineVisitor.CompositeBehavior = {
if (node.getTransform != null) {
actualTransforms.append(node.getTransform)
}
CompositeBehavior.ENTER_TRANSFORM
}

override def visitPrimitiveTransform(node: TransformHierarchy#Node): Unit =
if (node.getTransform != null) {
actualTransforms.append(node.getTransform)
}
})
actualTransforms
}

def getConnectionOptions(opts: CloudSqlOptions): JdbcConnectionOptions =
JdbcConnectionOptions(
username = opts.getCloudSqlUsername,
Expand Down

0 comments on commit 1f32fd6

Please sign in to comment.