Skip to content

Commit

Permalink
Handle db isolation for mapped operators and task groups (#39259)
Browse files Browse the repository at this point in the history
* Handle db isolation for mapped operators and task groups

* Update airflow/models/taskinstance.py
  • Loading branch information
dstandish committed Jun 14, 2024
1 parent a1f9b7d commit e69ab3a
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 47 deletions.
4 changes: 4 additions & 0 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from flask import Response

from airflow.jobs.job import Job, most_recent_job
from airflow.models.taskinstance import _record_task_map_for_downstreams
from airflow.models.xcom_arg import _get_task_map_length
from airflow.sensors.base import _orig_start_date
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.session import create_session
Expand Down Expand Up @@ -66,12 +68,14 @@ def _initialize_map() -> dict[str, Callable]:
_defer_task,
_get_template_context,
_get_ti_db_access,
_get_task_map_length,
_update_rtif,
_orig_start_date,
_handle_failure,
_handle_reschedule,
_add_log,
_xcom_pull,
_record_task_map_for_downstreams,
DagFileProcessor.update_import_errors,
DagFileProcessor.manage_slas,
DagFileProcessorManager.deactivate_stale_dags,
Expand Down
33 changes: 28 additions & 5 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,14 @@ def _execute_callable(context: Context, **execute_callable_kwargs):
for key, value in xcom_value.items():
task_instance.xcom_push(key=key, value=value, session=session_or_null)
task_instance.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session_or_null)
if TYPE_CHECKING:
assert task_orig.dag
_record_task_map_for_downstreams(
task_instance=task_instance, task=task_orig, value=xcom_value, session=session_or_null
task_instance=task_instance,
task=task_orig,
dag=task_orig.dag,
value=xcom_value,
session=session_or_null,
)
return result

Expand Down Expand Up @@ -1249,25 +1255,40 @@ def _refresh_from_task(
task_instance_mutation_hook(task_instance)


@internal_api_call
@provide_session
def _record_task_map_for_downstreams(
*, task_instance: TaskInstance | TaskInstancePydantic, task: Operator, value: Any, session: Session
*,
task_instance: TaskInstance | TaskInstancePydantic,
task: Operator,
dag: DAG,
value: Any,
session: Session,
) -> None:
"""
Record the task map for downstream tasks.
:param task_instance: the task instance
:param task: The task object
:param dag: the dag associated with the task
:param value: The value
:param session: SQLAlchemy ORM Session
:meta private:
"""
# when taking task over RPC, we need to add the dag back
if isinstance(task, MappedOperator):
if not task.dag:
task.dag = dag
elif not task._dag:
task._dag = dag

if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
return
# TODO: We don't push TaskMap for mapped task instances because it's not
# currently possible for a downstream to depend on one individual mapped
# task instance. This will change when we implement task mapping inside
# a mapped task group, and we'll need to further analyze the case.
# currently possible for a downstream to depend on one individual mapped
# task instance. This will change when we implement task mapping inside
# a mapped task group, and we'll need to further analyze the case.
if isinstance(task, MappedOperator):
return
if value is None:
Expand Down Expand Up @@ -3355,6 +3376,8 @@ def render_templates(
# MappedOperator is useless for template rendering, and we need to be
# able to access the unmapped task instead.
original_task.render_template_fields(context, jinja_env)
if isinstance(self.task, MappedOperator):
self.task = context["ti"].task

return original_task

Expand Down
100 changes: 58 additions & 42 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

from sqlalchemy import func, or_, select

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import AirflowException, XComNotFound
from airflow.models import MappedOperator, TaskInstance
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskmixin import DependencyMixin
from airflow.utils.db import exists_query
from airflow.utils.mixins import ResolveMixin
Expand Down Expand Up @@ -222,6 +223,53 @@ def __exit__(self, exc_type, exc_val, exc_tb):
SetupTeardownContext.set_work_task_roots_and_leaves()


@internal_api_call
@provide_session
def _get_task_map_length(
*,
dag_id: str,
task_id: str,
run_id: str,
is_mapped: bool,
session: Session = NEW_SESSION,
) -> int | None:
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCom

if is_mapped:
unfinished_ti_exists = exists_query(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id == task_id,
# Special NULL treatment is needed because 'state' can be NULL.
# The "IN" part would produce "NULL NOT IN ..." and eventually
# "NULl = NULL", which is a big no-no in SQL.
or_(
TaskInstance.state.is_(None),
TaskInstance.state.in_(s.value for s in State.unfinished if s is not None),
),
session=session,
)
if unfinished_ti_exists:
return None # Not all of the expanded tis are done yet.
query = select(func.count(XCom.map_index)).where(
XCom.dag_id == dag_id,
XCom.run_id == run_id,
XCom.task_id == task_id,
XCom.map_index >= 0,
XCom.key == XCOM_RETURN_KEY,
)
else:
query = select(TaskMap.length).where(
TaskMap.dag_id == dag_id,
TaskMap.run_id == run_id,
TaskMap.task_id == task_id,
TaskMap.map_index < 0,
)
return session.scalar(query)


class PlainXComArg(XComArg):
"""Reference to one single XCom without any additional semantics.
Expand Down Expand Up @@ -364,51 +412,19 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
return super().zip(*others, fillvalue=fillvalue)

def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCom

task = self.operator
if isinstance(task, MappedOperator):
unfinished_ti_exists = exists_query(
TaskInstance.dag_id == task.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id == task.task_id,
# Special NULL treatment is needed because 'state' can be NULL.
# The "IN" part would produce "NULL NOT IN ..." and eventually
# "NULl = NULL", which is a big no-no in SQL.
or_(
TaskInstance.state.is_(None),
TaskInstance.state.in_(s.value for s in State.unfinished if s is not None),
),
session=session,
)
if unfinished_ti_exists:
return None # Not all of the expanded tis are done yet.
query = select(func.count(XCom.map_index)).where(
XCom.dag_id == task.dag_id,
XCom.run_id == run_id,
XCom.task_id == task.task_id,
XCom.map_index >= 0,
XCom.key == XCOM_RETURN_KEY,
)
else:
query = select(TaskMap.length).where(
TaskMap.dag_id == task.dag_id,
TaskMap.run_id == run_id,
TaskMap.task_id == task.task_id,
TaskMap.map_index < 0,
)
return session.scalar(query)
return _get_task_map_length(
dag_id=self.operator.dag_id,
task_id=self.operator.task_id,
is_mapped=isinstance(self.operator, MappedOperator),
run_id=run_id,
session=session,
)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
from airflow.models.taskinstance import TaskInstance

ti = context["ti"]
if not isinstance(ti, TaskInstance):
raise NotImplementedError("Wait for AIP-44 implementation to complete")

if TYPE_CHECKING:
assert isinstance(ti, TaskInstance)
task_id = self.operator.task_id
map_indexes = ti.get_relevant_upstream_map_indexes(
self.operator,
Expand Down

0 comments on commit e69ab3a

Please sign in to comment.