From 8a2d9b68a3df872e8caec86b6eedbb702e737c66 Mon Sep 17 00:00:00 2001 From: Brian Hulette Date: Tue, 14 Jun 2022 14:33:29 -0700 Subject: [PATCH] Document and test overriding batch type inference (#21844) * Document and test overriding batch type inference * address review comments * Update sdks/python/apache_beam/transforms/core.py Co-authored-by: Andy Ye Co-authored-by: Andy Ye --- .../fn_api_runner/fn_runner_test.py | 29 +++++++++ .../apache_beam/transforms/batch_dofn_test.py | 34 ++++++++-- sdks/python/apache_beam/transforms/core.py | 64 +++++++++++++++---- 3 files changed, 112 insertions(+), 15 deletions(-) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py index 11e126de798f8..c65578eba2589 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py @@ -136,6 +136,35 @@ def test_batch_pardo(self): assert_that(res, equal_to([6, 12, 18])) + def test_batch_pardo_override_type_inference(self): + class ArrayMultiplyDoFnOverride(beam.DoFn): + def process_batch(self, batch, *unused_args, + **unused_kwargs) -> Iterator[np.ndarray]: + assert isinstance(batch, np.ndarray) + yield batch * 2 + + # infer_output_type must be defined (when there's no process method), + # otherwise we don't know the input type is the same as output type. + def infer_output_type(self, input_type): + return input_type + + def get_input_batch_type(self, input_element_type): + from apache_beam.typehints.batch import NumpyArray + return NumpyArray[input_element_type] + + def get_output_batch_type(self, input_element_type): + return self.get_input_batch_type(input_element_type) + + with self.create_pipeline() as p: + res = ( + p + | beam.Create(np.array([1, 2, 3], dtype=np.int64)).with_output_types( + np.int64) + | beam.ParDo(ArrayMultiplyDoFnOverride()) + | beam.Map(lambda x: x * 3)) + + assert_that(res, equal_to([6, 12, 18])) + def test_batch_pardo_trigger_flush(self): try: utils.check_compiled('apache_beam.coders.coder_impl') diff --git a/sdks/python/apache_beam/transforms/batch_dofn_test.py b/sdks/python/apache_beam/transforms/batch_dofn_test.py index d6b7d2dc9b2a2..1924052121c05 100644 --- a/sdks/python/apache_beam/transforms/batch_dofn_test.py +++ b/sdks/python/apache_beam/transforms/batch_dofn_test.py @@ -46,6 +46,17 @@ def process_batch(self, batch: List[int], *args, **kwargs): yield [element * 2 for element in batch] +class BatchDoFnOverrideTypeInference(beam.DoFn): + def process_batch(self, batch, *args, **kwargs): + yield [element * 2 for element in batch] + + def get_input_batch_type(self, input_element_type): + return List[input_element_type] + + def get_output_batch_type(self, input_element_type): + return List[input_element_type] + + class EitherDoFn(beam.DoFn): def process(self, element: int, *args, **kwargs) -> Iterator[float]: yield element / 2 @@ -75,6 +86,7 @@ def get_test_class_name(cls, num, params_dict): @parameterized_class([ { "dofn": ElementDoFn(), + "input_element_type": int, "expected_process_defined": True, "expected_process_batch_defined": False, "expected_input_batch_type": None, @@ -82,6 +94,7 @@ def get_test_class_name(cls, num, params_dict): }, { "dofn": BatchDoFn(), + "input_element_type": int, "expected_process_defined": False, "expected_process_batch_defined": True, "expected_input_batch_type": beam.typehints.List[int], @@ -89,6 +102,15 @@ def get_test_class_name(cls, num, params_dict): }, { "dofn": BatchDoFnNoReturnAnnotation(), + "input_element_type": int, + "expected_process_defined": False, + "expected_process_batch_defined": True, + "expected_input_batch_type": beam.typehints.List[int], + "expected_output_batch_type": beam.typehints.List[int] + }, + { + "dofn": BatchDoFnOverrideTypeInference(), + "input_element_type": int, "expected_process_defined": False, "expected_process_batch_defined": True, "expected_input_batch_type": beam.typehints.List[int], @@ -96,6 +118,7 @@ def get_test_class_name(cls, num, params_dict): }, { "dofn": EitherDoFn(), + "input_element_type": int, "expected_process_defined": True, "expected_process_batch_defined": True, "expected_input_batch_type": beam.typehints.List[int], @@ -103,6 +126,7 @@ def get_test_class_name(cls, num, params_dict): }, { "dofn": ElementToBatchDoFn(), + "input_element_type": int, "expected_process_defined": True, "expected_process_batch_defined": False, "expected_input_batch_type": None, @@ -110,6 +134,7 @@ def get_test_class_name(cls, num, params_dict): }, { "dofn": BatchToElementDoFn(), + "input_element_type": int, "expected_process_defined": False, "expected_process_batch_defined": True, "expected_input_batch_type": beam.typehints.List[int], @@ -127,11 +152,12 @@ def test_process_batch_defined(self): def test_get_input_batch_type(self): self.assertEqual( - self.dofn.get_input_batch_type(), self.expected_input_batch_type) + self.dofn._get_input_batch_type_normalized(self.input_element_type), + self.expected_input_batch_type) def test_get_output_batch_type(self): self.assertEqual( - self.dofn.get_output_batch_type(beam.typehints.Any), + self.dofn._get_output_batch_type_normalized(self.input_element_type), self.expected_output_batch_type) def test_can_yield_batches(self): @@ -152,8 +178,8 @@ def test_map_pardo(self): self.assertTrue(dofn.process_defined) self.assertFalse(dofn.process_batch_defined) - self.assertEqual(dofn.get_input_batch_type(), None) - self.assertEqual(dofn.get_output_batch_type(int), None) + self.assertEqual(dofn._get_input_batch_type_normalized(int), None) + self.assertEqual(dofn._get_output_batch_type_normalized(int), None) def test_no_input_annotation_raises(self): p = beam.Pipeline() diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index c5b55574c7e8c..fdc458bce0b3d 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -737,7 +737,23 @@ def process_yields_batches(self) -> bool: def process_batch_yields_elements(self) -> bool: return getattr(self.process_batch, '_beam_yields_elements', False) - def get_input_batch_type(self) -> typing.Optional[TypeConstraint]: + def get_input_batch_type( + self, input_element_type + ) -> typing.Optional[typing.Union[TypeConstraint, type]]: + """Determine the batch type expected as input to process_batch. + + The default implementation of ``get_input_batch_type`` simply observes the + input typehint for the first parameter of ``process_batch``. A Batched DoFn + may override this method if a dynamic approach is required. + + Args: + input_element_type: The **element type** of the input PCollection this + DoFn is being applied to. + + Returns: + ``None`` if this DoFn cannot accept batches, else a Beam typehint or + a native Python typehint. + """ if not self.process_batch_defined: return None input_type = list( @@ -746,10 +762,18 @@ def get_input_batch_type(self) -> typing.Optional[TypeConstraint]: # TODO(BEAM-14340): Consider supporting an alternative (dynamic?) approach # for declaring input type raise TypeError( - f"{self.__class__.__name__}.process_batch() does not have a type " - "annotation on its first parameter. This is required for " - "process_batch implementations.") - return typehints.native_type_compatibility.convert_to_beam_type(input_type) + f"Either {self.__class__.__name__}.process_batch() must have a type " + f"annotation on its first parameter, or {self.__class__.__name__} " + "must override get_input_batch_type.") + return input_type + + def _get_input_batch_type_normalized(self, input_element_type): + return typehints.native_type_compatibility.convert_to_beam_type( + self.get_input_batch_type(input_element_type)) + + def _get_output_batch_type_normalized(self, input_element_type): + return typehints.native_type_compatibility.convert_to_beam_type( + self.get_output_batch_type(input_element_type)) @staticmethod def _get_element_type_from_return_annotation(method, input_type): @@ -770,16 +794,32 @@ def _get_element_type_from_return_annotation(method, input_type): f"{method!r}, did you mean Iterator[{return_type}]?") def get_output_batch_type( - self, input_element_type) -> typing.Optional[TypeConstraint]: + self, input_element_type + ) -> typing.Optional[typing.Union[TypeConstraint, type]]: + """Determine the batch type produced by this DoFn's ``process_batch`` + implementation and/or its ``process`` implementation with + ``@yields_batch``. + + The default implementation of this method observes the return type + annotations on ``process_batch`` and/or ``process``. A Batched DoFn may + override this method if a dynamic approach is required. + + Args: + input_element_type: The **element type** of the input PCollection this + DoFn is being applied to. + + Returns: + ``None`` if this DoFn will never yield batches, else a Beam typehint or + a native Python typehint. + """ output_batch_type = None if self.process_defined and self.process_yields_batches: - # TODO: Use the element_type passed to infer_output_type instead of - # typehints.Any output_batch_type = self._get_element_type_from_return_annotation( self.process, input_element_type) if self.process_batch_defined and not self.process_batch_yields_elements: process_batch_type = self._get_element_type_from_return_annotation( - self.process_batch, self.get_input_batch_type()) + self.process_batch, + self._get_input_batch_type_normalized(input_element_type)) # TODO: Consider requiring an inheritance relationship rather than # equality @@ -1415,7 +1455,8 @@ def infer_output_type(self, input_type): def infer_batch_converters(self, input_element_type): # TODO: Test this code (in batch_dofn_test) if self.fn.process_batch_defined: - input_batch_type = self.fn.get_input_batch_type() + input_batch_type = self.fn._get_input_batch_type_normalized( + input_element_type) if input_batch_type is None: raise TypeError( @@ -1430,7 +1471,8 @@ def infer_batch_converters(self, input_element_type): self.fn.input_batch_converter = None if self.fn.can_yield_batches: - output_batch_type = self.fn.get_output_batch_type(input_element_type) + output_batch_type = self.fn._get_output_batch_type_normalized( + input_element_type) if output_batch_type is None: # TODO: Mention process method in this error raise TypeError(