Skip to content

Commit

Permalink
make PipelineOptions handling more generic #139
Browse files Browse the repository at this point in the history
  • Loading branch information
nevillelyh committed Jun 22, 2016
1 parent 2fa8a3c commit 4567650
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 31 deletions.
48 changes: 23 additions & 25 deletions scio-core/src/main/scala/com/spotify/scio/ScioContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import com.google.cloud.dataflow.sdk.Pipeline
import com.google.cloud.dataflow.sdk.PipelineResult.State
import com.google.cloud.dataflow.sdk.coders.TableRowJsonCoder
import com.google.cloud.dataflow.sdk.{io => gio}
import com.google.cloud.dataflow.sdk.options.{DataflowPipelineOptions, PipelineOptions}
import com.google.cloud.dataflow.sdk.options._
import com.google.cloud.dataflow.sdk.runners.{DataflowPipelineJob, DataflowPipelineRunner}
import com.google.cloud.dataflow.sdk.testing.TestPipeline
import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn
Expand All @@ -50,7 +50,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.{Buffer => MBuffer, Set => MSet}
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag
import scala.util.{Failure, Try}
import scala.util.Try
import scala.util.control.NonFatal

/** Convenience object for creating [[ScioContext]] and [[Args]]. */
Expand Down Expand Up @@ -131,34 +131,31 @@ class ScioContext private[scio] (val options: PipelineOptions,

import Implicits._

// TODO: decouple DataflowPipelineOptions
private[scio] def dfOptions: Try[DataflowPipelineOptions] = Try {
options.as(classOf[DataflowPipelineOptions])
}.orElse {
val name = options.getClass.getSimpleName
Failure(new RuntimeException(s"$name is not DataflowPipelineOptions"))
}
/** Get PipelineOptions as a more specific sub-type. */
def optionsAs[T <: PipelineOptions : ClassTag]: T = options.as(ScioUtil.classOf[T])

// Set default name if no app name specified by user
dfOptions.foreach { o =>
Try(optionsAs[ApplicationNameOptions]).foreach { o =>
if (o.getAppName == null || o.getAppName.startsWith("ScioContext$")) {
this.setName(CallSites.getAppName)
}
}

private[scio] val testId: Option[String] = dfOptions.toOption.flatMap { o =>
if ("JobTest-[0-9]+".r.pattern.matcher(o.getAppName).matches()) {
Some(o.getAppName)
} else {
None
private[scio] val testId: Option[String] =
Try(optionsAs[ApplicationNameOptions]).toOption.flatMap { o =>
if ("JobTest-[0-9]+".r.pattern.matcher(o.getAppName).matches()) {
Some(o.getAppName)
} else {
None
}
}
}

/** Dataflow pipeline. */
def pipeline: Pipeline = {
if (_pipeline == null) {
// TODO: make sure this works for other PipelineOptions
dfOptions.foreach(_.setFilesToStage(getFilesToStage(artifacts).asJava))
Try(optionsAs[DataflowPipelineWorkerPoolOptions])
.foreach(_.setFilesToStage(getFilesToStage(artifacts).asJava))
_pipeline = if (testId.isEmpty) {
Pipeline.create(options)
} else {
Expand Down Expand Up @@ -225,8 +222,10 @@ class ScioContext private[scio] (val options: PipelineOptions,
// Miscellaneous
// =======================================================================

private lazy val bigQueryClient: BigQueryClient =
BigQueryClient(dfOptions.get.getProject, dfOptions.get.getGcpCredential)
private lazy val bigQueryClient: BigQueryClient = {
val o = optionsAs[GcpOptions]
BigQueryClient(o.getProject, o.getGcpCredential)
}

// =======================================================================
// States
Expand All @@ -238,10 +237,9 @@ class ScioContext private[scio] (val options: PipelineOptions,
throw new RuntimeException("Cannot set name once pipeline is initialized")
}
// override app name and job name
dfOptions.foreach { o =>
o.setAppName(name)
o.setJobName(new DataflowPipelineOptions.JobNameFactory().create(options))
}
Try(optionsAs[ApplicationNameOptions]).foreach(_.setAppName(name))
Try(optionsAs[DataflowPipelineOptions])
.foreach(_.setJobName(new DataflowPipelineOptions.JobNameFactory().create(options)))
}

/** Close the context. No operation can be performed once the context is closed. */
Expand Down Expand Up @@ -621,7 +619,7 @@ class DistCacheScioContext private[scio] (self: ScioContext) {
if (self.isTest) {
new MockDistCache(testDistCache(DistCacheIO(uri)))
} else {
new DistCacheSingle(new URI(uri), initFn, self.dfOptions.get)
new DistCacheSingle(new URI(uri), initFn, self.optionsAs[GcsOptions])
}
}

Expand All @@ -635,7 +633,7 @@ class DistCacheScioContext private[scio] (self: ScioContext) {
if (self.isTest) {
new MockDistCache(testDistCache(DistCacheIO(uris.mkString("\t"))))
} else {
new DistCacheMulti(uris.map(new URI(_)), initFn, self.dfOptions.get)
new DistCacheMulti(uris.map(new URI(_)), initFn, self.optionsAs[GcsOptions])
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ runMain
object WordCount {
def main(cmdlineArgs: Array[String]): Unit = {
val (sc, args) = ContextAndArgs(cmdlineArgs)
val dfOptions = sc.options.as(classOf[DataflowPipelineOptions])
val dfOptions = sc.optionsAs[DataflowPipelineOptions]

val input = args.getOrElse("input", ExampleData.KING_LEAR)
val output = args.optional("output").getOrElse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object GameStats {
// scalastyle:off method.length
def main(cmdlineArgs: Array[String]): Unit = {
val (sc, args) = ContextAndArgs(cmdlineArgs)
val dataflowUtils = new DataflowExampleUtils(sc.options.as(classOf[DataflowPipelineOptions]))
val dataflowUtils = new DataflowExampleUtils(sc.optionsAs[DataflowPipelineOptions])

def fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSS")
.withZone(DateTimeZone.forTimeZone(TimeZone.getTimeZone("PST")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object LeaderBoard {

def main(cmdlineArgs: Array[String]): Unit = {
val (sc, args) = ContextAndArgs(cmdlineArgs)
val dataflowUtils = new DataflowExampleUtils(sc.options.as(classOf[DataflowPipelineOptions]))
val dataflowUtils = new DataflowExampleUtils(sc.optionsAs[DataflowPipelineOptions])

def fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSS")
.withZone(DateTimeZone.forTimeZone(TimeZone.getTimeZone("PST")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ runMain
object DeDupExample {
def main(cmdlineArgs: Array[String]): Unit = {
val (sc, args) = ContextAndArgs(cmdlineArgs)
val dfOptions = sc.options.as(classOf[DataflowPipelineOptions])
val dfOptions = sc.optionsAs[DataflowPipelineOptions]

val input = args.getOrElse("input", ExampleData.SHAKESPEARE_ALL)
val output = args.optional("output").getOrElse(
Expand Down
4 changes: 2 additions & 2 deletions scio-hdfs/src/main/scala/com/spotify/scio/hdfs/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ package object hdfs {
self.distCache(paths)(initFn)
} else {
//TODO: should upload be asynchronous, blocking on context close
val dfOptions = self.options.as(classOf[DataflowPipelineOptions])
val dfOptions = self.optionsAs[DataflowPipelineOptions]
require(dfOptions.getStagingLocation != null,
"Staging directory not set - use `--stagingLocation`!")
require(!paths.contains(null), "Artifact path can't be null")
Expand Down Expand Up @@ -183,7 +183,7 @@ package object hdfs {
val inStream = fs.open(src)

//TODO: Should we attempt to detect the Mime type rather than always using MimeTypes.BINARY?
val dfOptions = self.options.as(classOf[DataflowPipelineOptions])
val dfOptions = self.optionsAs[DataflowPipelineOptions]
val outChannel = Channels.newOutputStream(
dfOptions.getGcsUtil.create(GcsPath.fromUri(target), MimeTypes.BINARY))

Expand Down

0 comments on commit 4567650

Please sign in to comment.