Skip to content

Commit

Permalink
Disable nested task mapping for now
Browse files Browse the repository at this point in the history
I think I actually get at least most of the nested logic right-ish, but
this restriction is probably a good thing since nested task mapping
would require much work in many more places e.g. UI to work well from a
user's perspective.
  • Loading branch information
uranusjr committed Nov 15, 2022
1 parent 45b4d35 commit 6a8f814
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
3 changes: 3 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ def __repr__(self):
def __attrs_post_init__(self):
from airflow.models.xcom_arg import XComArg

if next(self.iter_mapped_task_groups(), None) is not None:
raise NotImplementedError("operator expansion in an expanded task group is not yet supported")

if self.task_group:
self.task_group.add(self)
if self.dag:
Expand Down
18 changes: 11 additions & 7 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,11 +1700,15 @@ def tg(x):
# Normal operator in mapped task group, expands to 2 tis.
MockOperator(task_id="t1")
# Mapped operator expands *again* against mapped task group arguments to 4 tis.
MockOperator.partial(task_id="t2").expand(arg1=literal)
with pytest.raises(NotImplementedError) as ctx:
MockOperator.partial(task_id="t2").expand(arg1=literal)
assert str(ctx.value) == "operator expansion in an expanded task group is not yet supported"
# Normal operator referencing mapped task group arguments does not further expand, only 2 tis.
MockOperator(task_id="t3", arg1=x)
# It can expand *again* (since each item in x is a list) but this is not done at parse time.
MockOperator.partial(task_id="t4").expand(arg1=x)
with pytest.raises(NotImplementedError) as ctx:
MockOperator.partial(task_id="t4").expand(arg1=x)
assert str(ctx.value) == "operator expansion in an expanded task group is not yet supported"

tg.expand(x=literal)

Expand All @@ -1717,13 +1721,13 @@ def tg(x):
assert query.all() == [
("tg.t1", 0, None),
("tg.t1", 1, None),
("tg.t2", 0, None),
("tg.t2", 1, None),
("tg.t2", 2, None),
("tg.t2", 3, None),
# ("tg.t2", 0, None),
# ("tg.t2", 1, None),
# ("tg.t2", 2, None),
# ("tg.t2", 3, None),
("tg.t3", 0, None),
("tg.t3", 1, None),
("tg.t4", -1, None),
# ("tg.t4", -1, None),
]


Expand Down
6 changes: 4 additions & 2 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,7 +2296,9 @@ def test_mapped_task_group_serde():
@task_group
def tg(a: str) -> None:
BaseOperator(task_id="op1")
BashOperator.partial(task_id="op2").expand(bash_command=["ls", a])
with pytest.raises(NotImplementedError) as ctx:
BashOperator.partial(task_id="op2").expand(bash_command=["ls", a])
assert str(ctx.value) == "operator expansion in an expanded task group is not yet supported"

tg.expand(a=[".", ".."])

Expand All @@ -2307,7 +2309,7 @@ def tg(a: str) -> None:
"_group_id": "tg",
"children": {
"tg.op1": ("operator", "tg.op1"),
"tg.op2": ("operator", "tg.op2"),
# "tg.op2": ("operator", "tg.op2"),
},
"downstream_group_ids": [],
"downstream_task_ids": [],
Expand Down

0 comments on commit 6a8f814

Please sign in to comment.