diff --git a/scio-core/src/main/scala/com/spotify/scio/ScioContext.scala b/scio-core/src/main/scala/com/spotify/scio/ScioContext.scala index 0e1a0063ba..382e7d47a1 100644 --- a/scio-core/src/main/scala/com/spotify/scio/ScioContext.scala +++ b/scio-core/src/main/scala/com/spotify/scio/ScioContext.scala @@ -28,7 +28,7 @@ import com.google.cloud.dataflow.sdk.Pipeline import com.google.cloud.dataflow.sdk.PipelineResult.State import com.google.cloud.dataflow.sdk.coders.TableRowJsonCoder import com.google.cloud.dataflow.sdk.{io => gio} -import com.google.cloud.dataflow.sdk.options.{DataflowPipelineOptions, PipelineOptions} +import com.google.cloud.dataflow.sdk.options._ import com.google.cloud.dataflow.sdk.runners.{DataflowPipelineJob, DataflowPipelineRunner} import com.google.cloud.dataflow.sdk.testing.TestPipeline import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn @@ -50,7 +50,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.{Buffer => MBuffer, Set => MSet} import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag -import scala.util.{Failure, Try} +import scala.util.Try import scala.util.control.NonFatal /** Convenience object for creating [[ScioContext]] and [[Args]]. */ @@ -131,34 +131,34 @@ class ScioContext private[scio] (val options: PipelineOptions, import Implicits._ - // TODO: decouple DataflowPipelineOptions - private[scio] def dfOptions: Try[DataflowPipelineOptions] = Try { - options.as(classOf[DataflowPipelineOptions]) - }.orElse { - val name = options.getClass.getSimpleName - Failure(new RuntimeException(s"$name is not DataflowPipelineOptions")) + /** Get PipelineOptions as a more specific sub-type. */ + def optionsAs[T <: PipelineOptions : ClassTag]: Try[T] = Try { + options.as(ScioUtil.classOf[T]) } // Set default name if no app name specified by user - dfOptions.foreach { o => + optionsAs[ApplicationNameOptions].foreach { o => if (o.getAppName == null || o.getAppName.startsWith("ScioContext$")) { this.setName(CallSites.getAppName) } } - private[scio] val testId: Option[String] = dfOptions.toOption.flatMap { o => - if ("JobTest-[0-9]+".r.pattern.matcher(o.getAppName).matches()) { - Some(o.getAppName) - } else { - None + private[scio] val testId: Option[String] = + this.optionsAs[ApplicationNameOptions].toOption.flatMap { o => + if ("JobTest-[0-9]+".r.pattern.matcher(o.getAppName).matches()) { + Some(o.getAppName) + } else { + None + } } - } /** Dataflow pipeline. */ def pipeline: Pipeline = { if (_pipeline == null) { // TODO: make sure this works for other PipelineOptions - dfOptions.foreach(_.setFilesToStage(getFilesToStage(artifacts).asJava)) + this + .optionsAs[DataflowPipelineWorkerPoolOptions] + .foreach(_.setFilesToStage(getFilesToStage(artifacts).asJava)) _pipeline = if (testId.isEmpty) { Pipeline.create(options) } else { @@ -226,7 +226,9 @@ class ScioContext private[scio] (val options: PipelineOptions, // ======================================================================= private lazy val bigQueryClient: BigQueryClient = - BigQueryClient(dfOptions.get.getProject, dfOptions.get.getGcpCredential) + BigQueryClient( + this.optionsAs[DataflowPipelineOptions].get.getProject, + this.optionsAs[GcpOptions].get.getGcpCredential) // ======================================================================= // States @@ -238,10 +240,10 @@ class ScioContext private[scio] (val options: PipelineOptions, throw new RuntimeException("Cannot set name once pipeline is initialized") } // override app name and job name - dfOptions.foreach { o => - o.setAppName(name) - o.setJobName(new DataflowPipelineOptions.JobNameFactory().create(options)) - } + this.optionsAs[ApplicationNameOptions].foreach(_.setAppName(name)) + this + .optionsAs[DataflowPipelineOptions] + .foreach(_.setJobName(new DataflowPipelineOptions.JobNameFactory().create(options))) } /** Close the context. No operation can be performed once the context is closed. */ @@ -621,7 +623,7 @@ class DistCacheScioContext private[scio] (self: ScioContext) { if (self.isTest) { new MockDistCache(testDistCache(DistCacheIO(uri))) } else { - new DistCacheSingle(new URI(uri), initFn, self.dfOptions.get) + new DistCacheSingle(new URI(uri), initFn, self.optionsAs[GcsOptions].get) } } @@ -635,7 +637,7 @@ class DistCacheScioContext private[scio] (self: ScioContext) { if (self.isTest) { new MockDistCache(testDistCache(DistCacheIO(uris.mkString("\t")))) } else { - new DistCacheMulti(uris.map(new URI(_)), initFn, self.dfOptions.get) + new DistCacheMulti(uris.map(new URI(_)), initFn, self.optionsAs[GcsOptions].get) } } diff --git a/scio-examples/src/main/scala/com/spotify/scio/examples/WordCount.scala b/scio-examples/src/main/scala/com/spotify/scio/examples/WordCount.scala index f21a8e3cf5..05a9d0317e 100644 --- a/scio-examples/src/main/scala/com/spotify/scio/examples/WordCount.scala +++ b/scio-examples/src/main/scala/com/spotify/scio/examples/WordCount.scala @@ -35,7 +35,7 @@ runMain object WordCount { def main(cmdlineArgs: Array[String]): Unit = { val (sc, args) = ContextAndArgs(cmdlineArgs) - val dfOptions = sc.options.as(classOf[DataflowPipelineOptions]) + val dfOptions = sc.optionsAs[DataflowPipelineOptions].get val input = args.getOrElse("input", ExampleData.KING_LEAR) val output = args.optional("output").getOrElse( diff --git a/scio-examples/src/main/scala/com/spotify/scio/examples/complete/game/GameStats.scala b/scio-examples/src/main/scala/com/spotify/scio/examples/complete/game/GameStats.scala index d471720acc..6a14bbccfa 100644 --- a/scio-examples/src/main/scala/com/spotify/scio/examples/complete/game/GameStats.scala +++ b/scio-examples/src/main/scala/com/spotify/scio/examples/complete/game/GameStats.scala @@ -55,7 +55,7 @@ object GameStats { // scalastyle:off method.length def main(cmdlineArgs: Array[String]): Unit = { val (sc, args) = ContextAndArgs(cmdlineArgs) - val dataflowUtils = new DataflowExampleUtils(sc.options.as(classOf[DataflowPipelineOptions])) + val dataflowUtils = new DataflowExampleUtils(sc.optionsAs[DataflowPipelineOptions].get) def fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSS") .withZone(DateTimeZone.forTimeZone(TimeZone.getTimeZone("PST"))) diff --git a/scio-examples/src/main/scala/com/spotify/scio/examples/complete/game/LeaderBoard.scala b/scio-examples/src/main/scala/com/spotify/scio/examples/complete/game/LeaderBoard.scala index f1e49cf9fa..ad58f098ca 100644 --- a/scio-examples/src/main/scala/com/spotify/scio/examples/complete/game/LeaderBoard.scala +++ b/scio-examples/src/main/scala/com/spotify/scio/examples/complete/game/LeaderBoard.scala @@ -39,7 +39,7 @@ object LeaderBoard { def main(cmdlineArgs: Array[String]): Unit = { val (sc, args) = ContextAndArgs(cmdlineArgs) - val dataflowUtils = new DataflowExampleUtils(sc.options.as(classOf[DataflowPipelineOptions])) + val dataflowUtils = new DataflowExampleUtils(sc.optionsAs[DataflowPipelineOptions].get) def fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSS") .withZone(DateTimeZone.forTimeZone(TimeZone.getTimeZone("PST"))) diff --git a/scio-examples/src/main/scala/com/spotify/scio/examples/cookbook/DeDupExample.scala b/scio-examples/src/main/scala/com/spotify/scio/examples/cookbook/DeDupExample.scala index 53bcb2699b..9a654135f4 100644 --- a/scio-examples/src/main/scala/com/spotify/scio/examples/cookbook/DeDupExample.scala +++ b/scio-examples/src/main/scala/com/spotify/scio/examples/cookbook/DeDupExample.scala @@ -34,7 +34,7 @@ runMain object DeDupExample { def main(cmdlineArgs: Array[String]): Unit = { val (sc, args) = ContextAndArgs(cmdlineArgs) - val dfOptions = sc.options.as(classOf[DataflowPipelineOptions]) + val dfOptions = sc.optionsAs[DataflowPipelineOptions].get val input = args.getOrElse("input", ExampleData.SHAKESPEARE_ALL) val output = args.optional("output").getOrElse( diff --git a/scio-hdfs/src/main/scala/com/spotify/scio/hdfs/package.scala b/scio-hdfs/src/main/scala/com/spotify/scio/hdfs/package.scala index b910c4683f..e70ca54529 100644 --- a/scio-hdfs/src/main/scala/com/spotify/scio/hdfs/package.scala +++ b/scio-hdfs/src/main/scala/com/spotify/scio/hdfs/package.scala @@ -140,7 +140,7 @@ package object hdfs { self.distCache(paths)(initFn) } else { //TODO: should upload be asynchronous, blocking on context close - val dfOptions = self.options.as(classOf[DataflowPipelineOptions]) + val dfOptions = self.optionsAs[DataflowPipelineOptions].get require(dfOptions.getStagingLocation != null, "Staging directory not set - use `--stagingLocation`!") require(!paths.contains(null), "Artifact path can't be null") @@ -183,7 +183,7 @@ package object hdfs { val inStream = fs.open(src) //TODO: Should we attempt to detect the Mime type rather than always using MimeTypes.BINARY? - val dfOptions = self.options.as(classOf[DataflowPipelineOptions]) + val dfOptions = self.optionsAs[DataflowPipelineOptions].get val outChannel = Channels.newOutputStream( dfOptions.getGcsUtil.create(GcsPath.fromUri(target), MimeTypes.BINARY))