diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd index 9a3d8d250b3de..ff65a3543ad22 100644 --- a/sdks/python/apache_beam/runners/common.pxd +++ b/sdks/python/apache_beam/runners/common.pxd @@ -71,7 +71,7 @@ cdef class DoFnSignature(object): cdef class DoFnInvoker(object): cdef public DoFnSignature signature - cdef OutputProcessor output_processor + cdef OutputHandler output_handler cdef object user_state_context cdef public object bundle_finalizer_param @@ -124,24 +124,47 @@ cdef class DoFnRunner: cpdef process(self, WindowedValue windowed_value) -cdef class OutputProcessor(object): +cdef class OutputHandler(object): @cython.locals(windowed_value=WindowedValue, output_element_count=int64_t) - cpdef process_outputs(self, WindowedValue element, results, - watermark_estimator=*) + cpdef handle_process_outputs(self, WindowedValue element, results, + watermark_estimator=*) + + @cython.locals(windowed_batch=WindowedBatch, + output_element_count=int64_t) + cpdef handle_process_batch_outputs(self, WindowedBatch input_batch, results, + watermark_estimator=*) -cdef class _OutputProcessor(OutputProcessor): +cdef class _OutputHandler(OutputHandler): cdef object window_fn cdef Receiver main_receivers cdef object tagged_receivers cdef DataflowDistributionCounter per_element_output_counter cdef object output_batch_converter + cdef bint _process_batch_yields_elements + cdef bint _process_yields_batches + + @cython.locals(windowed_value=WindowedValue, + windowed_batch=WindowedBatch, + output_element_count=int64_t) + cpdef handle_process_outputs(self, WindowedValue element, results, + watermark_estimator=*) @cython.locals(windowed_value=WindowedValue, + windowed_batch=WindowedBatch, output_element_count=int64_t) - cpdef process_outputs(self, WindowedValue element, results, - watermark_estimator=*) + cpdef handle_process_batch_outputs(self, WindowedBatch input_batch, results, + watermark_estimator=*) + + @cython.locals(windowed_value=WindowedValue) + cdef inline WindowedValue _maybe_propagate_windowing_info(self, WindowedValue input_element, result) + cdef inline tuple _handle_tagged_output(self, result) + cdef inline _write_value_to_tag(self, tag, WindowedValue windowed_value, + watermark_estimator) + cdef inline _write_batch_to_tag(self, tag, WindowedBatch windowed_batch, + watermark_estimator) + cdef inline _verify_batch_output(self, result) cdef class DoFnContext(object): cdef object label diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 594a7e59cf57f..53064dd23eb47 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -422,7 +422,7 @@ class DoFnInvoker(object): represented by a given DoFnSignature.""" def __init__(self, - output_processor, # type: _OutputProcessor + output_handler, # type: _OutputHandler signature # type: DoFnSignature ): # type: (...) -> None @@ -430,11 +430,11 @@ def __init__(self, """ Initializes `DoFnInvoker` - :param output_processor: an OutputProcessor for receiving elements produced + :param output_handler: an OutputHandler for receiving elements produced by invoking functions of the DoFn. :param signature: a DoFnSignature for the DoFn being invoked """ - self.output_processor = output_processor + self.output_handler = output_handler self.signature = signature self.user_state_context = None # type: Optional[userstate.UserStateContext] self.bundle_finalizer_param = None # type: Optional[core._BundleFinalizerParam] @@ -442,7 +442,7 @@ def __init__(self, @staticmethod def create_invoker( signature, # type: DoFnSignature - output_processor, # type: OutputProcessor + output_handler, # type: OutputHandler context=None, # type: Optional[DoFnContext] side_inputs=None, # type: Optional[List[sideinputs.SideInputMap]] input_args=None, input_kwargs=None, @@ -455,7 +455,7 @@ def create_invoker( """ Creates a new DoFnInvoker based on given arguments. Args: - output_processor: an OutputProcessor for receiving elements produced by + output_handler: an OutputHandler for receiving elements produced by invoking functions of the DoFn. signature: a DoFnSignature for the DoFn being invoked. context: Context to be used when invoking the DoFn (deprecated). @@ -482,12 +482,12 @@ def create_invoker( signature.process_method.defaults or signature.process_batch_method.defaults or signature.is_stateful_dofn()) if not use_per_window_invoker: - return SimpleInvoker(output_processor, signature) + return SimpleInvoker(output_handler, signature) else: if context is None: raise TypeError("Must provide context when not using SimpleInvoker") return PerWindowInvoker( - output_processor, + output_handler, signature, context, side_inputs, @@ -555,7 +555,7 @@ def invoke_start_bundle(self): """Invokes the DoFn.start_bundle() method. """ - self.output_processor.start_bundle_outputs( + self.output_handler.start_bundle_outputs( self.signature.start_bundle_method.method_value()) def invoke_finish_bundle(self): @@ -563,7 +563,7 @@ def invoke_finish_bundle(self): """Invokes the DoFn.finish_bundle() method. """ - self.output_processor.finish_bundle_outputs( + self.output_handler.finish_bundle_outputs( self.signature.finish_bundle_method.method_value()) def invoke_teardown(self): @@ -575,8 +575,8 @@ def invoke_teardown(self): def invoke_user_timer( self, timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag): - # self.output_processor is Optional, but in practice it won't be None here - self.output_processor.process_outputs( + # self.output_handler is Optional, but in practice it won't be None here + self.output_handler.handle_process_outputs( WindowedValue(None, timestamp, (window, )), self.signature.timer_methods[timer_spec].invoke_timer_callback( self.user_state_context, @@ -604,11 +604,11 @@ class SimpleInvoker(DoFnInvoker): """An invoker that processes elements ignoring windowing information.""" def __init__(self, - output_processor, # type: OutputProcessor + output_handler, # type: OutputHandler signature # type: DoFnSignature ): # type: (...) -> None - super().__init__(output_processor, signature) + super().__init__(output_handler, signature) self.process_method = signature.process_method.method_value self.process_batch_method = signature.process_batch_method.method_value @@ -620,7 +620,7 @@ def invoke_process(self, additional_kwargs=None ): # type: (...) -> Iterable[SplitResultResidual] - self.output_processor.process_outputs( + self.output_handler.handle_process_outputs( windowed_value, self.process_method(windowed_value.value)) return [] @@ -632,7 +632,7 @@ def invoke_process_batch(self, additional_kwargs=None ): # type: (...) -> None - self.output_processor.process_batch_outputs( + self.output_handler.handle_process_batch_outputs( windowed_batch, self.process_batch_method(windowed_batch.values)) @@ -711,7 +711,7 @@ class PerWindowInvoker(DoFnInvoker): """An invoker that processes elements considering windowing information.""" def __init__(self, - output_processor, # type: OutputProcessor + output_handler, # type: OutputHandler signature, # type: DoFnSignature context, # type: DoFnContext side_inputs, # type: Iterable[sideinputs.SideInputMap] @@ -720,7 +720,7 @@ def __init__(self, user_state_context, # type: Optional[userstate.UserStateContext] bundle_finalizer_param # type: Optional[core._BundleFinalizerParam] ): - super().__init__(output_processor, signature) + super().__init__(output_handler, signature) self.side_inputs = side_inputs self.context = context self.process_method = signature.process_method.method_value @@ -978,7 +978,7 @@ def _invoke_process_per_window(self, if additional_kwargs: kwargs_for_process.update(additional_kwargs) - self.output_processor.process_outputs( + self.output_handler.handle_process_outputs( windowed_value, self.process_method(*args_for_process, **kwargs_for_process), self.threadsafe_watermark_estimator) @@ -1062,7 +1062,7 @@ def _invoke_process_batch_per_window( kwargs_for_process_batch = kwargs_for_process_batch or {} - self.output_processor.process_batch_outputs( + self.output_handler.handle_process_batch_outputs( windowed_batch, self.process_batch_method( *args_for_process_batch, **kwargs_for_process_batch), @@ -1372,15 +1372,21 @@ def __init__(self, else: per_element_output_counter = None - # TODO(BEAM-14293): output processor assumes DoFns are batch-to-batch or - # element-to-element, @yields_batches and @yields_elements will break this - # assumption. - output_processor = _OutputProcessor( + output_handler = _OutputHandler( windowing.windowfn, main_receivers, tagged_receivers, per_element_output_counter, - getattr(fn, 'output_batch_converter', None)) + getattr(fn, 'output_batch_converter', None), + getattr( + do_fn_signature.process_method.method_value, + '_beam_yields_batches', + False), + getattr( + do_fn_signature.process_batch_method.method_value, + '_beam_yields_elements', + False), + ) if do_fn_signature.is_stateful_dofn() and not user_state_context: raise Exception( @@ -1390,7 +1396,7 @@ def __init__(self, self.do_fn_invoker = DoFnInvoker.create_invoker( do_fn_signature, - output_processor, + output_handler, self.context, side_inputs, args, @@ -1494,19 +1500,19 @@ def _reraise_augmented(self, exn): raise new_exn.with_traceback(tb) -class OutputProcessor(object): - def process_outputs( +class OutputHandler(object): + def handle_process_outputs( self, windowed_input_element, results, watermark_estimator=None): # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None raise NotImplementedError - def process_batch_outputs( + def handle_process_batch_outputs( self, windowed_input_element, results, watermark_estimator=None): # type: (WindowedBatch, Iterable[Any], Optional[WatermarkEstimator]) -> None raise NotImplementedError -class _OutputProcessor(OutputProcessor): +class _OutputHandler(OutputHandler): """Processes output produced by DoFn method invocations.""" def __init__(self, @@ -1515,8 +1521,10 @@ def __init__(self, tagged_receivers, # type: Mapping[Optional[str], Receiver] per_element_output_counter, output_batch_converter, # type: Optional[BatchConverter] + process_yields_batches, # type: bool, + process_batch_yields_elements, # type: bool, ): - """Initializes ``_OutputProcessor``. + """Initializes ``_OutputHandler``. Args: window_fn: a windowing function (WindowFn). @@ -1534,8 +1542,10 @@ def __init__(self, else: self.per_element_output_counter = None self.output_batch_converter = output_batch_converter + self._process_yields_batches = process_yields_batches + self._process_batch_yields_elements = process_batch_yields_elements - def process_outputs( + def handle_process_outputs( self, windowed_input_element, results, watermark_estimator=None): # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None @@ -1544,118 +1554,150 @@ def process_outputs( A value wrapped in a TaggedOutput object will be unwrapped and then dispatched to the appropriate indexed output. """ - if results is None: - # TODO(BEAM-3937): Remove if block after output counter released. - # Only enable per_element_output_counter when counter cythonized. - if self.per_element_output_counter is not None: - self.per_element_output_counter.add_input(0) - return + results = results or [] # TODO(BEAM-10782): Verify that the results object is a valid iterable type # if performance_runtime_type_check is active, without harming performance - output_element_count = 0 for result in results: - # results here may be a generator, which cannot call len on it. - output_element_count += 1 - tag = None - if isinstance(result, TaggedOutput): - tag = result.tag - if not isinstance(tag, str): - raise TypeError('In %s, tag %s is not a string' % (self, tag)) - result = result.value - if isinstance(result, WindowedValue): - windowed_value = result - if (windowed_input_element is not None and - len(windowed_input_element.windows) != 1): - windowed_value.windows *= len(windowed_input_element.windows) - elif isinstance(result, TimestampedValue): - assign_context = WindowFn.AssignContext(result.timestamp, result.value) - windowed_value = WindowedValue( - result.value, - result.timestamp, - self.window_fn.assign(assign_context)) - if len(windowed_input_element.windows) != 1: - windowed_value.windows *= len(windowed_input_element.windows) - else: - windowed_value = windowed_input_element.with_value(result) - if watermark_estimator is not None: - watermark_estimator.observe_timestamp(windowed_value.timestamp) - if tag is None: - self.main_receivers.receive(windowed_value) - else: - self.tagged_receivers[tag].receive(windowed_value) + tag, result = self._handle_tagged_output(result) + + if not self._process_yields_batches: + # process yields elements + windowed_value = self._maybe_propagate_windowing_info( + windowed_input_element, result) + + output_element_count += 1 + + self._write_value_to_tag(tag, windowed_value, watermark_estimator) + else: # process yields batches + self._verify_batch_output(result) + + if isinstance(result, WindowedBatch): + assert isinstance(result, HomogeneousWindowedBatch) + windowed_batch = result + + if (windowed_input_element is not None and + len(windowed_input_element.windows) != 1): + windowed_batch.windows *= len(windowed_input_element.windows) + else: + windowed_batch = ( + HomogeneousWindowedBatch.from_batch_and_windowed_value( + batch=result, windowed_value=windowed_input_element)) + + output_element_count += self.output_batch_converter.get_length( + windowed_batch.values) + + self._write_batch_to_tag(tag, windowed_batch, watermark_estimator) # TODO(BEAM-3937): Remove if block after output counter released. # Only enable per_element_output_counter when counter cythonized if self.per_element_output_counter is not None: self.per_element_output_counter.add_input(output_element_count) - def process_batch_outputs( + def handle_process_batch_outputs( self, windowed_input_batch, results, watermark_estimator=None): - # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None + # type: (WindowedBatch, Iterable[Any], Optional[WatermarkEstimator]) -> None - """Dispatch the result of process computation to the appropriate receivers. + """Dispatch the result of process_batch computation to the appropriate + receivers. A value wrapped in a TaggedOutput object will be unwrapped and then dispatched to the appropriate indexed output. """ - if results is None: - # TODO(BEAM-3937): Remove if block after output counter released. - # Only enable per_element_output_counter when counter cythonized. - if self.per_element_output_counter is not None: - self.per_element_output_counter.add_input(0) - return + results = results or [] + output_element_count = 0 + for result in results: + tag, result = self._handle_tagged_output(result) - # TODO(BEAM-10782): Verify that the results object is a valid iterable type - # if performance_runtime_type_check is active, without harming performance + if not self._process_batch_yields_elements: + # process_batch yields batches + assert self.output_batch_converter is not None - assert self.output_batch_converter is not None + self._verify_batch_output(result) - output_element_count = 0 - for result in results: - tag = None - if isinstance(result, TaggedOutput): - tag = result.tag - if not isinstance(tag, str): - raise TypeError('In %s, tag %s is not a string' % (self, tag)) - result = result.value - if isinstance(result, (WindowedValue, TimestampedValue)): - raise TypeError( - f"Received {type(result).__name__} from DoFn that was " - "expected to produce a batch.") - if isinstance(result, WindowedBatch): - assert isinstance(result, HomogeneousWindowedBatch) - windowed_batch = result - - if (windowed_input_batch is not None and - len(windowed_input_batch.windows) != 1): - windowed_batch.windows *= len(windowed_input_batch.windows) - # TODO(BEAM-14352): Add TimestampedBatch, an analogue for TimestampedValue - # and handle it here (see TimestampedValue logic in process_outputs). - else: - # TODO: This should error unless the DoFn was defined with - # @DoFn.yields_batches(output_aligned_with_input=True) - # We should consider also validating that the length is the same as - # windowed_input_batch - windowed_batch = windowed_input_batch.with_values(result) - - output_element_count += self.output_batch_converter.get_length( - windowed_input_batch.values) - - if watermark_estimator is not None: - for timestamp in windowed_batch.timestamps: - watermark_estimator.observe_timestamp(timestamp) - if tag is None: - self.main_receivers.receive_batch(windowed_batch) - else: - self.tagged_receivers[tag].receive_batch(windowed_batch) + if isinstance(result, WindowedBatch): + assert isinstance(result, HomogeneousWindowedBatch) + windowed_batch = result + + if (windowed_input_batch is not None and + len(windowed_input_batch.windows) != 1): + windowed_batch.windows *= len(windowed_input_batch.windows) + else: + windowed_batch = windowed_input_batch.with_values(result) + + output_element_count += self.output_batch_converter.get_length( + windowed_batch.values) + + self._write_batch_to_tag(tag, windowed_batch, watermark_estimator) + else: # process_batch yields elements + assert isinstance(windowed_input_batch, HomogeneousWindowedBatch) + + windowed_value = self._maybe_propagate_windowing_info( + windowed_input_batch.as_empty_windowed_value(), result) + + output_element_count += 1 + + self._write_value_to_tag(tag, windowed_value, watermark_estimator) # TODO(BEAM-3937): Remove if block after output counter released. # Only enable per_element_output_counter when counter cythonized if self.per_element_output_counter is not None: self.per_element_output_counter.add_input(output_element_count) + def _maybe_propagate_windowing_info(self, windowed_input_element, result): + # type: (WindowedValue, Any) -> WindowedValue + if isinstance(result, WindowedValue): + windowed_value = result + if (windowed_input_element is not None and + len(windowed_input_element.windows) != 1): + windowed_value.windows *= len(windowed_input_element.windows) + return windowed_value + + elif isinstance(result, TimestampedValue): + assign_context = WindowFn.AssignContext(result.timestamp, result.value) + windowed_value = WindowedValue( + result.value, result.timestamp, self.window_fn.assign(assign_context)) + if len(windowed_input_element.windows) != 1: + windowed_value.windows *= len(windowed_input_element.windows) + return windowed_value + + else: + return windowed_input_element.with_value(result) + + def _handle_tagged_output(self, result): + if isinstance(result, TaggedOutput): + tag = result.tag + if not isinstance(tag, str): + raise TypeError('In %s, tag %s is not a string' % (self, tag)) + return tag, result.value + return None, result + + def _write_value_to_tag(self, tag, windowed_value, watermark_estimator): + if watermark_estimator is not None: + watermark_estimator.observe_timestamp(windowed_value.timestamp) + + if tag is None: + self.main_receivers.receive(windowed_value) + else: + self.tagged_receivers[tag].receive(windowed_value) + + def _write_batch_to_tag(self, tag, windowed_batch, watermark_estimator): + if watermark_estimator is not None: + for timestamp in windowed_batch.timestamps: + watermark_estimator.observe_timestamp(timestamp) + + if tag is None: + self.main_receivers.receive_batch(windowed_batch) + else: + self.tagged_receivers[tag].receive_batch(windowed_batch) + + def _verify_batch_output(self, result): + if isinstance(result, (WindowedValue, TimestampedValue)): + raise TypeError( + f"Received {type(result).__name__} from DoFn that was " + "expected to produce a batch.") + def start_bundle_outputs(self, results): """Validate that start_bundle does not output any elements""" if results is None: diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py index 381394ef7221b..57510be749d72 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py @@ -37,7 +37,7 @@ from apache_beam.runners.common import DoFnContext from apache_beam.runners.common import DoFnInvoker from apache_beam.runners.common import DoFnSignature -from apache_beam.runners.common import OutputProcessor +from apache_beam.runners.common import OutputHandler from apache_beam.runners.direct.evaluation_context import DirectStepContext from apache_beam.runners.direct.util import KeyedWorkItem from apache_beam.runners.direct.watermark_manager import WatermarkManager @@ -121,7 +121,7 @@ def __init__(self, do_fn): def start_bundle(self): self._invoker = DoFnInvoker.create_invoker( self._signature, - output_processor=_NoneShallPassOutputProcessor(), + output_processor=_NoneShallPassOutputHandler(), process_invocation=False) def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs): @@ -142,7 +142,7 @@ def start_bundle(self): signature = DoFnSignature(self._do_fn) self._invoker = DoFnInvoker.create_invoker( signature, - output_processor=_NoneShallPassOutputProcessor(), + output_processor=_NoneShallPassOutputHandler(), process_invocation=False) def process(self, element_and_restriction, *args, **kwargs): @@ -268,7 +268,7 @@ def __init__(self, sdf, args_for_invoker, kwargs_for_invoker): 'watermark_estimator_state') self.watermark_hold_tag = _ReadModifyWriteStateTag('watermark_hold') self._process_element_invoker = None - self._output_processor = _OutputProcessor() + self._output_processor = _OutputHandler() self.sdf_invoker = DoFnInvoker.create_invoker( DoFnSignature(self.sdf), @@ -536,7 +536,7 @@ def initiate_checkpoint(): yield result -class _OutputProcessor(OutputProcessor): +class _OutputHandler(OutputHandler): def __init__(self): self.output_iter = None @@ -549,7 +549,7 @@ def reset(self): self.output_iter = None -class _NoneShallPassOutputProcessor(OutputProcessor): +class _NoneShallPassOutputHandler(OutputHandler): def process_outputs( self, windowed_input_element, output_iter, watermark_estimator=None): # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py index 07b5486e9bad8..11e126de798f8 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py @@ -293,6 +293,51 @@ def infer_output_type(self, input_type): 9*9 # [ 9, 14) ])) + def test_batch_to_element_pardo(self): + class ArraySumDoFn(beam.DoFn): + @beam.DoFn.yields_elements + def process_batch(self, batch: np.ndarray, *unused_args, + **unused_kwargs) -> Iterator[np.int64]: + yield batch.sum() + + def infer_output_type(self, input_type): + assert input_type == np.int64 + return np.int64 + + with self.create_pipeline() as p: + res = ( + p + | beam.Create(np.array(range(100), dtype=np.int64)).with_output_types( + np.int64) + | beam.ParDo(ArrayMultiplyDoFn()) + | beam.ParDo(ArraySumDoFn()) + | beam.CombineGlobally(sum)) + + assert_that(res, equal_to([99 * 50 * 2])) + + def test_element_to_batch_pardo(self): + class ArrayProduceDoFn(beam.DoFn): + @beam.DoFn.yields_batches + def process(self, element: np.int64, *unused_args, + **unused_kwargs) -> Iterator[np.ndarray]: + yield np.array([element] * int(element)) + + # infer_output_type must be defined (when there's no process method), + # otherwise we don't know the input type is the same as output type. + def infer_output_type(self, input_type): + return np.int64 + + with self.create_pipeline() as p: + res = ( + p + | beam.Create(np.array([1, 2, 3], dtype=np.int64)).with_output_types( + np.int64) + | beam.ParDo(ArrayProduceDoFn()) + | beam.ParDo(ArrayMultiplyDoFn()) + | beam.Map(lambda x: x * 3)) + + assert_that(res, equal_to([6, 12, 12, 18, 18, 18])) + def test_pardo_large_input(self): try: utils.check_compiled('apache_beam.coders.coder_impl') diff --git a/sdks/python/apache_beam/transforms/batch_dofn_test.py b/sdks/python/apache_beam/transforms/batch_dofn_test.py index f1fc7eda09394..d6b7d2dc9b2a2 100644 --- a/sdks/python/apache_beam/transforms/batch_dofn_test.py +++ b/sdks/python/apache_beam/transforms/batch_dofn_test.py @@ -22,6 +22,7 @@ import unittest from typing import Iterator from typing import List +from typing import Tuple from typing import no_type_check from parameterized import parameterized_class @@ -54,6 +55,19 @@ def process_batch(self, batch: List[int], *args, yield [element / 2 for element in batch] +class ElementToBatchDoFn(beam.DoFn): + @beam.DoFn.yields_batches + def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]: + yield [element] * element + + +class BatchToElementDoFn(beam.DoFn): + @beam.DoFn.yields_elements + def process_batch(self, batch: List[int], *args, + **kwargs) -> Iterator[Tuple[int, int]]: + yield (sum(batch), len(batch)) + + def get_test_class_name(cls, num, params_dict): return "%s_%s" % (cls.__name__, params_dict['dofn'].__class__.__name__) @@ -61,47 +75,68 @@ def get_test_class_name(cls, num, params_dict): @parameterized_class([ { "dofn": ElementDoFn(), - "process_defined": True, - "process_batch_defined": False, - "input_batch_type": None, - "output_batch_type": None + "expected_process_defined": True, + "expected_process_batch_defined": False, + "expected_input_batch_type": None, + "expected_output_batch_type": None }, { "dofn": BatchDoFn(), - "process_defined": False, - "process_batch_defined": True, - "input_batch_type": beam.typehints.List[int], - "output_batch_type": beam.typehints.List[float] + "expected_process_defined": False, + "expected_process_batch_defined": True, + "expected_input_batch_type": beam.typehints.List[int], + "expected_output_batch_type": beam.typehints.List[float] }, { "dofn": BatchDoFnNoReturnAnnotation(), - "process_defined": False, - "process_batch_defined": True, - "input_batch_type": beam.typehints.List[int], - "output_batch_type": beam.typehints.List[int] + "expected_process_defined": False, + "expected_process_batch_defined": True, + "expected_input_batch_type": beam.typehints.List[int], + "expected_output_batch_type": beam.typehints.List[int] }, { "dofn": EitherDoFn(), - "process_defined": True, - "process_batch_defined": True, - "input_batch_type": beam.typehints.List[int], - "output_batch_type": beam.typehints.List[float] + "expected_process_defined": True, + "expected_process_batch_defined": True, + "expected_input_batch_type": beam.typehints.List[int], + "expected_output_batch_type": beam.typehints.List[float] + }, + { + "dofn": ElementToBatchDoFn(), + "expected_process_defined": True, + "expected_process_batch_defined": False, + "expected_input_batch_type": None, + "expected_output_batch_type": beam.typehints.List[int] + }, + { + "dofn": BatchToElementDoFn(), + "expected_process_defined": False, + "expected_process_batch_defined": True, + "expected_input_batch_type": beam.typehints.List[int], + "expected_output_batch_type": None, }, ], class_name_func=get_test_class_name) class BatchDoFnParameterizedTest(unittest.TestCase): def test_process_defined(self): - self.assertEqual(self.dofn.process_defined, self.process_defined) + self.assertEqual(self.dofn.process_defined, self.expected_process_defined) def test_process_batch_defined(self): self.assertEqual( - self.dofn.process_batch_defined, self.process_batch_defined) + self.dofn.process_batch_defined, self.expected_process_batch_defined) def test_get_input_batch_type(self): - self.assertEqual(self.dofn.get_input_batch_type(), self.input_batch_type) + self.assertEqual( + self.dofn.get_input_batch_type(), self.expected_input_batch_type) def test_get_output_batch_type(self): - self.assertEqual(self.dofn.get_output_batch_type(), self.output_batch_type) + self.assertEqual( + self.dofn.get_output_batch_type(beam.typehints.Any), + self.expected_output_batch_type) + + def test_can_yield_batches(self): + expected = self.expected_output_batch_type is not None + self.assertEqual(self.dofn.can_yield_batches, expected) class BatchDoFnNoInputAnnotation(beam.DoFn): @@ -118,7 +153,7 @@ def test_map_pardo(self): self.assertTrue(dofn.process_defined) self.assertFalse(dofn.process_batch_defined) self.assertEqual(dofn.get_input_batch_type(), None) - self.assertEqual(dofn.get_output_batch_type(), None) + self.assertEqual(dofn.get_output_batch_type(int), None) def test_no_input_annotation_raises(self): p = beam.Pipeline() diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index b544c03c7442f..5b2c43ad9e458 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -581,6 +581,26 @@ def wrapper(process_fn): return wrapper + @staticmethod + def yields_elements(fn): + if not fn.__name__ in ('process', 'process_batch'): + raise TypeError( + "@yields_elements must be applied to a process or " + f"process_batch method, got {fn!r}.") + + fn._beam_yields_elements = True + return fn + + @staticmethod + def yields_batches(fn): + if not fn.__name__ in ('process', 'process_batch'): + raise TypeError( + "@yields_elements must be applied to a process or " + f"process_batch method, got {fn!r}.") + + fn._beam_yields_batches = True + return fn + def default_label(self): return self.__class__.__name__ @@ -703,6 +723,20 @@ def process_batch_defined(self) -> bool: if hasattr(self.process_batch, '__self__') else self.process_batch) != DoFn.process_batch + @property + def can_yield_batches(self) -> bool: + return ( + (self.process_defined and self.process_yields_batches) or + (self.process_batch_defined and not self.process_batch_yields_elements)) + + @property + def process_yields_batches(self) -> bool: + return getattr(self.process, '_beam_yields_batches', False) + + @property + def process_batch_yields_elements(self) -> bool: + return getattr(self.process_batch, '_beam_yields_elements', False) + def get_input_batch_type(self) -> typing.Optional[TypeConstraint]: if not self.process_batch_defined: return None @@ -717,14 +751,12 @@ def get_input_batch_type(self) -> typing.Optional[TypeConstraint]: "process_batch implementations.") return typehints.native_type_compatibility.convert_to_beam_type(input_type) - def get_output_batch_type(self) -> typing.Optional[TypeConstraint]: - if not self.process_batch_defined: - return None - return_type = inspect.signature(self.process_batch).return_annotation + @staticmethod + def _get_element_type_from_return_annotation(method, input_type): + return_type = inspect.signature(method).return_annotation if return_type == inspect.Signature.empty: # output type not annotated, try to infer it - return_type = trivial_inference.infer_return_type( - self.process_batch, [self.get_input_batch_type()]) + return_type = trivial_inference.infer_return_type(method, [input_type]) return_type = typehints.native_type_compatibility.convert_to_beam_type( return_type) @@ -734,8 +766,34 @@ def get_output_batch_type(self) -> typing.Optional[TypeConstraint]: return return_type.yielded_type else: raise TypeError( - "Expected Iterator return type annotation, did you mean " - f"Iterator[{return_type}]") + "Expected Iterator in return type annotation for " + f"{method!r}, did you mean Iterator[{return_type}]?") + + def get_output_batch_type( + self, input_element_type) -> typing.Optional[TypeConstraint]: + output_batch_type = None + if self.process_defined and self.process_yields_batches: + # TODO: Use the element_type passed to infer_output_type instead of + # typehints.Any + output_batch_type = self._get_element_type_from_return_annotation( + self.process, input_element_type) + if self.process_batch_defined and not self.process_batch_yields_elements: + process_batch_type = self._get_element_type_from_return_annotation( + self.process_batch, self.get_input_batch_type()) + + # TODO: Consider requiring an inheritance relationship rather than + # equality + if (output_batch_type is not None and + (not process_batch_type == output_batch_type)): + raise TypeError( + f"DoFn {self!r} yields batches from both process and " + "process_batch, but they produce different types:\n" + f" process: {output_batch_type}\n" + f" process_batch: {process_batch_type!r}") + + output_batch_type = process_batch_type + + return output_batch_type def _strip_output_annotations(self, type_hint): annotations = (TimestampedValue, WindowedValue, pvalue.TaggedOutput) @@ -1355,8 +1413,7 @@ def infer_output_type(self, input_type): return self.fn.infer_output_type(input_type) def infer_batch_converters(self, input_element_type): - # This assumes batch input implies batch output - # TODO(BEAM-14293): Define and handle yields_batches and yields_elements + # TODO: Test this code (in batch_dofn_test) if self.fn.process_batch_defined: input_batch_type = self.fn.get_input_batch_type() @@ -1365,30 +1422,28 @@ def infer_batch_converters(self, input_element_type): "process_batch method on {self.fn!r} does not have " "an input type annoation") - output_batch_type = self.fn.get_output_batch_type() - if output_batch_type is None: - raise TypeError( - "process_batch method on {self.fn!r} does not have " - "a return type annoation") - # Generate a batch converter to convert between the input type and the # (batch) input type of process_batch self.fn.input_batch_converter = BatchConverter.from_typehints( element_type=input_element_type, batch_type=input_batch_type) + else: + self.fn.input_batch_converter = None + + if self.fn.can_yield_batches: + output_batch_type = self.fn.get_output_batch_type(input_element_type) + if output_batch_type is None: + # TODO: Mention process method in this error + raise TypeError( + f"process_batch method on {self.fn!r} does not have " + "a return type annoation") # Generate a batch converter to convert between the output type and the # (batch) output type of process_batch output_element_type = self.infer_output_type(input_element_type) self.fn.output_batch_converter = BatchConverter.from_typehints( element_type=output_element_type, batch_type=output_batch_type) - - def infer_output_batch_type(self): - # TODO(BEAM-14293): Handle process() with @yields_batch - if not self.fn.process_batch_defined: - return - - batch_type = self.fn.get_output_batch_type() - return batch_type + else: + self.fn.output_batch_converter = None def make_fn(self, fn, has_side_inputs): if isinstance(fn, DoFn): @@ -1427,8 +1482,7 @@ def expand(self, pcoll): key_coder, self) - if self.dofn.process_batch_defined: - self.infer_batch_converters(pcoll.element_type) + self.infer_batch_converters(pcoll.element_type) return pvalue.PCollection.from_(pcoll) diff --git a/sdks/python/apache_beam/utils/windowed_value.py b/sdks/python/apache_beam/utils/windowed_value.py index d80becb41c01c..7e864c3764e25 100644 --- a/sdks/python/apache_beam/utils/windowed_value.py +++ b/sdks/python/apache_beam/utils/windowed_value.py @@ -354,6 +354,12 @@ def as_windowed_values(self, explode_fn: Callable) -> Iterable[WindowedValue]: for value in explode_fn(self._wv.value): yield self._wv.with_value(value) + def as_empty_windowed_value(self): + """Get a single WindowedValue with identical windowing information to this + HomogeneousWindowedBatch, but with value=None. Useful for re-using APIs that + pull windowing information from a WindowedValue.""" + return self._wv.with_value(None) + def __eq__(self, other): if isinstance(other, HomogeneousWindowedBatch): return self._wv == other._wv @@ -362,6 +368,11 @@ def __eq__(self, other): def __hash__(self): return hash(self._wv) + @staticmethod + def from_batch_and_windowed_value( + *, batch, windowed_value: WindowedValue) -> 'WindowedBatch': + return HomogeneousWindowedBatch(windowed_value.with_value(batch)) + @staticmethod def from_windowed_values( windowed_values: Sequence[WindowedValue], *,