Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for named transforms. #357

Merged
merged 11 commits into from
Jan 17, 2017
5 changes: 3 additions & 2 deletions scio-core/src/main/scala/com/spotify/scio/ScioContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ object ScioContext {
*/
// scalastyle:off number.of.methods
class ScioContext private[scio] (val options: PipelineOptions,
private var artifacts: List[String]) {
private var artifacts: List[String])
extends TransformNameable {

private implicit val context: ScioContext = this

Expand Down Expand Up @@ -398,7 +399,7 @@ class ScioContext private[scio] (val options: PipelineOptions,

private[scio] def applyInternal[Output <: POutput](root: PTransform[_ >: PBegin, Output])
: Output =
pipeline.apply(CallSites.getCurrent, root)
pipeline.apply(this.tfName, root)

/**
* Get an SCollection for an object file.
Expand Down
340 changes: 170 additions & 170 deletions scio-core/src/main/scala/com/spotify/scio/util/MultiJoin.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ package com.spotify.scio.values
import com.google.cloud.dataflow.sdk.coders.Coder
import com.google.cloud.dataflow.sdk.transforms.{Combine, DoFn, PTransform, ParDo}
import com.google.cloud.dataflow.sdk.values.{KV, PCollection, POutput}
import com.spotify.scio.util.CallSites
import com.spotify.scio.{Implicits, ScioContext}

import scala.reflect.ClassTag

private[values] trait PCollectionWrapper[T] {
private[values] trait PCollectionWrapper[T] extends TransformNameable {

import Implicits._

Expand All @@ -39,7 +38,7 @@ private[values] trait PCollectionWrapper[T] {

private[scio] def applyInternal[Output <: POutput]
(transform: PTransform[_ >: PCollection[T], Output]): Output =
internal.apply(CallSites.getCurrent, transform)
internal.apply(this.tfName, transform)

protected def pApply[U: ClassTag]
(transform: PTransform[_ >: PCollection[T], PCollection[U]]): SCollection[U] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ object SCollection {
def unionAll[T: ClassTag](scs: Iterable[SCollection[T]]): SCollection[T] = {
val o = PCollectionList
.of(scs.map(_.internal).asJava)
.apply(CallSites.getCurrent, Flatten.pCollections())
new SCollectionImpl(o, scs.head.context)
.apply("UnionAll", Flatten.pCollections())
scs.head.context.wrap(o)
}

import scala.language.implicitConversions
Expand Down Expand Up @@ -126,7 +126,7 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] {
/** Apply a transform. */
private[values] def transform[U: ClassTag](f: SCollection[T] => SCollection[U])
: SCollection[U] = {
val o = internal.apply(CallSites.getCurrent, new PTransform[PCollection[T], PCollection[U]]() {
val o = internal.apply(this.tfName, new PTransform[PCollection[T], PCollection[U]]() {
override def apply(input: PCollection[T]): PCollection[U] = {
f(context.wrap(input)).internal
}
Expand Down Expand Up @@ -163,7 +163,7 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] {
def union(that: SCollection[T]): SCollection[T] = {
val o = PCollectionList
.of(internal).and(that.internal)
.apply(CallSites.getCurrent, Flatten.pCollections())
.apply(this.tfName, Flatten.pCollections())
context.wrap(o)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ import scala.reflect.ClassTag
*/
class SCollectionWithHotKeyFanout[K: ClassTag, V: ClassTag]
(private val self: PairSCollectionFunctions[K, V],
private val hotKeyFanout: Either[K => Int, Int]) {
private val hotKeyFanout: Either[K => Int, Int])
extends TransformNameable {

private def withFanout[K, I, O](combine: Combine.PerKey[K, I, O])
: PerKeyWithHotKeyFanout[K, I, O] = this.hotKeyFanout match {
Expand All @@ -42,6 +43,11 @@ class SCollectionWithHotKeyFanout[K: ClassTag, V: ClassTag]
combine.withHotKeyFanout(f)
}

override def withName(name: String): this.type = {
self.self.withName(name)
this
}

/**
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
* This function can return a different result type, U, than the type of the values in this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class SCollectionWithSideInput[T: ClassTag] private[values] (val internal: PColl
.withOutputTags(_mainTag.tupleTag, sideTags)
.of(transformWithSideOutputsFn(sideOutputs, f))

val pCollectionWrapper = this.internal.apply(CallSites.getCurrent, transform)
val pCollectionWrapper = this.internal.apply("TransformWithSideOutputs", transform)
pCollectionWrapper.getAll.asScala
.mapValues(context.wrap(_).asInstanceOf[SCollection[T]].setCoder(internal.getCoder))
.flatMap{ case(tt, col) => Try{tagToSide(tt.getId) -> col}.toOption }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2016 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.values

import com.spotify.scio.util.CallSites

trait TransformNameable {
private var nameProvider: TransformNameProvider = CallSiteNameProvider

private[scio] def tfName: String = {
val n = nameProvider.name
nameProvider = CallSiteNameProvider
n
}

def withName(name: String): this.type = {
require(nameProvider.getClass != classOf[ConstNameProvider],
s"withName() has already been used to set '${tfName}' as the name for the next transform.")
nameProvider = new ConstNameProvider(name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an require() here against withName().withName() ?

this
}
}

private trait TransformNameProvider {
def name: String
}

private object CallSiteNameProvider extends TransformNameProvider {
def name: String = CallSites.getCurrent
}

private class ConstNameProvider(val name: String) extends TransformNameProvider
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Copyright 2016 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.values

import com.spotify.scio.testing.PipelineSpec
import com.spotify.scio.util.MultiJoin

class NamedTransformTest extends PipelineSpec {

"ScioContext" should "support custom transform name" in {
runWithContext { sc =>
val p = sc.withName("ReadInput").parallelize(Seq("a", "b", "c"))
assertTransformNameEquals(p, "ReadInput/Read(InMemorySource)")
}
}

"SCollection" should "support custom transform name" in {
runWithContext { sc =>
val p = sc.parallelize(Seq(1, 2, 3, 4, 5))
.map(_ * 3)
.withName("OnlyEven").filter(_ % 2 == 0)
assertTransformNameEquals(p, "OnlyEven/Filter")
}
}

"DoubleSCollectionFunctions" should "support custom transform name" in {
runWithContext { sc =>
val p = sc.parallelize(Seq(1.0, 2.0, 3.0, 4.0, 5.0))
.withName("CalcVariance").variance
assertOuterTransformNameEquals(p, "CalcVariance")
}
}

"PairSCollectionFunctions" should "support custom transform name" in {
runWithContext { sc =>
val p = sc.parallelize(Seq(("a", 1), ("b", 2), ("c", 3)))
.withName("SumPerKey").sumByKey
assertTransformNameEquals(p, "SumPerKey/KvToTuple")
}
}

"SCollectionWithAccumulator" should "support custom transform name" in {
runWithContext { sc =>
val intSum = sc.sumAccumulator[Int]("IntSum")
val p = sc.parallelize(Seq(1, 2, 3, 4, 5))
.withAccumulator(intSum)
.withName("TripleSum").map { (i, c) =>
val n = i * 3
c.addValue(intSum, n)
n
}
assertTransformNameEquals(p, "TripleSum")
}
}

"SCollectionWithFanout" should "support custom transform name" in {
runWithContext { sc =>
val p = sc.parallelize(Seq(1, 2, 3)).withFanout(10)
.withName("Sum").sum
assertTransformNameEquals(p, "Sum/Values/Values")
}
}

"SCollectionWithHotKeyFanout" should "support custom transform name" in {
runWithContext { sc =>
val p = sc.parallelize(Seq(("a", 1), ("b", 2), ("c", 3))).withHotKeyFanout(10)
.withName("Sum").sumByKey
assertTransformNameEquals(p, "Sum/KvToTuple")
}
}

"SCollectionWithSideInput" should "support custom transform name" in {
runWithContext { sc =>
val p1 = sc.parallelize(Seq("a", "b", "c"))
val p2 = sc.parallelize(Seq(1, 2, 3)).asListSideInput
val s = p1.withSideInputs(p2)
.withName("GetX").filter((x, s) => x == "a")
assertTransformNameEquals(s, "GetX")
}
}

"SCollectionWithSideOutput" should "support custom transform name" in {
runWithContext { sc =>
val p1 = sc.parallelize(Seq("a", "b", "c"))
val p2 = SideOutput[String]()
val (main, side) = p1.withSideOutputs(p2)
.withName("MakeSideOutput").map { (x, s) => s.output(p2, x + "2"); x + "1" }
assertTransformNameEquals(main, "MakeSideOutput")
assertTransformNameEquals(side(p2), "MakeSideOutput")
}
}

"WindowedSCollection" should "support custom transform name" in {
runWithContext { sc =>
val p = sc.parallelize(Seq(1, 2, 3, 4, 5))
.toWindowed
.withName("Triple").map(x => x.withValue(x.value * 3))
assertTransformNameEquals(p, "Triple")
}
}

"MultiJoin" should "support custom transform name" in {
runWithContext { sc =>
val p1 = sc.parallelize(Seq(("a", 1), ("a", 2), ("b", 3), ("c", 4)))
val p2 = sc.parallelize(Seq(("a", 11), ("b", 12), ("b", 13), ("d", 14)))
val p = MultiJoin.withName("JoinEm").left(p1, p2)
assertTransformNameEquals(p, "JoinEm")
}
}

"Duplicate transform name" should "have number to make unique" in {
runWithContext { sc =>
val p1 = sc.parallelize(1 to 5)
.withName("MyTransform").map(_ * 2)
val p2 = p1
.withName("MyTransform").map(_ * 3)
val p3 = p1
.withName("MyTransform").map(_ * 4)
assertTransformNameEquals(p1, "MyTransform")
assertTransformNameEquals(p2, "MyTransform2")
assertTransformNameEquals(p3, "MyTransform3")
}
}

"TransformNameable" should "prevent repeated calls to .withName" in {
intercept[IllegalArgumentException](runWithContext { sc =>
val p1 = sc.parallelize(1 to 5)
.withName("Double").withName("DoubleMap").map(_ * 2)
}).getMessage shouldBe "requirement failed: withName() has already been used to set 'Double'" +
" as the name for the next transform."
}

private def assertTransformNameEquals(p: PCollectionWrapper[_], tfName: String) =
p.internal.getProducingTransformInternal.getFullName shouldBe tfName

private def assertOuterTransformNameEquals(p: PCollectionWrapper[_], tfName: String) =
p.internal.getProducingTransformInternal.getFullName.split("/").head shouldBe tfName
}
20 changes: 10 additions & 10 deletions scripts/multijoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def cogroup(out, n):
print >> out, ' .of(tagA, a.toKV.internal)'
for x in vals[1:]:
print >> out, ' .and(tag%s, %s.toKV.internal)' % (x, x.lower())
print >> out, ' .apply(CallSites.getCurrent, CoGroupByKey.create())'
print >> out, ' .apply("CoGroupByKey", CoGroupByKey.create())'

print >> out, ' a.context.wrap(keyed).map { kv =>'
print >> out, ' a.context.wrap(keyed).withName(this.tfName).map { kv =>'
print >> out, ' val (key, result) = (kv.getKey, kv.getValue)'
print >> out, ' (key, (%s))' % ', '.join('result.getAll(tag%s).asScala' % x for x in vals) # NOQA
print >> out, ' }'
Expand All @@ -85,9 +85,9 @@ def join(out, n):
print >> out, ' .of(tagA, a.toKV.internal)'
for x in vals[1:]:
print >> out, ' .and(tag%s, %s.toKV.internal)' % (x, x.lower())
print >> out, ' .apply(CallSites.getCurrent, CoGroupByKey.create())'
print >> out, ' .apply("CoGroupByKey", CoGroupByKey.create())'

print >> out, ' a.context.wrap(keyed).flatMap { kv =>'
print >> out, ' a.context.wrap(keyed).withName(this.tfName).flatMap { kv =>'
print >> out, ' val (key, result) = (kv.getKey, kv.getValue)'
print >> out, ' for {'
for x in reversed(vals):
Expand All @@ -111,9 +111,9 @@ def left(out, n):
print >> out, ' .of(tagA, a.toKV.internal)'
for x in vals[1:]:
print >> out, ' .and(tag%s, %s.toKV.internal)' % (x, x.lower())
print >> out, ' .apply(CallSites.getCurrent, CoGroupByKey.create())'
print >> out, ' .apply("CoGroupByKey", CoGroupByKey.create())'

print >> out, ' a.context.wrap(keyed).flatMap { kv =>'
print >> out, ' a.context.wrap(keyed).withName(this.tfName).flatMap { kv =>'
print >> out, ' val (key, result) = (kv.getKey, kv.getValue)'
print >> out, ' for {'
for (i, x) in enumerate(reversed(vals)):
Expand All @@ -140,9 +140,9 @@ def outer(out, n):
print >> out, ' .of(tagA, a.toKV.internal)'
for x in vals[1:]:
print >> out, ' .and(tag%s, %s.toKV.internal)' % (x, x.lower())
print >> out, ' .apply(CallSites.getCurrent, CoGroupByKey.create())'
print >> out, ' .apply("CoGroupByKey", CoGroupByKey.create())'

print >> out, ' a.context.wrap(keyed).flatMap { kv =>'
print >> out, ' a.context.wrap(keyed).withName(this.tfName).flatMap { kv =>'
print >> out, ' val (key, result) = (kv.getKey, kv.getValue)'
print >> out, ' for {'
for (i, x) in enumerate(reversed(vals)):
Expand Down Expand Up @@ -186,12 +186,12 @@ def main(out):
import com.google.cloud.dataflow.sdk.transforms.join.{CoGroupByKey, KeyedPCollectionTuple} # NOQA
import com.google.cloud.dataflow.sdk.values.TupleTag
import com.google.common.collect.Lists
import com.spotify.scio.values.SCollection
import com.spotify.scio.values.{SCollection, TransformNameable}

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

object MultiJoin {
object MultiJoin extends TransformNameable {

def toOptions[T](xs: Iterator[T]): Iterator[Option[T]] = if (xs.isEmpty) Iterator(None) else xs.map(Option(_))
''').replace(' # NOQA', '').lstrip('\n')
Expand Down