Skip to content

Commit

Permalink
Use sentinel to elide the dag object on reserialization
Browse files Browse the repository at this point in the history
We don't serialize the dag on the task.dag attr when making RPC calls.  By marking it with a sentinel value, we can add understand when we're dealing with a deserialized object, and then re-set the dag attr while skipping some of the extra code applied in the setter.
  • Loading branch information
dstandish committed May 24, 2024
1 parent 9284029 commit 4a1152c
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 50 deletions.
16 changes: 13 additions & 3 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET
from airflow.utils.types import ELIDED_DAG, NOTSET
from airflow.utils.xcom import XCOM_RETURN_KEY

if TYPE_CHECKING:
Expand Down Expand Up @@ -1200,11 +1200,21 @@ def dag(self) -> DAG: # type: ignore[override]
@dag.setter
def dag(self, dag: DAG | None):
"""Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok."""
from airflow.models.dag import DAG

if dag is None:
self._dag = None
return

# if already set to elided, then just set and exit
if self._dag is ELIDED_DAG:
self._dag = dag
return
# if setting to elided, then just set and exit
if dag is ELIDED_DAG:
self._dag = ELIDED_DAG # type: ignore[assignment]
return

from airflow.models.dag import DAG

if not isinstance(dag, DAG):
raise TypeError(f"Expected DAG; received {dag.__class__.__name__}")
elif self.has_dag() and self.dag is not dag:
Expand Down
51 changes: 20 additions & 31 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.task_instance_session import set_current_task_instance_session
from airflow.utils.timeout import timeout
from airflow.utils.types import ELIDED_DAG
from airflow.utils.xcom import XCOM_RETURN_KEY

TR = TaskReschedule
Expand Down Expand Up @@ -904,13 +905,15 @@ def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydanti
def _get_template_context(
*,
task_instance: TaskInstance | TaskInstancePydantic,
dag: DAG,
session: Session | None = None,
ignore_param_exceptions: bool = True,
) -> Context:
"""
Return TI Context.
:param task_instance: the task instance
:param task_instance: the task instance for the task
:param dag for the task
:param session: SQLAlchemy ORM Session
:param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
Expand All @@ -930,27 +933,10 @@ def _get_template_context(
assert task_instance.task
assert task
assert task.dag
try:
dag: DAG = task.dag
except AirflowException:
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

if isinstance(task_instance, TaskInstancePydantic):
ti = session.scalar(
select(TaskInstance).where(
TaskInstance.task_id == task_instance.task_id,
TaskInstance.dag_id == task_instance.dag_id,
TaskInstance.run_id == task_instance.run_id,
TaskInstance.map_index == task_instance.map_index,
)
)
dag = ti.dag_model.serialized_dag.dag
if hasattr(task_instance.task, "_dag"): # BaseOperator
task_instance.task._dag = dag
else: # MappedOperator
task_instance.task.dag = dag
else:
raise
if task.dag is ELIDED_DAG:
task.dag = dag # required after deserialization

dag_run = task_instance.get_dagrun(session)
data_interval = dag.get_run_data_interval(dag_run)

Expand Down Expand Up @@ -1280,16 +1266,8 @@ def _record_task_map_for_downstreams(
:meta private:
"""
# if not task._dag:
# task._dag = dag # required when on RPC server side

# when taking task over RPC, we need to add the dag back
if isinstance(task, MappedOperator):
if not task.dag:
task.dag = dag
elif not task._dag:
task._dag = dag

if task.dag is ELIDED_DAG:
task.dag = dag # required after deserialization
if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
return
# TODO: We don't push TaskMap for mapped task instances because it's not
Expand Down Expand Up @@ -3295,8 +3273,12 @@ def get_template_context(
:param session: SQLAlchemy ORM Session
:param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
"""
if TYPE_CHECKING:
assert self.task
assert self.task.dag
return _get_template_context(
task_instance=self,
dag=self.task.dag,
session=session,
ignore_param_exceptions=ignore_param_exceptions,
)
Expand Down Expand Up @@ -3355,8 +3337,15 @@ def render_templates(
context = self.get_template_context()
original_task = self.task

ti = context["ti"]

if TYPE_CHECKING:
assert original_task
assert self.task
assert ti.task

if ti.task.dag is ELIDED_DAG:
ti.task.dag = self.task.dag

# If self.task is mapped, this call replaces self.task to point to the
# unmapped BaseOperator created by this function! This is because the
Expand Down
4 changes: 4 additions & 0 deletions airflow/serialization/pydantic/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,12 @@ def get_template_context(
"""
from airflow.models.taskinstance import _get_template_context

if TYPE_CHECKING:
assert self.task
assert self.task.dag
return _get_template_context(
task_instance=self,
dag=self.task.dag,
session=session,
ignore_param_exceptions=ignore_param_exceptions,
)
Expand Down
4 changes: 2 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
from airflow.utils.timezone import from_timestamp, parse_timezone
from airflow.utils.types import NOTSET, ArgNotSet
from airflow.utils.types import ELIDED_DAG, NOTSET, ArgNotSet

if TYPE_CHECKING:
from inspect import Parameter
Expand Down Expand Up @@ -1292,7 +1292,7 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
)
else:
op = SerializedBaseOperator(task_id=encoded_op["task_id"])

