Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sensor mention in taskflow concepts #28708

Merged
merged 2 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/decorators/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
kwargs["task_id"] = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
super().__init__(**kwargs)

def poke(self, context: Context) -> PokeReturnValue:
def poke(self, context: Context) -> PokeReturnValue | bool:
return self.python_callable(*self.op_args, **self.op_kwargs)


Expand Down
2 changes: 1 addition & 1 deletion airflow/sensors/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
self.op_kwargs = op_kwargs or {}
self.templates_dict = templates_dict

def poke(self, context: Context) -> PokeReturnValue:
def poke(self, context: Context) -> PokeReturnValue | bool:
context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
self.op_kwargs = determine_kwargs(self.python_callable, self.op_args, context)

Expand Down
9 changes: 9 additions & 0 deletions docs/apache-airflow/core-concepts/taskflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ for deserialization ensure that ``deserialize(data: dict, version: int)`` is spe

Note: Typing of ``__version__`` is required and needs to be ``ClassVar[int]``


Sensors and the TaskFlow API
--------------------------------------

.. versionadded:: 2.5.0

For an example of writing a Sensor using the TaskFlow API, see
:ref:`Using the TaskFlow API with Sensor operators <taskflow/task_sensor_example>`.

History
-------

Expand Down
6 changes: 5 additions & 1 deletion docs/apache-airflow/tutorial/taskflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,11 @@ You can apply the ``@task.sensor`` decorator to convert a regular Python functio
BaseSensorOperator class. The Python function implements the poke logic and returns an instance of
the ``PokeReturnValue`` class as the ``poke()`` method in the BaseSensorOperator does. The ``PokeReturnValue`` is
a new feature in Airflow 2.3 that allows a sensor operator to push an XCom value as described in
section "Having sensors return XOM values" of :doc:`apache-airflow-providers:howto/create-update-providers`.
section "Having sensors return XCOM values" of :doc:`apache-airflow-providers:howto/create-update-providers`.

Alternatively in cases where the sensor doesn't need to push XCOM values: both ``poke()`` and the wrapped
function can return a boolean-like value where ``True`` designates the sensor's operation as complete and
``False`` designates the sensor's operation as incomplete.

.. _taskflow/task_sensor_example:

Expand Down
74 changes: 74 additions & 0 deletions tests/decorators/test_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,30 @@ def dummy_f():
)
assert actual_xcom_value == sensor_xcom_value

def test_basic_sensor_success_returns_bool(self, dag_maker):
@task.sensor
def sensor_f():
return True

@task
def dummy_f():
pass

with dag_maker():
sf = sensor_f()
df = dummy_f()
sf >> df

dr = dag_maker.create_dagrun()
sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == "sensor_f":
assert ti.state == State.SUCCESS
if ti.task_id == "dummy_f":
assert ti.state == State.NONE

def test_basic_sensor_failure(self, dag_maker):
@task.sensor(timeout=0)
def sensor_f():
Expand All @@ -89,6 +113,32 @@ def dummy_f():
if ti.task_id == "dummy_f":
assert ti.state == State.NONE

def test_basic_sensor_failure_returns_bool(self, dag_maker):
@task.sensor(timeout=0)
def sensor_f():
return False

@task
def dummy_f():
pass

with dag_maker():
sf = sensor_f()
df = dummy_f()
sf >> df

dr = dag_maker.create_dagrun()
with pytest.raises(AirflowSensorTimeout):
sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True)

tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == "sensor_f":
assert ti.state == State.FAILED
if ti.task_id == "dummy_f":
assert ti.state == State.NONE

def test_basic_sensor_soft_fail(self, dag_maker):
@task.sensor(timeout=0, soft_fail=True)
def sensor_f():
Expand All @@ -113,6 +163,30 @@ def dummy_f():
if ti.task_id == "dummy_f":
assert ti.state == State.NONE

def test_basic_sensor_soft_fail_returns_bool(self, dag_maker):
@task.sensor(timeout=0, soft_fail=True)
def sensor_f():
return False

@task
def dummy_f():
pass

with dag_maker():
sf = sensor_f()
df = dummy_f()
sf >> df

dr = dag_maker.create_dagrun()
sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == "sensor_f":
assert ti.state == State.SKIPPED
if ti.task_id == "dummy_f":
assert ti.state == State.NONE

def test_basic_sensor_get_upstream_output(self, dag_maker):
ret_val = 100
sensor_xcom_value = "xcom_value"
Expand Down