From a025e86bcea5b5e6ba61c23e66df732bdac782d2 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 24 May 2022 10:26:21 +0100 Subject: [PATCH] Move MappedOperator tests to mirror code location At some point during the development of AIP-42 we moved the code for MappedOperator out of baseoperator.py to mappedoperator.py, but we didn't move the tests at the same time --- tests/models/test_baseoperator.py | 252 ------------------------- tests/models/test_mappedoperator.py | 278 ++++++++++++++++++++++++++++ 2 files changed, 278 insertions(+), 252 deletions(-) create mode 100644 tests/models/test_mappedoperator.py diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 5ba271a5a136a..8c75c86ed43c1 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -22,7 +22,6 @@ from unittest import mock import jinja2 -import pendulum import pytest from airflow.decorators import task as task_decorator @@ -30,20 +29,13 @@ from airflow.lineage.entities import File from airflow.models import DAG from airflow.models.baseoperator import BaseOperator, BaseOperatorMeta, chain, cross_downstream -from airflow.models.mappedoperator import MappedOperator -from airflow.models.taskinstance import TaskInstance -from airflow.models.taskmap import TaskMap -from airflow.models.xcom import XCOM_RETURN_KEY -from airflow.models.xcom_arg import XComArg from airflow.utils.context import Context from airflow.utils.edgemodifier import Label -from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule from tests.models import DEFAULT_DATE from tests.test_utils.config import conf_vars -from tests.test_utils.mapping import expand_mapped_task from tests.test_utils.mock_operators import DeprecatedOperator, MockOperator @@ -752,250 +744,6 @@ def test_operator_retries(caplog, dag_maker, retries, expected): assert caplog.record_tuples == expected -def test_task_mapping_with_dag(): - with DAG("test-dag", start_date=DEFAULT_DATE) as dag: - task1 = BaseOperator(task_id="op1") - literal = ['a', 'b', 'c'] - mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal) - finish = MockOperator(task_id="finish") - - task1 >> mapped >> finish - - assert task1.downstream_list == [mapped] - assert mapped in dag.tasks - assert mapped.task_group == dag.task_group - # At parse time there should only be three tasks! - assert len(dag.tasks) == 3 - - assert finish.upstream_list == [mapped] - assert mapped.downstream_list == [finish] - - -def test_task_mapping_without_dag_context(): - with DAG("test-dag", start_date=DEFAULT_DATE) as dag: - task1 = BaseOperator(task_id="op1") - literal = ['a', 'b', 'c'] - mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal) - - task1 >> mapped - - assert isinstance(mapped, MappedOperator) - assert mapped in dag.tasks - assert task1.downstream_list == [mapped] - assert mapped in dag.tasks - # At parse time there should only be two tasks! - assert len(dag.tasks) == 2 - - -def test_task_mapping_default_args(): - default_args = {'start_date': DEFAULT_DATE.now(), 'owner': 'test'} - with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args): - task1 = BaseOperator(task_id="op1") - literal = ['a', 'b', 'c'] - mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal) - - task1 >> mapped - - assert mapped.partial_kwargs['owner'] == 'test' - assert mapped.start_date == pendulum.instance(default_args['start_date']) - - -def test_map_unknown_arg_raises(): - with pytest.raises(TypeError, match=r"argument 'file'"): - BaseOperator.partial(task_id='a').expand(file=[1, 2, {'a': 'b'}]) - - -def test_map_xcom_arg(): - """Test that dependencies are correct when mapping with an XComArg""" - with DAG("test-dag", start_date=DEFAULT_DATE): - task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1)) - finish = MockOperator(task_id="finish") - - mapped >> finish - - assert task1.downstream_list == [mapped] - - -def test_partial_on_instance() -> None: - """`.partial` on an instance should fail -- it's only designed to be called on classes""" - with pytest.raises(TypeError): - MockOperator( - task_id='a', - ).partial() - - -def test_partial_on_class() -> None: - # Test that we accept args for superclasses too - op = MockOperator.partial(task_id='a', arg1="a", trigger_rule=TriggerRule.ONE_FAILED) - assert op.kwargs["arg1"] == "a" - assert op.kwargs["trigger_rule"] == TriggerRule.ONE_FAILED - - -def test_partial_on_class_invalid_ctor_args() -> None: - """Test that when we pass invalid args to partial(). - - I.e. if an arg is not known on the class or any of its parent classes we error at parse time - """ - with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"): - MockOperator.partial(task_id='a', foo='bar', bar=2) - - -@pytest.mark.parametrize( - ["num_existing_tis", "expected"], - ( - pytest.param(0, [(0, None), (1, None), (2, None)], id='only-unmapped-ti-exists'), - pytest.param( - 3, - [(0, 'success'), (1, 'success'), (2, 'success')], - id='all-tis-exist', - ), - pytest.param( - 5, - [ - (0, 'success'), - (1, 'success'), - (2, 'success'), - (3, TaskInstanceState.REMOVED), - (4, TaskInstanceState.REMOVED), - ], - id="tis-to-be-removed", - ), - ), -) -def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected): - literal = [1, 2, {'a': 'b'}] - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1)) - - dr = dag_maker.create_dagrun() - - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=len(literal), - keys=None, - ) - ) - - if num_existing_tis: - # Remove the map_index=-1 TI when we're creating other TIs - session.query(TaskInstance).filter( - TaskInstance.dag_id == mapped.dag_id, - TaskInstance.task_id == mapped.task_id, - TaskInstance.run_id == dr.run_id, - ).delete() - - for index in range(num_existing_tis): - # Give the existing TIs a state to make sure we don't change them - ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) - session.add(ti) - session.flush() - - mapped.expand_mapped_task(dr.run_id, session=session) - - indices = ( - session.query(TaskInstance.map_index, TaskInstance.state) - .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) - .order_by(TaskInstance.map_index) - .all() - ) - - assert indices == expected - - -def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1)) - - dr = dag_maker.create_dagrun() - - expand_mapped_task(mapped, dr.run_id, task1.task_id, length=0, session=session) - - indices = ( - session.query(TaskInstance.map_index, TaskInstance.state) - .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) - .order_by(TaskInstance.map_index) - .all() - ) - - assert indices == [(-1, TaskInstanceState.SKIPPED)] - - -def test_mapped_task_applies_default_args_classic(dag_maker): - with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) as dag: - MockOperator(task_id="simple", arg1=None, arg2=0) - MockOperator.partial(task_id="mapped").expand(arg1=[1], arg2=[2, 3]) - - assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) - assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) - - -def test_mapped_task_applies_default_args_taskflow(dag_maker): - with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) as dag: - - @dag.task - def simple(arg): - pass - - @dag.task - def mapped(arg): - pass - - simple(arg=0) - mapped.expand(arg=[1, 2]) - - assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) - assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) - - -def test_mapped_render_template_fields_validating_operator(dag_maker, session): - class MyOperator(MockOperator): - def __init__(self, value, arg1, **kwargs): - assert isinstance(value, str), "value should have been resolved before unmapping" - assert isinstance(arg1, str), "value should have been resolved before unmapping" - super().__init__(arg1=arg1, **kwargs) - self.value = value - - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - xcom_arg = XComArg(task1) - mapped = MyOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand( - value=xcom_arg, arg1=xcom_arg - ) - - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) - - ti.xcom_push(key=XCOM_RETURN_KEY, value=['{{ ds }}'], session=session) - - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=1, - keys=None, - ) - ) - session.flush() - - mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) - mapped_ti.map_index = 0 - op = mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(op, MyOperator) - - assert op.value == "{{ ds }}", "Should not be templated!" - assert op.arg1 == "{{ ds }}" - assert op.arg2 == "a" - - def test_default_retry_delay(dag_maker): with dag_maker(dag_id='test_default_retry_delay'): task1 = BaseOperator(task_id='test_no_explicit_retry_delay') diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py new file mode 100644 index 0000000000000..c720fd96d9b84 --- /dev/null +++ b/tests/models/test_mappedoperator.py @@ -0,0 +1,278 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import timedelta + +import pendulum +import pytest + +from airflow.models import DAG +from airflow.models.baseoperator import BaseOperator +from airflow.models.mappedoperator import MappedOperator +from airflow.models.taskinstance import TaskInstance +from airflow.models.taskmap import TaskMap +from airflow.models.xcom import XCOM_RETURN_KEY +from airflow.models.xcom_arg import XComArg +from airflow.utils.state import TaskInstanceState +from airflow.utils.trigger_rule import TriggerRule +from tests.models import DEFAULT_DATE +from tests.test_utils.mapping import expand_mapped_task +from tests.test_utils.mock_operators import MockOperator + + +def test_task_mapping_with_dag(): + with DAG("test-dag", start_date=DEFAULT_DATE) as dag: + task1 = BaseOperator(task_id="op1") + literal = ['a', 'b', 'c'] + mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal) + finish = MockOperator(task_id="finish") + + task1 >> mapped >> finish + + assert task1.downstream_list == [mapped] + assert mapped in dag.tasks + assert mapped.task_group == dag.task_group + # At parse time there should only be three tasks! + assert len(dag.tasks) == 3 + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] + + +def test_task_mapping_without_dag_context(): + with DAG("test-dag", start_date=DEFAULT_DATE) as dag: + task1 = BaseOperator(task_id="op1") + literal = ['a', 'b', 'c'] + mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal) + + task1 >> mapped + + assert isinstance(mapped, MappedOperator) + assert mapped in dag.tasks + assert task1.downstream_list == [mapped] + assert mapped in dag.tasks + # At parse time there should only be two tasks! + assert len(dag.tasks) == 2 + + +def test_task_mapping_default_args(): + default_args = {'start_date': DEFAULT_DATE.now(), 'owner': 'test'} + with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args): + task1 = BaseOperator(task_id="op1") + literal = ['a', 'b', 'c'] + mapped = MockOperator.partial(task_id='task_2').expand(arg2=literal) + + task1 >> mapped + + assert mapped.partial_kwargs['owner'] == 'test' + assert mapped.start_date == pendulum.instance(default_args['start_date']) + + +def test_map_unknown_arg_raises(): + with pytest.raises(TypeError, match=r"argument 'file'"): + BaseOperator.partial(task_id='a').expand(file=[1, 2, {'a': 'b'}]) + + +def test_map_xcom_arg(): + """Test that dependencies are correct when mapping with an XComArg""" + with DAG("test-dag", start_date=DEFAULT_DATE): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1)) + finish = MockOperator(task_id="finish") + + mapped >> finish + + assert task1.downstream_list == [mapped] + + +def test_partial_on_instance() -> None: + """`.partial` on an instance should fail -- it's only designed to be called on classes""" + with pytest.raises(TypeError): + MockOperator( + task_id='a', + ).partial() + + +def test_partial_on_class() -> None: + # Test that we accept args for superclasses too + op = MockOperator.partial(task_id='a', arg1="a", trigger_rule=TriggerRule.ONE_FAILED) + assert op.kwargs["arg1"] == "a" + assert op.kwargs["trigger_rule"] == TriggerRule.ONE_FAILED + + +def test_partial_on_class_invalid_ctor_args() -> None: + """Test that when we pass invalid args to partial(). + + I.e. if an arg is not known on the class or any of its parent classes we error at parse time + """ + with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"): + MockOperator.partial(task_id='a', foo='bar', bar=2) + + +@pytest.mark.parametrize( + ["num_existing_tis", "expected"], + ( + pytest.param(0, [(0, None), (1, None), (2, None)], id='only-unmapped-ti-exists'), + pytest.param( + 3, + [(0, 'success'), (1, 'success'), (2, 'success')], + id='all-tis-exist', + ), + pytest.param( + 5, + [ + (0, 'success'), + (1, 'success'), + (2, 'success'), + (3, TaskInstanceState.REMOVED), + (4, TaskInstanceState.REMOVED), + ], + id="tis-to-be-removed", + ), + ), +) +def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected): + literal = [1, 2, {'a': 'b'}] + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1)) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=len(literal), + keys=None, + ) + ) + + if num_existing_tis: + # Remove the map_index=-1 TI when we're creating other TIs + session.query(TaskInstance).filter( + TaskInstance.dag_id == mapped.dag_id, + TaskInstance.task_id == mapped.task_id, + TaskInstance.run_id == dr.run_id, + ).delete() + + for index in range(num_existing_tis): + # Give the existing TIs a state to make sure we don't change them + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) + session.add(ti) + session.flush() + + mapped.expand_mapped_task(dr.run_id, session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == expected + + +def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1)) + + dr = dag_maker.create_dagrun() + + expand_mapped_task(mapped, dr.run_id, task1.task_id, length=0, session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == [(-1, TaskInstanceState.SKIPPED)] + + +def test_mapped_task_applies_default_args_classic(dag_maker): + with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) as dag: + MockOperator(task_id="simple", arg1=None, arg2=0) + MockOperator.partial(task_id="mapped").expand(arg1=[1], arg2=[2, 3]) + + assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) + assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) + + +def test_mapped_task_applies_default_args_taskflow(dag_maker): + with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) as dag: + + @dag.task + def simple(arg): + pass + + @dag.task + def mapped(arg): + pass + + simple(arg=0) + mapped.expand(arg=[1, 2]) + + assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) + assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) + + +def test_mapped_render_template_fields_validating_operator(dag_maker, session): + class MyOperator(MockOperator): + def __init__(self, value, arg1, **kwargs): + assert isinstance(value, str), "value should have been resolved before unmapping" + assert isinstance(arg1, str), "value should have been resolved before unmapping" + super().__init__(arg1=arg1, **kwargs) + self.value = value + + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + xcom_arg = XComArg(task1) + mapped = MyOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand( + value=xcom_arg, arg1=xcom_arg + ) + + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + + ti.xcom_push(key=XCOM_RETURN_KEY, value=['{{ ds }}'], session=session) + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=1, + keys=None, + ) + ) + session.flush() + + mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) + mapped_ti.map_index = 0 + op = mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(op, MyOperator) + + assert op.value == "{{ ds }}", "Should not be templated!" + assert op.arg1 == "{{ ds }}" + assert op.arg2 == "a"