diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index fa48b474ca..60061cbb0e 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -58,6 +58,8 @@ def __init__( self._max_concurrency = concurrency self._min_success_ratio = min_success_ratio self._array_task_interface = python_function_task.python_interface + if "metadata" not in kwargs and python_function_task.metadata: + kwargs["metadata"] = python_function_task.metadata super().__init__( name=name, interface=collection_interface, diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 80ab499416..95df669829 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -18,6 +18,12 @@ def t1(a: int) -> str: return str(b) +@task(cache=True, cache_version="1") +def t2(a: int) -> str: + b = a + 2 + return str(b) + + # This test is for documentation. def test_map_docs(): # test_map_task_start @@ -162,3 +168,11 @@ def many_outputs(a: int) -> (int, str): with pytest.raises(ValueError): _ = map_task(many_inputs) + + +def test_map_task_metadata(): + map_meta = TaskMetadata(retries=1) + mapped_1 = map_task(t2, metadata=map_meta) + assert mapped_1.metadata is map_meta + mapped_2 = map_task(t2) + assert mapped_2.metadata is t2.metadata