op.dag = ELIDED_DAG # type: ignore[assignment]
cls.populate_operator(op, encoded_op)
return op

Expand Down
19 changes: 19 additions & 0 deletions airflow/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ def is_arg_passed(arg: Union[ArgNotSet, None] = NOTSET) -> bool:
"""Sentinel value for argument default. See ``ArgNotSet``."""


class ElidedDag:
"""
Sentinel type to signal when dag elided on serialization.
:meta private:
"""

def __getattr__(self, item):
raise RuntimeError("Dag was elided on serialization and must be set again.")


ELIDED_DAG = ElidedDag()
"""
Sentinel value for dag elided on serialization. See ``ElidedDag``.
:meta private:
"""


class DagRunType(str, enum.Enum):
"""Class with DagRun types."""

Expand Down
21 changes: 14 additions & 7 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2472,13 +2472,20 @@ def test_operator_expand_deserialized_unmap():
@pytest.mark.db_test
def test_sensor_expand_deserialized_unmap():
"""Unmap a deserialized mapped sensor should be similar to deserializing a non-mapped sensor"""
normal = BashSensor(task_id="a", bash_command=[1, 2], mode="reschedule")
mapped = BashSensor.partial(task_id="a", mode="reschedule").expand(bash_command=[1, 2])

serialize = SerializedBaseOperator.serialize

deserialize = SerializedBaseOperator.deserialize
assert deserialize(serialize(mapped)).unmap(None) == deserialize(serialize(normal))
dag = DAG(dag_id="hello", start_date=None)
with dag:
normal = BashSensor(task_id="a", bash_command=[1, 2], mode="reschedule")
mapped = BashSensor.partial(task_id="b", mode="reschedule").expand(bash_command=[1, 2])
ser_mapped = SerializedBaseOperator.serialize(mapped)
deser_mapped = SerializedBaseOperator.deserialize(ser_mapped)
deser_mapped.dag = dag
deser_unmapped = deser_mapped.unmap(None)
ser_normal = SerializedBaseOperator.serialize(normal)
deser_normal = SerializedBaseOperator.deserialize(ser_normal)
deser_normal.dag = dag
comps = set(BashSensor._comps)
comps.remove("task_id")
assert all(getattr(deser_unmapped, c, None) == getattr(deser_normal, c, None) for c in comps)


def test_task_resources_serde():
Expand Down
27 changes: 20 additions & 7 deletions tests/serialization/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from airflow.jobs.job import Job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.models import MappedOperator
from airflow.models.dag import DagModel
from airflow.models.dag import DAG, DagModel
from airflow.models.dataset import (
DagScheduleDatasetReference,
DatasetEvent,
Expand All @@ -43,7 +43,7 @@
from airflow.settings import _ENABLE_AIP_44
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.types import DagRunType
from airflow.utils.types import ELIDED_DAG, DagRunType
from tests.models import DEFAULT_DATE

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, d
"task_id": "target",
}

with dag_maker():
with dag_maker() as dag:

@task
def source():
Expand Down Expand Up @@ -114,7 +114,7 @@ def target(val=None):
# roundtrip ti
sered = BaseSerialization.serialize(ti, use_pydantic_models=True)
desered = BaseSerialization.deserialize(sered, use_pydantic_models=True)

assert desered.task.dag is ELIDED_DAG
assert "operator_class" not in sered["__var"]["task"]

assert desered.task.__class__ == MappedOperator
Expand All @@ -127,9 +127,22 @@ def target(val=None):

assert isinstance(desered.task.operator_class, dict)

resered = BaseSerialization.serialize(desered, use_pydantic_models=True)
deresered = BaseSerialization.deserialize(resered, use_pydantic_models=True)
assert deresered.task.operator_class == desered.task.operator_class == op_class_dict_expected
# let's check that we can safely add back dag...
assert isinstance(dag, DAG)
# dag already has this task
assert dag.has_task(desered.task.task_id) is True
# but the task has no dag
assert desered.task.dag is ELIDED_DAG
# and there are no upstream / downstreams on the task cus those are wiped out on serialization
# and this is wrong / not great but that's how it is
assert desered.task.upstream_task_ids == set()
assert desered.task.downstream_task_ids == set()
# add the dag back
desered.task.dag = dag
# great, no error
# but still, there are no upstream downstreams
assert desered.task.upstream_task_ids == set()
assert desered.task.downstream_task_ids == set()


@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
Expand Down

0 comments on commit 4a1152c

Please sign in to comment.