Skip to content

Commit

Permalink
Scio execution graph
Browse files Browse the repository at this point in the history
  • Loading branch information
shnapz committed Sep 30, 2024
1 parent b2c4ff1 commit 63d91f5
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 22 deletions.
45 changes: 28 additions & 17 deletions scio-core/src/main/scala/com/spotify/scio/ScioContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.io.File
import java.net.URI
import java.nio.file.Files
import com.spotify.scio.coders.{Coder, CoderMaterializer, KVCoder}
import com.spotify.scio.graph.{CustomInput, Parallelize, StepInfo, TransformStep, UnionAll}
import com.spotify.scio.io._
import com.spotify.scio.metrics.Metrics
import com.spotify.scio.options.ScioOptions
Expand Down Expand Up @@ -524,8 +525,8 @@ class ScioContext private[scio] (
private var _onClose: Unit => Unit = identity

/** Wrap a [[org.apache.beam.sdk.values.PCollection PCollection]]. */
def wrap[T](p: PCollection[T]): SCollection[T] =
new SCollectionImpl[T](p, this)
def wrap[T](p: PCollection[T], step: StepInfo): SCollection[T] =
new SCollectionImpl[T](p, this, step)

/** Add callbacks calls when the context is closed. */
private[scio] def onClose(f: Unit => Unit): Unit =
Expand Down Expand Up @@ -687,25 +688,33 @@ class ScioContext private[scio] (

private[scio] def applyTransform[U](
name: Option[String],
root: PTransform[_ >: PBegin, PCollection[U]]
root: PTransform[_ >: PBegin, PCollection[U]],
step: StepInfo
): SCollection[U] =
wrap(applyInternal(name, root))
wrap(applyInternal(name, root), step)

private[scio] def applyTransform[U](
root: PTransform[_ >: PBegin, PCollection[U]]
root: PTransform[_ >: PBegin, PCollection[U]],
step: StepInfo
): SCollection[U] =
applyTransform(None, root)
applyTransform(None, root, step)

private[scio] def applyTransform[U](
name: String,
root: PTransform[_ >: PBegin, PCollection[U]]
root: PTransform[_ >: PBegin, PCollection[U]],
step: StepInfo
): SCollection[U] =
applyTransform(Option(name), root)
applyTransform(Option(name), root, step)

def transform[U](f: ScioContext => SCollection[U]): SCollection[U] = transform(this.tfName)(f)

def transform[U](name: String)(f: ScioContext => SCollection[U]): SCollection[U] =
wrap(transform_(name)(f(_).internal))
def transform[U](name: String)(f: ScioContext => SCollection[U]): SCollection[U] = {
val transformed = transform_(name)(sc => SCollectionOutput(f(sc)))
wrap(
transformed.scioCollection.internal,
TransformStep(name, transformed.scioCollection.step)
)
}

private[scio] def transform_[U <: POutput](f: ScioContext => U): U =
transform_(tfName)(f)
Expand Down Expand Up @@ -760,7 +769,7 @@ class ScioContext private[scio] (
if (this.isTest) {
TestDataManager.getInput(testId.get)(CustomIO[T](name)).toSCollection(this)
} else {
applyTransform(name, transform)
applyTransform(name, transform, CustomInput(name))
}
}

Expand Down Expand Up @@ -788,13 +797,15 @@ class ScioContext private[scio] (
// `T: Coder` context bound is required since `scs` might be empty.
def unionAll[T: Coder](scs: => Iterable[SCollection[T]]): SCollection[T] = {
val tfName = this.tfName // evaluate eagerly to avoid overriding `scs` names
scs match {
scs.toList match {
case Nil => empty()
case contents =>
val sources = contents.map(_.step)
wrap(
PCollectionList
.of(contents.map(_.internal).asJava)
.apply(tfName, Flatten.pCollections())
.apply(tfName, Flatten.pCollections()),
UnionAll(tfName, sources)
)
}
}
Expand All @@ -809,7 +820,7 @@ class ScioContext private[scio] (
def parallelize[T: Coder](elems: Iterable[T]): SCollection[T] =
requireNotClosed {
val coder = CoderMaterializer.beam(this, Coder[T])
this.applyTransform(Create.of(elems.asJava).withCoder(coder))
this.applyTransform(Create.of(elems.asJava).withCoder(coder), Parallelize)
}

/**
Expand All @@ -822,7 +833,7 @@ class ScioContext private[scio] (
requireNotClosed {
val coder = CoderMaterializer.beam(this, KVCoder(koder, voder))
this
.applyTransform(Create.of(elems.asJava).withCoder(coder))
.applyTransform(Create.of(elems.asJava).withCoder(coder), Parallelize)
.map(kv => (kv.getKey, kv.getValue))
}

Expand All @@ -834,7 +845,7 @@ class ScioContext private[scio] (
requireNotClosed {
val coder = CoderMaterializer.beam(this, Coder[T])
val v = elems.map(t => TimestampedValue.of(t._1, t._2))
this.applyTransform(Create.timestamped(v.asJava).withCoder(coder))
this.applyTransform(Create.timestamped(v.asJava).withCoder(coder), Parallelize)
}

/**
Expand All @@ -848,7 +859,7 @@ class ScioContext private[scio] (
requireNotClosed {
val coder = CoderMaterializer.beam(this, Coder[T])
val v = elems.zip(timestamps).map(t => TimestampedValue.of(t._1, t._2))
this.applyTransform(Create.timestamped(v.asJava).withCoder(coder))
this.applyTransform(Create.timestamped(v.asJava).withCoder(coder), Parallelize)
}

// =======================================================================
Expand Down
39 changes: 39 additions & 0 deletions scio-core/src/main/scala/com/spotify/scio/graph/StepInfo.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.spotify.scio.graph

private[scio] trait StepInfo {
val name: String
val sources: List[StepInfo]
}

private[scio] case class TransformStep(name: String, sources: List[StepInfo]) extends StepInfo

private[scio] object TransformStep {
def apply(name: String, source: StepInfo): TransformStep = TransformStep(name, List(source))
}

// preserve link to an original PTransform?
private[scio] case class CustomInput(name: String) extends StepInfo {
val sources: List[StepInfo] = List.empty
}

private[scio] case class UnionAll(name: String, sources: List[StepInfo]) extends StepInfo

private[scio] object Parallelize extends StepInfo {
override val name: String = "parallelize"
override val sources: List[StepInfo] = List.empty
}

private[scio] case class ReadTextIO(filePattern: String) extends StepInfo {
override val name: String = null
override val sources: List[StepInfo] = List()
}

private[scio] case class TestInput(kind: String) extends StepInfo {
override val name: String = kind
override val sources: List[StepInfo] = List()
}

private[scio] case class FlatMap(source: StepInfo) extends StepInfo {
val name: String = null
override val sources: List[StepInfo] = List(source)
}
3 changes: 2 additions & 1 deletion scio-core/src/main/scala/com/spotify/scio/io/TextIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.nio.channels.Channels
import java.util.Collections
import com.spotify.scio.ScioContext
import com.spotify.scio.coders.{Coder, CoderMaterializer}
import com.spotify.scio.graph.ReadTextIO
import com.spotify.scio.util.ScioUtil
import com.spotify.scio.util.FilenamePolicySupplier
import com.spotify.scio.values.SCollection
Expand Down Expand Up @@ -51,7 +52,7 @@ final case class TextIO(path: String) extends ScioIO[String] {
.withCompression(params.compression)
.withEmptyMatchTreatment(params.emptyMatchTreatment)

sc.applyTransform(t)
sc.applyTransform(t, ReadTextIO(filePattern))
.setCoder(coder)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.spotify.scio.testing

import com.spotify.scio.coders.Coder
import com.spotify.scio.graph.TestInput
import com.spotify.scio.io.ScioIO
import com.spotify.scio.values.SCollection
import com.spotify.scio.{ScioContext, ScioResult}
Expand Down Expand Up @@ -46,7 +47,7 @@ final private[scio] case class TestStreamInputSource[T](
)

override def toSCollection(sc: ScioContext): SCollection[T] =
sc.applyTransform(stream)
sc.applyTransform(stream, TestInput("TestStream"))

override def toString: String = s"TestStream(${stream.getEvents})"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ trait SCollectionSafeSyntax {
tuple
.get(errorTag)
.setCoder(CoderMaterializer.beam(self.context, Coder[(T, Throwable)]))
(self.context.wrap(main), self.context.wrap(errorPipe))
(self.context.wrap(main, FlatMap()), self.context.wrap(errorPipe))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.spotify.scio.values

import com.spotify.scio.ScioContext
import com.spotify.scio.coders.{BeamCoders, Coder}
import com.spotify.scio.graph.StepInfo
import org.apache.beam.sdk.transforms.PTransform
import org.apache.beam.sdk.values.{PCollection, POutput}

Expand All @@ -27,6 +28,8 @@ private[values] trait PCollectionWrapper[T] extends TransformNameable {
/** The [[org.apache.beam.sdk.values.PCollection PCollection]] being wrapped internally. */
val internal: PCollection[T]

val step: StepInfo

implicit def coder: Coder[T] = BeamCoders.getCoder(internal)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import com.spotify.scio.estimators.{
ApproximateUniqueCounter,
ApproximateUniqueCounterByError
}
import com.spotify.scio.graph.StepInfo
import com.spotify.scio.io._
import com.spotify.scio.schemas.{Schema, SchemaMaterializer}
import com.spotify.scio.testing.TestDataManager
Expand Down Expand Up @@ -1763,5 +1764,8 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] {
io.writeWithContext(this, ())
}

private[scio] class SCollectionImpl[T](val internal: PCollection[T], val context: ScioContext)
extends SCollection[T] {}
private[scio] class SCollectionImpl[T](
val internal: PCollection[T],
val context: ScioContext,
val step: StepInfo
) extends SCollection[T] {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.spotify.scio.values;

import org.apache.beam.sdk.Pipeline
import org.apache.beam.sdk.transforms.PTransform
import org.apache.beam.sdk.values.{PInput, POutput, PValue, TupleTag}

import java.util;

case class SCollectionOutput[T](scioCollection: SCollection[T]) extends POutput {
override def getPipeline: Pipeline = scioCollection.internal.getPipeline

override def expand(): util.Map[TupleTag[_], PValue] = scioCollection.internal.expand()

override def finishSpecifyingOutput(
transformName: String,
input: PInput,
transform: PTransform[_, _]
): Unit = scioCollection.internal.finishSpecifyingOutput(transformName, input, transform)
}

0 comments on commit 63d91f5

Please sign in to comment.