Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand tasks in mapped group at parse time #27158

Merged
merged 2 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,9 @@
ListOfDictsExpandInput,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
is_mappable,
)
from airflow.models.mappedoperator import (
MappedOperator,
ValidationSource,
ensure_xcomarg_return_value,
get_mappable_types,
)
from airflow.models.mappedoperator import MappedOperator, ValidationSource, ensure_xcomarg_return_value
from airflow.models.pool import Pool
from airflow.models.xcom_arg import XComArg
from airflow.typing_compat import ParamSpec, Protocol
Expand Down Expand Up @@ -100,7 +96,7 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]) ->
kwargs_left = kwargs.copy()
for arg_name in self._mappable_function_argument_names:
value = kwargs_left.pop(arg_name, NOTSET)
if func != "expand" or value is NOTSET or isinstance(value, get_mappable_types()):
if func != "expand" or value is NOTSET or is_mappable(value):
continue
tname = type(value).__name__
raise ValueError(f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}")
Expand Down
21 changes: 3 additions & 18 deletions airflow/decorators/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,19 @@
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Mapping, Sequence, TypeVar, overload

import attr
from sqlalchemy.orm import Session

from airflow.decorators.base import ExpandableFactory
from airflow.models.expandinput import (
DictOfListsExpandInput,
ExpandInput,
ListOfDictsExpandInput,
MappedArgument,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
from airflow.models.taskmixin import DAGNode
from airflow.models.xcom_arg import XComArg
from airflow.typing_compat import ParamSpec
from airflow.utils.context import Context
from airflow.utils.helpers import prevent_duplicates
from airflow.utils.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.task_group import MappedTaskGroup, TaskGroup

if TYPE_CHECKING:
Expand All @@ -58,17 +54,6 @@
task_group_sig = inspect.signature(TaskGroup.__init__)


@attr.define(kw_only=True)
class _MappedArgument(ResolveMixin):
_input: ExpandInput
_key: str

@provide_session
def resolve(self, context: Context, *, session: Session = NEW_SESSION) -> Any:
data, _ = self._input.resolve(context, session=session)
return data[self._key]


@attr.define()
class _TaskGroupFactory(ExpandableFactory, Generic[FParams, FReturn]):
function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable())
Expand Down Expand Up @@ -146,7 +131,7 @@ def expand(self, **kwargs: OperatorExpandArgument) -> DAGNode:
return self._create_task_group(
functools.partial(MappedTaskGroup, expand_input=expand_input),
**self.partial_kwargs,
**{k: _MappedArgument(input=expand_input, key=k) for k in kwargs},
**{k: MappedArgument(input=expand_input, key=k) for k in kwargs},
)

def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument) -> DAGNode:
Expand Down Expand Up @@ -175,7 +160,7 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument) -> DAGNode:
return self._create_task_group(
functools.partial(MappedTaskGroup, expand_input=expand_input),
**self.partial_kwargs,
**{k: _MappedArgument(input=expand_input, key=k) for k in map_kwargs},
**{k: MappedArgument(input=expand_input, key=k) for k in map_kwargs},
)


Expand Down
54 changes: 53 additions & 1 deletion airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from __future__ import annotations

import datetime
import functools
import inspect
import operator
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence

from airflow.compat.functools import cached_property
from airflow.compat.functools import cache, cached_property
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models.taskmixin import DAGNode
Expand All @@ -30,6 +32,7 @@
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule

Expand Down Expand Up @@ -65,6 +68,10 @@
)


class NotMapped(Exception):
"""Raise if a task is neither mapped nor has any parent mapped groups."""


class AbstractOperator(LoggingMixin, DAGNode):
"""Common implementation for operators, including unmapped and mapped.

Expand Down Expand Up @@ -276,6 +283,14 @@ def iter_mapped_dependants(self) -> Iterator[MappedOperator]:
if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies())
)

def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
"""Return mapped task groups this task belongs to."""
parent = self.task_group
while parent is not None:
if isinstance(parent, MappedTaskGroup):
yield parent
parent = parent.task_group

def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator:
"""Get the "normal" operator from current abstract operator.

Expand Down Expand Up @@ -370,6 +385,43 @@ def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:
return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc]
return link.get_link(self.unmap(None), ti_key=ti.key)

@cache
def get_parse_time_mapped_ti_count(self) -> int:
"""Number of mapped task instances that can be created on DAG run creation.

This only considers literal mapped arguments, and would return *None*
when any non-literal values are used for mapping.

:raise NotFullyPopulated: If non-literal mapped arguments are encountered.
:raise NotMapped: If the operator is neither mapped, nor has any parent
mapped task groups.
:return: Total number of mapped TIs this task should have.
"""
mapped_task_groups = list(self.iter_mapped_task_groups())
if not mapped_task_groups:
raise NotMapped
counts = (g.get_parse_time_mapped_ti_count() for g in mapped_task_groups)
return functools.reduce(operator.mul, counts)

