Skip to content

Commit

Permalink
decouple DataflowPipelineOptions #139
Browse files Browse the repository at this point in the history
  • Loading branch information
nevillelyh committed Jun 17, 2016
1 parent ae8b810 commit 1b9707a
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 26 deletions.
37 changes: 25 additions & 12 deletions scio-core/src/main/scala/com/spotify/scio/ScioContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.{Buffer => MBuffer, Set => MSet}
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag
import scala.util.{Failure, Try}
import scala.util.control.NonFatal

/** Convenience object for creating [[ScioContext]] and [[Args]]. */
Expand All @@ -70,13 +71,13 @@ object ScioContext {
def apply(): ScioContext = ScioContext(defaultOptions)

/** Create a new [[ScioContext]] instance. */
def apply(options: DataflowPipelineOptions): ScioContext = new ScioContext(options, Nil, None)
def apply(options: PipelineOptions): ScioContext = new ScioContext(options, Nil, None)

/** Create a new [[ScioContext]] instance. */
def apply(artifacts: List[String]): ScioContext = new ScioContext(defaultOptions, artifacts, None)

/** Create a new [[ScioContext]] instance. */
def apply(options: DataflowPipelineOptions, artifacts: List[String]): ScioContext =
def apply(options: PipelineOptions, artifacts: List[String]): ScioContext =
new ScioContext(options, artifacts, None)

/** Create a new [[ScioContext]] instance for testing. */
Expand All @@ -101,7 +102,7 @@ object ScioContext {
(PipelineOptionsFactory.fromArgs(dfArgs).as(cls), Args(appArgs))
}

private val defaultOptions: DataflowPipelineOptions =
private val defaultOptions: PipelineOptions =
PipelineOptionsFactory.fromArgs(Array.empty).as(classOf[DataflowPipelineOptions])

}
Expand All @@ -117,23 +118,33 @@ object ScioContext {
* @groupname Ungrouped Other Members
*/
// scalastyle:off number.of.methods
class ScioContext private[scio] (val options: DataflowPipelineOptions,
class ScioContext private[scio] (val options: PipelineOptions,
private var artifacts: List[String],
testId: Option[String]) {

private val logger = LoggerFactory.getLogger(ScioContext.getClass)

import Implicits._

// TODO: decouple DataflowPipelineOptions
private def dfOptions: Try[DataflowPipelineOptions] = Try {
options.as(classOf[DataflowPipelineOptions])
}.orElse {
val name = options.getClass.getSimpleName
Failure(new RuntimeException(s"$name is not DataflowPipelineOptions"))
}

// Set default name if no app name specified by user
if (options.getAppName == null || options.getAppName.startsWith("ScioContext$")) {
this.setName(CallSites.getAppName)
dfOptions.foreach { o =>
if (o.getAppName == null || o.getAppName.startsWith("ScioContext$")) {
this.setName(CallSites.getAppName)
}
}

/** Dataflow pipeline. */
def pipeline: Pipeline = {
if (_pipeline == null) {
options.setFilesToStage(getFilesToStage(artifacts).asJava)
dfOptions.foreach(_.setFilesToStage(getFilesToStage(artifacts).asJava))
_pipeline = if (testId.isEmpty) {
Pipeline.create(options)
} else {
Expand Down Expand Up @@ -201,7 +212,7 @@ class ScioContext private[scio] (val options: DataflowPipelineOptions,
// =======================================================================

private lazy val bigQueryClient: BigQueryClient =
BigQueryClient(options.getProject, options.getGcpCredential)
BigQueryClient(dfOptions.get.getProject, dfOptions.get.getGcpCredential)

// =======================================================================
// States
Expand All @@ -213,8 +224,10 @@ class ScioContext private[scio] (val options: DataflowPipelineOptions,
throw new RuntimeException("Cannot set name once pipeline is initialized")
}
// override app name and job name
options.setAppName(name)
options.setJobName(new DataflowPipelineOptions.JobNameFactory().create(options))
dfOptions.foreach { o =>
o.setAppName(name)
o.setJobName(new DataflowPipelineOptions.JobNameFactory().create(options))
}
}

/** Close the context. No operation can be performed once the context is closed. */
Expand Down Expand Up @@ -592,7 +605,7 @@ class ScioContext private[scio] (val options: DataflowPipelineOptions,
if (this.isTest) {
new MockDistCache(testDistCache(DistCacheIO(uri)))
} else {
new DistCacheSingle(new URI(uri), initFn, options)
new DistCacheSingle(new URI(uri), initFn, dfOptions.get)
}
}

Expand All @@ -606,7 +619,7 @@ class ScioContext private[scio] (val options: DataflowPipelineOptions,
if (this.isTest) {
new MockDistCache(testDistCache(DistCacheIO(uris.mkString("\t"))))
} else {
new DistCacheMulti(uris.map(new URI(_)), initFn, options)
new DistCacheMulti(uris.map(new URI(_)), initFn, dfOptions.get)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package com.spotify.scio.examples

import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions
import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath
import com.spotify.scio._
import com.spotify.scio.examples.common.ExampleData
Expand All @@ -34,11 +35,12 @@ runMain
object WordCount {
def main(cmdlineArgs: Array[String]): Unit = {
val (sc, args) = ContextAndArgs(cmdlineArgs)
val dfOptions = sc.options.as(classOf[DataflowPipelineOptions])

val input = args.getOrElse("input", ExampleData.KING_LEAR)
val output = args.optional("output").getOrElse(
if (sc.options.getStagingLocation != null) {
GcsPath.fromUri(sc.options.getStagingLocation).resolve("counts.txt").toString
if (dfOptions.getStagingLocation != null) {
GcsPath.fromUri(dfOptions.getStagingLocation).resolve("counts.txt").toString
} else {
throw new IllegalArgumentException("Must specify --output or --stagingLocation")
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package com.spotify.scio.examples.complete.game
import java.util.TimeZone

import com.google.cloud.dataflow.examples.common.DataflowExampleUtils
import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions
import com.google.cloud.dataflow.sdk.transforms.windowing.{IntervalWindow, OutputTimeFns}
import com.spotify.scio._
import com.spotify.scio.experimental._
Expand Down Expand Up @@ -54,7 +55,7 @@ object GameStats {
// scalastyle:off method.length
def main(cmdlineArgs: Array[String]): Unit = {
val (sc, args) = ContextAndArgs(cmdlineArgs)
val dataflowUtils = new DataflowExampleUtils(sc.options)
val dataflowUtils = new DataflowExampleUtils(sc.options.as(classOf[DataflowPipelineOptions]))

def fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSS")
.withZone(DateTimeZone.forTimeZone(TimeZone.getTimeZone("PST")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ package com.spotify.scio.examples.complete.game
import java.util.TimeZone

import com.google.cloud.dataflow.examples.common.DataflowExampleUtils
import com.google.cloud.dataflow.sdk.transforms.windowing.{
AfterProcessingTime, AfterWatermark, IntervalWindow, Repeatedly
}
import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions
import com.google.cloud.dataflow.sdk.transforms.windowing._
import com.spotify.scio._
import com.spotify.scio.experimental._
import com.spotify.scio.values.WindowOptions
Expand All @@ -40,7 +39,7 @@ object LeaderBoard {

def main(cmdlineArgs: Array[String]): Unit = {
val (sc, args) = ContextAndArgs(cmdlineArgs)
val dataflowUtils = new DataflowExampleUtils(sc.options)
val dataflowUtils = new DataflowExampleUtils(sc.options.as(classOf[DataflowPipelineOptions]))

def fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSS")
.withZone(DateTimeZone.forTimeZone(TimeZone.getTimeZone("PST")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package com.spotify.scio.examples.cookbook

import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions
import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath
import com.spotify.scio._
import com.spotify.scio.examples.common.ExampleData
Expand All @@ -33,11 +34,12 @@ runMain
object DeDupExample {
def main(cmdlineArgs: Array[String]): Unit = {
val (sc, args) = ContextAndArgs(cmdlineArgs)
val dfOptions = sc.options.as(classOf[DataflowPipelineOptions])

val input = args.getOrElse("input", ExampleData.SHAKESPEARE_ALL)
val output = args.optional("output").getOrElse(
if (sc.options.getStagingLocation != null) {
GcsPath.fromUri(sc.options.getStagingLocation).resolve("deduped.txt").toString
if (dfOptions.getStagingLocation != null) {
GcsPath.fromUri(dfOptions.getStagingLocation).resolve("deduped.txt").toString
} else {
throw new IllegalArgumentException("Must specify --output or --stagingLocation")
})
Expand Down
9 changes: 6 additions & 3 deletions scio-hdfs/src/main/scala/com/spotify/scio/hdfs/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.google.cloud.dataflow.contrib.hadoop._
import com.google.cloud.dataflow.contrib.hadoop.simpleauth._
import com.google.cloud.dataflow.sdk.coders.AvroCoder
import com.google.cloud.dataflow.sdk.io.{Read, Write}
import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions
import com.google.cloud.dataflow.sdk.util.MimeTypes
import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath
import com.google.cloud.dataflow.sdk.values.KV
Expand Down Expand Up @@ -125,7 +126,8 @@ package object hdfs {
self.distCache(path)(initFn)
} else {
//TODO: should upload be asynchronous, blocking on context close
require(self.options.getStagingLocation != null,
val dfOptions = self.options.as(classOf[DataflowPipelineOptions])
require(dfOptions.getStagingLocation != null,
"Staging directory not set - use `--stagingLocation`!")
require(path != null, "Artifact path can't be null")

Expand All @@ -141,7 +143,7 @@ package object hdfs {

val targetDistCache = new Path("distcache", s"$targetHash-${path.split("/").last}")

val target = new Path(self.options.getStagingLocation, targetDistCache)
val target = new Path(dfOptions.getStagingLocation, targetDistCache)

if (username != null) {
UserGroupInformation.createRemoteUser(username).doAs(new PrivilegedAction[Unit] {
Expand All @@ -164,8 +166,9 @@ package object hdfs {
val inStream = fs.open(src)

//TODO: Should we attempt to detect the Mime type rather than always using MimeTypes.BINARY?
val dfOptions = self.options.as(classOf[DataflowPipelineOptions])
val outChannel = Channels.newOutputStream(
self.options.getGcsUtil.create(GcsPath.fromUri(target), MimeTypes.BINARY))
dfOptions.getGcsUtil.create(GcsPath.fromUri(target), MimeTypes.BINARY))

try {
ByteStreams.copy(inStream, outChannel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package com.spotify.scio.repl

import java.io.{OutputStream, PrintStream}

import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions
import com.google.cloud.dataflow.sdk.options.PipelineOptions
import com.spotify.scio.{ScioContext, ScioResult}

class ReplScioContext(options: DataflowPipelineOptions,
class ReplScioContext(options: PipelineOptions,
artifacts: List[String],
testId: Option[String])
extends ScioContext(options, artifacts, testId) {
Expand Down

0 comments on commit 1b9707a

Please sign in to comment.