Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-9324] Fix incompatibility of direct runner with cython #17728

Merged
merged 2 commits into from
May 25, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions sdks/python/apache_beam/runners/worker/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@
except ImportError:

class FakeCython(object):
@staticmethod
def cast(type, value):
return value
compiled = False

globals()['cython'] = FakeCython()

Expand All @@ -93,6 +91,22 @@ def cast(type, value):
SdfSplitResultsResidual = Tuple['DoOperation', 'SplitResultResidual']


# TODO(BEAM-9324) Remove these workarounds once upgraded to Cython 3
def _cast_to_operation(value):
if cython.compiled:
return cython.cast(Operation, value)
else:
return value


# TODO(BEAM-9324) Remove these workarounds once upgraded to Cython 3
def _cast_to_receiver(value):
if cython.compiled:
return cython.cast(Receiver, value)
else:
return value


class ConsumerSet(Receiver):
"""A ConsumerSet represents a graph edge between two Operation nodes.

Expand Down Expand Up @@ -307,7 +321,7 @@ def receive(self, windowed_value):
self.update_counters_start(windowed_value)

for consumer in self.element_consumers:
cython.cast(Operation, consumer).process(windowed_value)
_cast_to_operation(consumer).process(windowed_value)

# TODO: Do this branching when contstructing ConsumerSet
if self.has_batch_consumers:
Expand All @@ -324,10 +338,10 @@ def receive_batch(self, windowed_batch):
for wv in windowed_batch.as_windowed_values(
self.producer_batch_converter.explode_batch):
for consumer in self.element_consumers:
cython.cast(Operation, consumer).process(wv)
_cast_to_operation(consumer).process(wv)

for consumer in self.passthrough_batch_consumers:
cython.cast(Operation, consumer).process_batch(windowed_batch)
_cast_to_operation(consumer).process_batch(windowed_batch)

for (consumer_batch_converter,
consumers) in self.other_batch_consumers.items():
Expand All @@ -342,7 +356,7 @@ def receive_batch(self, windowed_batch):
"This is very inefficient, consider re-structuring your pipeline "
"or adding a DoFn to directly convert between these types.",
InefficientExecutionWarning)
cython.cast(Operation, consumer).process_batch(
_cast_to_operation(consumer).process_batch(
windowed_batch.with_values(
consumer_batch_converter.produce_batch(
self.producer_batch_converter.explode_batch(
Expand All @@ -358,13 +372,13 @@ def flush(self):
for windowed_batch in WindowedBatch.from_windowed_values(
self._batched_elements, produce_fn=batch_converter.produce_batch):
for consumer in consumers:
cython.cast(Operation, consumer).process_batch(windowed_batch)
_cast_to_operation(consumer).process_batch(windowed_batch)

for consumer in self.passthrough_batch_consumers:
for windowed_batch in WindowedBatch.from_windowed_values(
self._batched_elements,
produce_fn=self.producer_batch_converter.produce_batch):
cython.cast(Operation, consumer).process_batch(windowed_batch)
_cast_to_operation(consumer).process_batch(windowed_batch)

self._batched_elements = []

Expand Down Expand Up @@ -495,7 +509,7 @@ def finish(self):

"""Finish operation."""
for receiver in self.receivers:
cython.cast(Receiver, receiver).flush()
_cast_to_receiver(receiver).flush()

def teardown(self):
# type: () -> None
Expand All @@ -511,7 +525,7 @@ def reset(self):

def output(self, windowed_value, output_index=0):
# type: (WindowedValue, int) -> None
cython.cast(Receiver, self.receivers[output_index]).receive(windowed_value)
_cast_to_receiver(self.receivers[output_index]).receive(windowed_value)

def add_receiver(self, operation, output_index=0):
# type: (Operation, int) -> None
Expand Down