Skip to content

Commit

Permalink
Document and test overriding batch type inference (#21844)
Browse files Browse the repository at this point in the history
* Document and test overriding batch type inference

* address review comments

* Update sdks/python/apache_beam/transforms/core.py

Co-authored-by: Andy Ye <andyye333@gmail.com>

Co-authored-by: Andy Ye <andyye333@gmail.com>
  • Loading branch information
TheNeuralBit and yeandy authored Jun 14, 2022
1 parent e7c021d commit 5f04b97
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,35 @@ def test_batch_pardo(self):

assert_that(res, equal_to([6, 12, 18]))

def test_batch_pardo_override_type_inference(self):
class ArrayMultiplyDoFnOverride(beam.DoFn):
def process_batch(self, batch, *unused_args,
**unused_kwargs) -> Iterator[np.ndarray]:
assert isinstance(batch, np.ndarray)
yield batch * 2

# infer_output_type must be defined (when there's no process method),
# otherwise we don't know the input type is the same as output type.
def infer_output_type(self, input_type):
return input_type

def get_input_batch_type(self, input_element_type):
from apache_beam.typehints.batch import NumpyArray
return NumpyArray[input_element_type]

def get_output_batch_type(self, input_element_type):
return self.get_input_batch_type(input_element_type)

with self.create_pipeline() as p:
res = (
p
| beam.Create(np.array([1, 2, 3], dtype=np.int64)).with_output_types(
np.int64)
| beam.ParDo(ArrayMultiplyDoFnOverride())
| beam.Map(lambda x: x * 3))

assert_that(res, equal_to([6, 12, 18]))

def test_batch_pardo_trigger_flush(self):
try:
utils.check_compiled('apache_beam.coders.coder_impl')
Expand Down
34 changes: 30 additions & 4 deletions sdks/python/apache_beam/transforms/batch_dofn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ def process_batch(self, batch: List[int], *args, **kwargs):
yield [element * 2 for element in batch]


class BatchDoFnOverrideTypeInference(beam.DoFn):
def process_batch(self, batch, *args, **kwargs):
yield [element * 2 for element in batch]

def get_input_batch_type(self, input_element_type):
return List[input_element_type]

def get_output_batch_type(self, input_element_type):
return List[input_element_type]


class EitherDoFn(beam.DoFn):
def process(self, element: int, *args, **kwargs) -> Iterator[float]:
yield element / 2
Expand Down Expand Up @@ -75,41 +86,55 @@ def get_test_class_name(cls, num, params_dict):
@parameterized_class([
{
"dofn": ElementDoFn(),
"input_element_type": int,
"expected_process_defined": True,
"expected_process_batch_defined": False,
"expected_input_batch_type": None,
"expected_output_batch_type": None
},
{
"dofn": BatchDoFn(),
"input_element_type": int,
"expected_process_defined": False,
"expected_process_batch_defined": True,
"expected_input_batch_type": beam.typehints.List[int],
"expected_output_batch_type": beam.typehints.List[float]
},
{
"dofn": BatchDoFnNoReturnAnnotation(),
"input_element_type": int,
"expected_process_defined": False,
"expected_process_batch_defined": True,
"expected_input_batch_type": beam.typehints.List[int],
"expected_output_batch_type": beam.typehints.List[int]
},
{
"dofn": BatchDoFnOverrideTypeInference(),
"input_element_type": int,
"expected_process_defined": False,
"expected_process_batch_defined": True,
"expected_input_batch_type": beam.typehints.List[int],
"expected_output_batch_type": beam.typehints.List[int]
},
{
"dofn": EitherDoFn(),
"input_element_type": int,
"expected_process_defined": True,
"expected_process_batch_defined": True,
"expected_input_batch_type": beam.typehints.List[int],
"expected_output_batch_type": beam.typehints.List[float]
},
{
"dofn": ElementToBatchDoFn(),
"input_element_type": int,
"expected_process_defined": True,
"expected_process_batch_defined": False,
"expected_input_batch_type": None,
"expected_output_batch_type": beam.typehints.List[int]
},
{
"dofn": BatchToElementDoFn(),
"input_element_type": int,
"expected_process_defined": False,
"expected_process_batch_defined": True,
"expected_input_batch_type": beam.typehints.List[int],
Expand All @@ -127,11 +152,12 @@ def test_process_batch_defined(self):

def test_get_input_batch_type(self):
self.assertEqual(
self.dofn.get_input_batch_type(), self.expected_input_batch_type)
self.dofn._get_input_batch_type_normalized(self.input_element_type),
self.expected_input_batch_type)

def test_get_output_batch_type(self):
self.assertEqual(
self.dofn.get_output_batch_type(beam.typehints.Any),
self.dofn._get_output_batch_type_normalized(self.input_element_type),
self.expected_output_batch_type)

def test_can_yield_batches(self):
Expand All @@ -152,8 +178,8 @@ def test_map_pardo(self):

self.assertTrue(dofn.process_defined)
self.assertFalse(dofn.process_batch_defined)
self.assertEqual(dofn.get_input_batch_type(), None)
self.assertEqual(dofn.get_output_batch_type(int), None)
self.assertEqual(dofn._get_input_batch_type_normalized(int), None)
self.assertEqual(dofn._get_output_batch_type_normalized(int), None)

def test_no_input_annotation_raises(self):
p = beam.Pipeline()
Expand Down
64 changes: 53 additions & 11 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,23 @@ def process_yields_batches(self) -> bool:
def process_batch_yields_elements(self) -> bool:
return getattr(self.process_batch, '_beam_yields_elements', False)

def get_input_batch_type(self) -> typing.Optional[TypeConstraint]:
def get_input_batch_type(
self, input_element_type
) -> typing.Optional[typing.Union[TypeConstraint, type]]:
"""Determine the batch type expected as input to process_batch.
The default implementation of ``get_input_batch_type`` simply observes the
input typehint for the first parameter of ``process_batch``. A Batched DoFn
may override this method if a dynamic approach is required.
Args:
input_element_type: The **element type** of the input PCollection this
DoFn is being applied to.
Returns:
``None`` if this DoFn cannot accept batches, else a Beam typehint or
a native Python typehint.
"""
if not self.process_batch_defined:
return None
input_type = list(
Expand All @@ -746,10 +762,18 @@ def get_input_batch_type(self) -> typing.Optional[TypeConstraint]:
# TODO(BEAM-14340): Consider supporting an alternative (dynamic?) approach
# for declaring input type
raise TypeError(
f"{self.__class__.__name__}.process_batch() does not have a type "
"annotation on its first parameter. This is required for "
"process_batch implementations.")
return typehints.native_type_compatibility.convert_to_beam_type(input_type)
f"Either {self.__class__.__name__}.process_batch() must have a type "
f"annotation on its first parameter, or {self.__class__.__name__} "
"must override get_input_batch_type.")
return input_type

def _get_input_batch_type_normalized(self, input_element_type):
return typehints.native_type_compatibility.convert_to_beam_type(
self.get_input_batch_type(input_element_type))

def _get_output_batch_type_normalized(self, input_element_type):
return typehints.native_type_compatibility.convert_to_beam_type(
self.get_output_batch_type(input_element_type))

@staticmethod
def _get_element_type_from_return_annotation(method, input_type):
Expand All @@ -770,16 +794,32 @@ def _get_element_type_from_return_annotation(method, input_type):
f"{method!r}, did you mean Iterator[{return_type}]?")

def get_output_batch_type(
self, input_element_type) -> typing.Optional[TypeConstraint]:
self, input_element_type
) -> typing.Optional[typing.Union[TypeConstraint, type]]:
"""Determine the batch type produced by this DoFn's ``process_batch``
implementation and/or its ``process`` implementation with
``@yields_batch``.
The default implementation of this method observes the return type
annotations on ``process_batch`` and/or ``process``. A Batched DoFn may
override this method if a dynamic approach is required.
Args:
input_element_type: The **element type** of the input PCollection this
DoFn is being applied to.
Returns:
``None`` if this DoFn will never yield batches, else a Beam typehint or
a native Python typehint.
"""
output_batch_type = None
if self.process_defined and self.process_yields_batches:
# TODO: Use the element_type passed to infer_output_type instead of
# typehints.Any
output_batch_type = self._get_element_type_from_return_annotation(
self.process, input_element_type)
if self.process_batch_defined and not self.process_batch_yields_elements:
process_batch_type = self._get_element_type_from_return_annotation(
self.process_batch, self.get_input_batch_type())
self.process_batch,
self._get_input_batch_type_normalized(input_element_type))

# TODO: Consider requiring an inheritance relationship rather than
# equality
Expand Down Expand Up @@ -1415,7 +1455,8 @@ def infer_output_type(self, input_type):
def infer_batch_converters(self, input_element_type):
# TODO: Test this code (in batch_dofn_test)
if self.fn.process_batch_defined:
input_batch_type = self.fn.get_input_batch_type()
input_batch_type = self.fn._get_input_batch_type_normalized(
input_element_type)

if input_batch_type is None:
raise TypeError(
Expand All @@ -1430,7 +1471,8 @@ def infer_batch_converters(self, input_element_type):
self.fn.input_batch_converter = None

if self.fn.can_yield_batches:
output_batch_type = self.fn.get_output_batch_type(input_element_type)
output_batch_type = self.fn._get_output_batch_type_normalized(
input_element_type)
if output_batch_type is None:
# TODO: Mention process method in this error
raise TypeError(
Expand Down

0 comments on commit 5f04b97

Please sign in to comment.