diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index f46a4a183699a..09706fe7187de 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -78,9 +78,7 @@ except ImportError: class FakeCython(object): - @staticmethod - def cast(type, value): - return value + compiled = False globals()['cython'] = FakeCython() @@ -93,6 +91,22 @@ def cast(type, value): SdfSplitResultsResidual = Tuple['DoOperation', 'SplitResultResidual'] +# TODO(BEAM-9324) Remove these workarounds once upgraded to Cython 3 +def _cast_to_operation(value): + if cython.compiled: + return cython.cast(Operation, value) + else: + return value + + +# TODO(BEAM-9324) Remove these workarounds once upgraded to Cython 3 +def _cast_to_receiver(value): + if cython.compiled: + return cython.cast(Receiver, value) + else: + return value + + class ConsumerSet(Receiver): """A ConsumerSet represents a graph edge between two Operation nodes. @@ -307,7 +321,7 @@ def receive(self, windowed_value): self.update_counters_start(windowed_value) for consumer in self.element_consumers: - cython.cast(Operation, consumer).process(windowed_value) + _cast_to_operation(consumer).process(windowed_value) # TODO: Do this branching when contstructing ConsumerSet if self.has_batch_consumers: @@ -324,10 +338,10 @@ def receive_batch(self, windowed_batch): for wv in windowed_batch.as_windowed_values( self.producer_batch_converter.explode_batch): for consumer in self.element_consumers: - cython.cast(Operation, consumer).process(wv) + _cast_to_operation(consumer).process(wv) for consumer in self.passthrough_batch_consumers: - cython.cast(Operation, consumer).process_batch(windowed_batch) + _cast_to_operation(consumer).process_batch(windowed_batch) for (consumer_batch_converter, consumers) in self.other_batch_consumers.items(): @@ -342,7 +356,7 @@ def receive_batch(self, windowed_batch): "This is very inefficient, consider re-structuring your pipeline " "or adding a DoFn to directly convert between these types.", InefficientExecutionWarning) - cython.cast(Operation, consumer).process_batch( + _cast_to_operation(consumer).process_batch( windowed_batch.with_values( consumer_batch_converter.produce_batch( self.producer_batch_converter.explode_batch( @@ -358,13 +372,13 @@ def flush(self): for windowed_batch in WindowedBatch.from_windowed_values( self._batched_elements, produce_fn=batch_converter.produce_batch): for consumer in consumers: - cython.cast(Operation, consumer).process_batch(windowed_batch) + _cast_to_operation(consumer).process_batch(windowed_batch) for consumer in self.passthrough_batch_consumers: for windowed_batch in WindowedBatch.from_windowed_values( self._batched_elements, produce_fn=self.producer_batch_converter.produce_batch): - cython.cast(Operation, consumer).process_batch(windowed_batch) + _cast_to_operation(consumer).process_batch(windowed_batch) self._batched_elements = [] @@ -495,7 +509,7 @@ def finish(self): """Finish operation.""" for receiver in self.receivers: - cython.cast(Receiver, receiver).flush() + _cast_to_receiver(receiver).flush() def teardown(self): # type: () -> None @@ -511,7 +525,7 @@ def reset(self): def output(self, windowed_value, output_index=0): # type: (WindowedValue, int) -> None - cython.cast(Receiver, self.receivers[output_index]).receive(windowed_value) + _cast_to_receiver(self.receivers[output_index]).receive(windowed_value) def add_receiver(self, operation, output_index=0): # type: (Operation, int) -> None