def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
"""Number of mapped TaskInstances that can be created at run time.

This considers both literal and non-literal mapped arguments, and the
result is therefore available when all depended tasks have finished. The
return value should be identical to ``parse_time_mapped_ti_count`` if
all mapped arguments are literal.

:raise NotFullyPopulated: If upstream tasks are not all complete yet.
:raise NotMapped: If the operator is neither mapped, nor has any parent
mapped task groups.
:return: Total number of mapped TIs this task should have.
"""
mapped_task_groups = list(self.iter_mapped_task_groups())
if not mapped_task_groups:
raise NotMapped
counts = (g.get_mapped_ti_count(run_id, session=session) for g in mapped_task_groups)
return functools.reduce(operator.mul, counts)

def render_template_fields(
self,
context: Context,
Expand Down
84 changes: 39 additions & 45 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,7 @@
import warnings
from collections import defaultdict
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
NamedTuple,
Sequence,
TypeVar,
cast,
overload,
)
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, NamedTuple, Sequence, TypeVar, overload

from sqlalchemy import (
Boolean,
Expand All @@ -59,7 +48,9 @@
from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.configuration import conf as airflow_conf
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskNotFound
from airflow.models.abstractoperator import NotMapped
from airflow.models.base import Base, StringID
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.tasklog import LogTemplate
Expand Down Expand Up @@ -907,7 +898,6 @@ def _check_for_removed_or_restored_tasks(
for ti in tis:
ti_mutation_hook(ti)
task_ids.add(ti.task_id)
task = None
try:
task = dag.get_task(ti.task_id)

Expand All @@ -925,28 +915,15 @@ def _check_for_removed_or_restored_tasks(
ti.state = State.REMOVED
continue

if not task.is_mapped:
try:
num_mapped_tis = task.get_parse_time_mapped_ti_count()
except NotMapped:
continue
task = cast("MappedOperator", task)
num_mapped_tis = task.parse_time_mapped_ti_count
# Check if the number of mapped literals has changed and we need to mark this TI as removed
if num_mapped_tis is not None:
if ti.map_index >= num_mapped_tis:
self.log.debug(
"Removing task '%s' as the map_index is longer than the literal mapping list (%s)",
ti,
num_mapped_tis,
)
ti.state = State.REMOVED
elif ti.map_index < 0:
self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti)
ti.state = State.REMOVED
else:
# What if it is _now_ dynamically mapped, but wasn't before?
task.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
total_length = task.get_mapped_ti_count(self.run_id, session=session)

if total_length is None:
except NotFullyPopulated:
# What if it is _now_ dynamically mapped, but wasn't before?
try:
total_length = task.get_mapped_ti_count(self.run_id, session=session)
except NotFullyPopulated:
# Not all upstreams finished, so we can't tell what should be here. Remove everything.
if ti.map_index >= 0:
self.log.debug(
Expand All @@ -962,6 +939,18 @@ def _check_for_removed_or_restored_tasks(
total_length,
)
ti.state = State.REMOVED
else:
# Check if the number of mapped literals has changed and we need to mark this TI as removed.
if ti.map_index >= num_mapped_tis:
self.log.debug(
"Removing task '%s' as the map_index is longer than the literal mapping list (%s)",
ti,
num_mapped_tis,
)
ti.state = State.REMOVED
elif ti.map_index < 0:
self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti)
ti.state = State.REMOVED

return task_ids

Expand Down Expand Up @@ -1033,15 +1022,20 @@ def _create_tasks(
:param tasks: Tasks to create jobs for in the DAG run
:param task_creator: Function to create task instances
"""
map_indexes: Iterable[int]
for task in tasks:
if not task.is_mapped:
yield from task_creator(task, (-1,))
continue
count = cast(MappedOperator, task).get_mapped_ti_count(self.run_id, session=session)
if count:
yield from task_creator(task, range(count))
continue
yield from task_creator(task, (-1,))
try:
count = task.get_mapped_ti_count(self.run_id, session=session)
except (NotMapped, NotFullyPopulated):
map_indexes = (-1,)
else:
if count:
map_indexes = range(count)
else:
# Make sure to always create at least one ti; this will be
# marked as REMOVED later at runtime.
map_indexes = (-1,)
yield from task_creator(task, map_indexes)

def _create_task_instances(
self,
Expand Down Expand Up @@ -1090,9 +1084,9 @@ def _revise_mapped_task_indexes(self, task: MappedOperator, session: Session) ->
"""Check if task increased or reduced in length and handle appropriately"""
from airflow.settings import task_instance_mutation_hook

task.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
total_length = task.get_mapped_ti_count(self.run_id, session=session)
if total_length is None: # Upstreams not ready, don't need to revise this yet.
try:
total_length = task.get_mapped_ti_count(self.run_id, session=session)
except NotFullyPopulated: # Upstreams not ready, don't need to revise this yet.
return []

query = session.query(TI.map_index).filter(
Expand Down
Loading