From 5feea5b79903c85baf8677351ad65557efb523f3 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Fri, 26 Feb 2016 17:30:13 -0800 Subject: [PATCH] Implement InProcessPipelineRunner#run Appropriately construct an evaluation context and executor, and start the pipeline when run is called. Implement runner-provided ExecutionContext and StepContext abstractions. --- ...achedThreadPoolExecutorServiceFactory.java | 34 +++ .../ConsumerTrackingPipelineVisitor.java | 118 ++++++++++ .../inprocess/InProcessPipelineOptions.java | 29 +++ .../inprocess/InProcessPipelineRunner.java | 205 ++++++++++++++++-- .../InProcessPipelineRunnerTest.java | 77 +++++++ 5 files changed, 439 insertions(+), 24 deletions(-) create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java new file mode 100644 index 0000000000000..9e310711d1f89 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** + * A {@link DefaultValueFactory} that produces cached thread pools via + * {@link Executors#newCachedThreadPool()}. + */ +class CachedThreadPoolExecutorServiceFactory + implements DefaultValueFactory { + @Override + public ExecutorService create(PipelineOptions options) { + return Executors.newCachedThreadPool(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java new file mode 100644 index 0000000000000..2a6a8103809c8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PValue; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * Tracks the {@link AppliedPTransform AppliedPTransforms} that consume each {@link PValue} in the + * {@link Pipeline}. This is used to schedule consuming {@link PTransform PTransforms} to consume + * input after the upstream transform has produced and committed output. + */ +public class ConsumerTrackingPipelineVisitor implements PipelineVisitor { + private Map>> valueToConsumers = new HashMap<>(); + private Collection> rootTransforms = new ArrayList<>(); + private Collection> views = new ArrayList<>(); + private Map, String> stepNames = new HashMap<>(); + private Set toFinalize = new HashSet<>(); + private int numTransforms = 0; + + @Override + public void enterCompositeTransform(TransformTreeNode node) {} + + @Override + public void leaveCompositeTransform(TransformTreeNode node) {} + + @Override + public void visitTransform(TransformTreeNode node) { + toFinalize.removeAll(node.getInputs().keySet()); + AppliedPTransform appliedTransform = getAppliedTransform(node); + if (node.getInput().expand().isEmpty()) { + rootTransforms.add(appliedTransform); + } else { + for (PValue value : node.getInputs().keySet()) { + valueToConsumers.get(value).add(appliedTransform); + stepNames.put(appliedTransform, genStepName()); + } + } + } + + private AppliedPTransform getAppliedTransform(TransformTreeNode node) { + @SuppressWarnings({"rawtypes", "unchecked"}) + AppliedPTransform application = AppliedPTransform.of( + node.getFullName(), node.getInput(), node.getOutput(), (PTransform) node.getTransform()); + return application; + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + toFinalize.add(value); + for (PValue expandedValue : value.expand()) { + valueToConsumers.put(expandedValue, new ArrayList>()); + if (expandedValue instanceof PCollectionView) { + views.add((PCollectionView) expandedValue); + } + expandedValue.recordAsOutput(getAppliedTransform(producer)); + } + value.recordAsOutput(getAppliedTransform(producer)); + } + + private String genStepName() { + return String.format("s%s", numTransforms++); + } + + public Map>> getValueToConsumers() { + return valueToConsumers; + } + + public Map, String> getStepNames() { + return stepNames; + } + + public Collection> getRootTransforms() { + return rootTransforms; + } + + public Collection> getViews() { + return views; + } + + public Map>> getValueToCustomers() { + return valueToConsumers; + } + + /** + * @return + */ + public Set getUnfinalizedPValues() { + return toFinalize; + } +} + + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java index f06c9030d1786..1825da0b79edc 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java @@ -17,14 +17,43 @@ import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions; import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.Validation.Required; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.util.concurrent.ExecutorService; /** * Options that can be used to configure the {@link InProcessPipelineRunner}. */ public interface InProcessPipelineOptions extends PipelineOptions, ApplicationNameOptions { + @JsonIgnore + @Default.InstanceFactory(CachedThreadPoolExecutorServiceFactory.class) + ExecutorService getExecutorService(); + + void setExecutorService(ExecutorService executorService); + + /** + * Gets the {@link Clock} used by this pipeline. The clock is used in place of accessing the + * system time when time values are required by the evaluator. + */ @Default.InstanceFactory(NanosOffsetClock.Factory.class) + @Required + @Description( + "The processing time source used by the pipeline. When the current time is " + + "needed by the evaluator, the result of clock#now() is used.") Clock getClock(); void setClock(Clock clock); + + @Default.Boolean(false) + @Description("If the pipelien should block awaiting completion of the pipeline. If set to true, " + + "a call to Pipeline#run() will block until all PTransforms are complete. Otherwise, the " + + "Pipeline will execute asynchronously. If set to false, the completion of the pipeline can " + + "be awaited on by use of InProcessPipelineResult#awaitCompletion().") + boolean isBlockOnRun(); + + void setBlockOnRun(boolean b); } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java index c54a67d0373a6..8186ddf186886 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2015 Google Inc. + * Copyright (C) 2016 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of @@ -15,27 +15,47 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; -import static com.google.common.base.Preconditions.checkArgument; - +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; +import com.google.cloud.dataflow.sdk.PipelineResult; import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.io.Read; import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.AggregatorPipelineExtractor; +import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKey; -import com.google.cloud.dataflow.sdk.runners.inprocess.ViewEvaluatorFactory.InProcessCreatePCollectionView; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Flatten.FlattenPCollectionList; import com.google.cloud.dataflow.sdk.transforms.GroupByKey; import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; import com.google.cloud.dataflow.sdk.transforms.View.CreatePCollectionView; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.util.MapAggregatorValues; import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.UserCodeException; import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import org.joda.time.Instant; import java.util.Collection; +import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; import javax.annotation.Nullable; @@ -44,27 +64,33 @@ * {@link PCollection PCollections}. */ @Experimental -public class InProcessPipelineRunner { - @SuppressWarnings({"rawtypes", "unused"}) +public class InProcessPipelineRunner + extends PipelineRunner { + @SuppressWarnings("rawtypes") private static Map, Class> defaultTransformOverrides = ImmutableMap., Class>builder() + .put(Create.Values.class, InProcessCreate.class) .put(GroupByKey.class, InProcessGroupByKey.class) - .put(CreatePCollectionView.class, InProcessCreatePCollectionView.class) + .put( + CreatePCollectionView.class, + ViewEvaluatorFactory.InProcessCreatePCollectionView.class) .build(); - private static Map, TransformEvaluatorFactory> defaultEvaluatorFactories = - new ConcurrentHashMap<>(); - - /** - * Register a default transform evaluator. - */ - public static > void registerTransformEvaluatorFactory( - Class clazz, TransformEvaluatorFactory evaluator) { - checkArgument(defaultEvaluatorFactories.put(clazz, evaluator) == null, - "Defining a default factory %s to evaluate Transforms of type %s multiple times", evaluator, - clazz); - } + @SuppressWarnings("rawtypes") + private static Map, TransformEvaluatorFactory> + defaultEvaluatorFactories = + ImmutableMap., TransformEvaluatorFactory>builder() + .put(Read.Bounded.class, new BoundedReadEvaluatorFactory()) + .put(Read.Unbounded.class, new UnboundedReadEvaluatorFactory()) + .put(FlattenPCollectionList.class, new FlattenEvaluatorFactory()) + .put(ParDo.Bound.class, new ParDoSingleEvaluatorFactory()) + .put(ParDo.BoundMulti.class, new ParDoMultiEvaluatorFactory()) + .put(ViewEvaluatorFactory.WriteView.class, new ViewEvaluatorFactory()) + .put( + GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly.class, + new GroupByKeyEvaluatorFactory()) + .build(); /** * Part of a {@link PCollection}. Elements are output to a bundle, which will cause them to be @@ -75,13 +101,15 @@ public class InProcessPipelineRunner { */ public static interface UncommittedBundle { /** - * Returns the PCollection that the elements of this bundle belong to. + * Returns the PCollection that the elements of this {@link UncommittedBundle} belong to. */ PCollection getPCollection(); /** * Outputs an element to this bundle. * + * The bundle implementation is responsible for properly propagating + * * @param element the element to add to this bundle * @return this bundle */ @@ -105,7 +133,6 @@ public static interface UncommittedBundle { * @param the type of elements contained within this bundle */ public static interface CommittedBundle { - /** * Returns the PCollection that the elements of this bundle belong to. */ @@ -124,8 +151,8 @@ public static interface CommittedBundle { @Nullable Object getKey(); /** - * @return an {@link Iterable} containing all of the elements that have been added to this - * {@link CommittedBundle} + * Returns an {@link Iterable} containing all of the elements that have been added to this + * {@link CommittedBundle}. */ Iterable> getElements(); @@ -154,6 +181,12 @@ public static interface PCollectionViewWriter { //////////////////////////////////////////////////////////////////////////////////////////////// private final InProcessPipelineOptions options; + public static InProcessPipelineRunner createForTest() { + InProcessPipelineOptions options = PipelineOptionsFactory.as(InProcessPipelineOptions.class); + options.setBlockOnRun(true); + return new InProcessPipelineRunner(options); + } + public static InProcessPipelineRunner fromOptions(PipelineOptions options) { return new InProcessPipelineRunner(options.as(InProcessPipelineOptions.class)); } @@ -169,6 +202,130 @@ public InProcessPipelineOptions getPipelineOptions() { return options; } + @Override + public OutputT apply( + PTransform transform, InputT input) { + Class overrideClass = defaultTransformOverrides.get(transform.getClass()); + if (overrideClass != null) { + transform.validate(input); + // It is the responsibility of whoever constructs overrides to ensure this is type safe. + @SuppressWarnings("unchecked") + Class> transformClass = + (Class>) transform.getClass(); + + @SuppressWarnings("unchecked") + Class> customTransformClass = + (Class>) overrideClass; + + PTransform customTransform = + InstanceBuilder.ofType(customTransformClass) + .withArg(transformClass, transform) + .build(); + + // This overrides the contents of the apply method without changing the TransformTreeNode that + // is generated by the PCollection application. + return super.apply(customTransform, input); + } else { + return super.apply(transform, input); + } + } + + @Override + public InProcessPipelineResult run(Pipeline pipeline) { + ConsumerTrackingPipelineVisitor visitor = new ConsumerTrackingPipelineVisitor(); + pipeline.traverseTopologically(visitor); + for (PValue unfinalized : visitor.getUnfinalizedPValues()) { + unfinalized.finishSpecifying(); + } + + InProcessEvaluationContext context = + InProcessEvaluationContext.create(this, visitor.getRootTransforms(), + visitor.getValueToConsumers(), visitor.getStepNames(), visitor.getViews()); + + ExecutorService executorService = context.getPipelineOptions().getExecutorService(); + InProcessExecutor executor = + ExecutorServiceParallelExecutor.create(executorService, context); + executor.start(visitor.getRootTransforms()); + + Map, Collection>> aggregatorSteps = + new AggregatorPipelineExtractor(pipeline).getAggregatorSteps(); + InProcessPipelineResult result = + new InProcessPipelineResult(executor, context, aggregatorSteps); + if (options.isBlockOnRun()) { + try { + result.awaitCompletion(); + } catch (UserCodeException userException) { + throw new PipelineExecutionException(userException.getCause()); + } catch (Throwable t) { + Throwables.propagate(t); + } + } + return result; + } + + /** + * The result of running a {@link Pipeline} with the {@link InProcessPipelineRunner}. + * + * Throws {@link UnsupportedOperationException} for all methods. + */ + public static class InProcessPipelineResult implements PipelineResult { + private final InProcessExecutor executor; + private final InProcessEvaluationContext evaluationContext; + private final Map, Collection>> aggregatorSteps; + private State state; + + private InProcessPipelineResult( + InProcessExecutor executor, + InProcessEvaluationContext evaluationContext, + Map, Collection>> aggregatorSteps) { + this.executor = executor; + this.evaluationContext = evaluationContext; + this.aggregatorSteps = aggregatorSteps; + // Only ever constructed after the executor has started. + this.state = State.RUNNING; + } + + @Override + public State getState() { + return state; + } + + @Override + public AggregatorValues getAggregatorValues(Aggregator aggregator) + throws AggregatorRetrievalException { + CounterSet counters = evaluationContext.getCounters(); + Collection> steps = aggregatorSteps.get(aggregator); + Map stepValues = new HashMap<>(); + for (AppliedPTransform transform : evaluationContext.getSteps()) { + if (steps.contains(transform.getTransform())) { + String stepName = + String.format( + "user-%s-%s", evaluationContext.getStepName(transform), aggregator.getName()); + Counter counter = (Counter) counters.getExistingCounter(stepName); + if (counter == null) { + throw new IllegalArgumentException( + "Aggregator " + aggregator + " is not used in this pipeline"); + } + stepValues.put(transform.getFullName(), counter.getAggregate()); + } + } + return new MapAggregatorValues<>(stepValues); + } + + public State awaitCompletion() throws Throwable { + if (!state.isTerminal()) { + try { + executor.awaitCompletion(); + state = State.DONE; + } catch (Throwable t) { + state = State.FAILED; + throw t; + } + } + return state; + } + } + /** * An executor that schedules and executes {@link AppliedPTransform AppliedPTransforms} for both * source and intermediate {@link PTransform PTransforms}. diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java new file mode 100644 index 0000000000000..adb64cd62588f --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessPipelineResult; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.SimpleFunction; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Tests for basic {@link InProcessPipelineRunner} functionality. + */ +@RunWith(JUnit4.class) +public class InProcessPipelineRunnerTest implements Serializable { + @Test + public void wordCountShouldSucceed() throws Throwable { + Pipeline p = getPipeline(); + + PCollection> counts = + p.apply(Create.of("foo", "bar", "foo", "baz", "bar", "foo")) + .apply(MapElements.via(new SimpleFunction() { + @Override + public String apply(String input) { + return input; + } + })) + .apply(Count.perElement()); + PCollection countStrs = + counts.apply(MapElements.via(new SimpleFunction, String>() { + @Override + public String apply(KV input) { + String str = String.format("%s: %s", input.getKey(), input.getValue()); + return str; + } + })); + + DataflowAssert.that(countStrs).containsInAnyOrder("baz: 1", "bar: 2", "foo: 3"); + + InProcessPipelineResult result = ((InProcessPipelineResult) p.run()); + result.awaitCompletion(); + } + + private Pipeline getPipeline() { + PipelineOptions opts = PipelineOptionsFactory.create(); + opts.setRunner(InProcessPipelineRunner.class); + + Pipeline p = Pipeline.create(opts); + return p; + } +} +