Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set provided name on transforms with side outputs #4779

Merged
merged 4 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -188,25 +188,9 @@ def previousVersion(currentVersion: String): Option[String] = {

lazy val mimaSettings = Def.settings(
mimaBinaryIssueFilters := Seq(
// minor scio-tensorflow breaking changes for 0.12.6
// minor scio-core breaking changes for 0.12.8
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.tensorflow.syntax.SeqExampleSCollectionOps.saveAsTfRecordFile"
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.tensorflow.syntax.SeqExampleSCollectionOps.saveAsTfRecordFile$extension"
),
// minor scio-grpc breaking changes for 0.12.6
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.grpc.GrpcSCollectionOps.grpcLookup"
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.grpc.GrpcSCollectionOps.grpcLookup$extension"
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.grpc.GrpcSCollectionOps.grpcLookupStream"
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.grpc.GrpcSCollectionOps.grpcLookupStream$extension"
"com.spotify.scio.values.SCollectionWithSideOutput.this"
)
),
mimaPreviousArtifacts :=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,7 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] {
* @group side
*/
def withSideOutputs(sides: SideOutput[_]*): SCollectionWithSideOutput[T] =
new SCollectionWithSideOutput[T](internal, context, sides)
new SCollectionWithSideOutput[T](this, sides)

// =======================================================================
// Windowing operations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,27 @@ import scala.jdk.CollectionConverters._
* s of the [[SideOutput]] s are accessed via the additional [[SideOutputCollections]] return value.
*/
class SCollectionWithSideOutput[T] private[values] (
val internal: PCollection[T],
val context: ScioContext,
Comment on lines -35 to -36
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change.
We need to have access to the Scollecction to get the name. The new definition is consistent with other SCollection wrapper

coll: SCollection[T],
sides: Iterable[SideOutput[_]]
) extends PCollectionWrapper[T] {

override val internal: PCollection[T] = coll.internal
override val context: ScioContext = coll.context

private val sideTags = TupleTagList.of(sides.map(_.tupleTag).toList.asJava)

override def withName(name: String): this.type = {
coll.withName(name)
this
}

private def apply[U: Coder](f: DoFn[T, U]): (SCollection[U], SideOutputCollections) = {
val mainTag = new TupleTag[U]

val dofn = ParDo.of(f).withOutputTags(mainTag, sideTags)
val tuple = this.applyInternal(dofn)
val tuple = this.applyInternal(coll.tfName, dofn)

val main =
tuple.get(mainTag).setCoder(CoderMaterializer.beam(context, Coder[U]))
val main = tuple.get(mainTag).setCoder(CoderMaterializer.beam(context, Coder[U]))

sides.foreach { s =>
tuple
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ trait TransformNameable {
nameProvider.getClass != classOf[ConstNameProvider],
s"withName() has already been used to set '$tfName' as the name for the next transform."
)
nameProvider = new ConstNameProvider(name)
nameProvider = ConstNameProvider(name)
this
}
}
Expand Down
230 changes: 108 additions & 122 deletions scio-test/src/test/scala/com/spotify/scio/values/NamedTransformTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package com.spotify.scio.values

import com.spotify.scio.ScioContext
import com.spotify.scio.testing.PipelineSpec
import com.spotify.scio.util.MultiJoin
import org.apache.beam.sdk.Pipeline
Expand Down Expand Up @@ -79,152 +80,139 @@ trait NamedTransformSpec extends PipelineSpec {

class NamedTransformTest extends NamedTransformSpec {
"ScioContext" should "support custom transform name" in {
runWithContext { sc =>
val p = sc.withName("ReadInput").parallelize(Seq("a", "b", "c"))
assertTransformNameStartsWith(p, "ReadInput/Read")
}
val sc = ScioContext()
val p = sc.withName("ReadInput").parallelize(Seq("a", "b", "c"))
assertTransformNameStartsWith(p, "ReadInput/Read")
}

"SCollection" should "support custom transform name" in {
runWithContext { sc =>
val p = sc
.parallelize(Seq(1, 2, 3, 4, 5))
.map(_ * 3)
.withName("OnlyEven")
.filter(_ % 2 == 0)
assertTransformNameStartsWith(p, "OnlyEven")
}
val sc = ScioContext()
val p = sc
.parallelize(Seq(1, 2, 3, 4, 5))
.map(_ * 3)
.withName("OnlyEven")
.filter(_ % 2 == 0)
assertTransformNameStartsWith(p, "OnlyEven")
}

"DoubleSCollectionFunctions" should "support custom transform name" in {
runWithContext { sc =>
val p = sc
.parallelize(Seq(1.0, 2.0, 3.0, 4.0, 5.0))
.withName("CalcVariance")
.variance
assertTransformNameStartsWith(p, "CalcVariance")
}
val sc = ScioContext()
val p = sc
.parallelize(Seq(1.0, 2.0, 3.0, 4.0, 5.0))
.withName("CalcVariance")
.variance
assertTransformNameStartsWith(p, "CalcVariance")
}

"PairSCollectionFunctions" should "support custom transform name" in {
runWithContext { sc =>
val p = sc
.parallelize(Seq(("a", 1), ("b", 2), ("c", 3)))
.withName("SumPerKey")
.sumByKey
assertTransformNameStartsWith(p, "SumPerKey/KvToTuple")
}
val sc = ScioContext()
val p = sc
.parallelize(Seq(("a", 1), ("b", 2), ("c", 3)))
.withName("SumPerKey")
.sumByKey
assertTransformNameStartsWith(p, "SumPerKey/KvToTuple")
}

"SCollectionWithFanout" should "support custom transform name" in {
runWithContext { sc =>
val p = sc
.parallelize(Seq(1, 2, 3))
.withFanout(10)
.withName("Sum")
.sum
assertTransformNameStartsWith(p, "Sum/Values/Values")
}
val sc = ScioContext()
val p = sc
.parallelize(Seq(1, 2, 3))
.withFanout(10)
.withName("Sum")
.sum
assertTransformNameStartsWith(p, "Sum/Values/Values")
}

"SCollectionWithHotKeyFanout" should "support custom transform name" in {
runWithContext { sc =>
val p = sc
.parallelize(Seq(("a", 1), ("b", 2), ("c", 3)))
.withHotKeyFanout(10)
.withName("Sum")
.sumByKey
assertTransformNameStartsWith(p, "Sum/KvToTuple")
}
val sc = ScioContext()
val p = sc
.parallelize(Seq(("a", 1), ("b", 2), ("c", 3)))
.withHotKeyFanout(10)
.withName("Sum")
.sumByKey
assertTransformNameStartsWith(p, "Sum/KvToTuple")
}

"SCollectionWithSideInput" should "support custom transform name" in {
runWithContext { sc =>
val p1 = sc.parallelize(Seq("a", "b", "c"))
val p2 = sc.parallelize(Seq(1, 2, 3)).asListSideInput
val s = p1
.withSideInputs(p2)
.withName("GetX")
.filter((x, _) => x == "a")
assertTransformNameStartsWith(s, "GetX")
}
val sc = ScioContext()
val p1 = sc.parallelize(Seq("a", "b", "c"))
val p2 = sc.parallelize(Seq(1, 2, 3)).asListSideInput
val s = p1
.withSideInputs(p2)
.withName("GetX")
.filter((x, _) => x == "a")
assertTransformNameStartsWith(s, "GetX")
}

"SCollectionWithSideOutput" should "support custom transform name" in {
runWithContext { sc =>
val p1 = sc.parallelize(Seq("a", "b", "c"))
val p2 = SideOutput[String]()
val (main, side) = p1
.withSideOutputs(p2)
.withName("MakeSideOutput")
.map { (x, s) => s.output(p2, x + "2"); x + "1" }
val sideOut = side(p2)
assertTransformNameStartsWith(main, "MakeSideOutput")
assertTransformNameStartsWith(sideOut, "MakeSideOutput")
}
val sc = ScioContext()
val p1 = sc.parallelize(Seq("a", "b", "c"))
val p2 = SideOutput[String]()
val (main, side) = p1
.withSideOutputs(p2)
.withName("MakeSideOutput")
.map { (x, s) => s.output(p2, x + "2"); x + "1" }
val sideOut = side(p2)
assertTransformNameStartsWith(main, "MakeSideOutput")
assertTransformNameStartsWith(sideOut, "MakeSideOutput")
}

"WindowedSCollection" should "support custom transform name" in {
runWithContext { sc =>
val p = sc
.parallelize(Seq(1, 2, 3, 4, 5))
.toWindowed
.withName("Triple")
.map(x => x.withValue(x.value * 3))
assertTransformNameStartsWith(p, "Triple")
}
val sc = ScioContext()
val p = sc
.parallelize(Seq(1, 2, 3, 4, 5))
.toWindowed
.withName("Triple")
.map(x => x.withValue(x.value * 3))
assertTransformNameStartsWith(p, "Triple")
}

"Joins" should "support custom transform names" in {
runWithContext { sc =>
val p1 = sc.parallelize(Seq(("a", 1), ("b", 2), ("c", 3)))
val p2 = sc.parallelize(Seq(("a", 11), ("b", 12), ("b", 13), ("d", 14)))
val inner = p1.withName("inner").join(p2)
val left = p1.withName("left").leftOuterJoin(p2)
val right = p1.withName("right").rightOuterJoin(p2)
val full = p1.withName("full").fullOuterJoin(p2)
assertTransformNameStartsWith(inner, "inner")
assertTransformNameStartsWith(left, "left")
assertTransformNameStartsWith(right, "right")
assertTransformNameStartsWith(full, "full")
}
val sc = ScioContext()
val p1 = sc.parallelize(Seq(("a", 1), ("b", 2), ("c", 3)))
val p2 = sc.parallelize(Seq(("a", 11), ("b", 12), ("b", 13), ("d", 14)))
val inner = p1.withName("inner").join(p2)
val left = p1.withName("left").leftOuterJoin(p2)
val right = p1.withName("right").rightOuterJoin(p2)
val full = p1.withName("full").fullOuterJoin(p2)
assertTransformNameStartsWith(inner, "inner")
assertTransformNameStartsWith(left, "left")
assertTransformNameStartsWith(right, "right")
assertTransformNameStartsWith(full, "full")
}

"MultiJoin" should "support custom transform name" in {
runWithContext { sc =>
val p1 = sc.parallelize(Seq(("a", 1), ("a", 2), ("b", 3), ("c", 4)))
val p2 = sc.parallelize(Seq(("a", 11), ("b", 12), ("b", 13), ("d", 14)))
val p = MultiJoin.withName("JoinEm").left(p1, p2)
assertTransformNameStartsWith(p, "JoinEm")
}
val sc = ScioContext()
val p1 = sc.parallelize(Seq(("a", 1), ("a", 2), ("b", 3), ("c", 4)))
val p2 = sc.parallelize(Seq(("a", 11), ("b", 12), ("b", 13), ("d", 14)))
val p = MultiJoin.withName("JoinEm").left(p1, p2)
assertTransformNameStartsWith(p, "JoinEm")
}

"Duplicate transform name" should "have number to make unique" in {
runWithContext { sc =>
val p1 = sc
.parallelize(1 to 5)
.withName("MyTransform")
.map(_ * 2)
val p2 = p1
.withName("MyTransform")
.map(_ * 3)
val p3 = p1
.withName("MyTransform")
.map(_ * 4)
assertTransformNameStartsWith(p1, "MyTransform")
assertTransformNameStartsWith(p2, "MyTransform2")
assertTransformNameStartsWith(p3, "MyTransform3")
}
val sc = ScioContext()
val p1 = sc
.parallelize(1 to 5)
.withName("MyTransform")
.map(_ * 2)
val p2 = p1
.withName("MyTransform")
.map(_ * 3)
val p3 = p1
.withName("MyTransform")
.map(_ * 4)
assertTransformNameStartsWith(p1, "MyTransform")
assertTransformNameStartsWith(p2, "MyTransform2")
assertTransformNameStartsWith(p3, "MyTransform3")
}

"TransformNameable" should "prevent repeated calls to .withName" in {
val e = the[IllegalArgumentException] thrownBy {
runWithContext { sc =>
sc.parallelize(1 to 5)
.withName("Double")
.withName("DoubleMap")
.map(_ * 2)
}
val sc = ScioContext()
sc.parallelize(1 to 5)
.withName("Double")
.withName("DoubleMap")
.map(_ * 2)
}

val msg = "requirement failed: withName() has already been used to set 'Double' as " +
Expand All @@ -241,22 +229,20 @@ class NamedTransformTest extends NamedTransformSpec {
}

it should "contain file:line only on outer transform" in {
runWithContext { sc =>
val p = sc.parallelize(1 to 5).transform(_.transform(_.map(_ + 1)))
assertTransformNameStartsWith(
p,
"""transform\@\{NamedTransformTest\.scala:\d*\}:\d*/transform:\d*/map:\d*"""
)
}
val sc = ScioContext()
val p = sc.parallelize(1 to 5).transform(_.transform(_.map(_ + 1)))
assertTransformNameStartsWith(
p,
"""transform\@\{NamedTransformTest\.scala:\d*\}:\d*/transform:\d*/map:\d*"""
)
}

it should "support fall back to default transform names" in {
runWithContext { sc =>
val defaultName = sc.tfName(default = Some("default"))
defaultName should be("default")
val sc = ScioContext()
val defaultName = sc.tfName(default = Some("default"))
defaultName should be("default")

val userNamed = sc.withName("UserNamed").tfName(default = Some("default"))
userNamed should be("UserNamed")
}
val userNamed = sc.withName("UserNamed").tfName(default = Some("default"))
userNamed should be("UserNamed")
}
}
Loading