Skip to content

Commit

Permalink
add validation for duplicate template and node names (#1054)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [ ] Fixes #<!--issue number goes here-->
- [x] Tests added
- [x] Documentation/examples added
- [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

**Description of PR**
Currently, hera lacks validation for duplicate template names, causing
templates with duplicate names to be missing when rendering to yaml.
Hera also lacks validation for duplicate node names, which results in
rendered yaml that is invalid when submitted to argo-workflows.

This PR adds validation for both situations, preventing the user from
rendering incorrect or invalid yaml when Workflows contain multiple
templates or nodes with the same name by raising a TemplateNameConflict
or NodeNameConflict, respectively.

Note that the order of operations has been adjusted in
`_HeraContext.add_sub_node`. This change was required to continue to
support workflows with recursive references.

---------

Signed-off-by: crflynn <flynn@simplebet.io>
Signed-off-by: Elliot Gunton <elliotgunton@gmail.com>
Co-authored-by: Elliot Gunton <elliotgunton@gmail.com>
  • Loading branch information
crflynn and elliotgunton authored May 31, 2024
1 parent 08038f0 commit fe4d5c9
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 21 deletions.
11 changes: 11 additions & 0 deletions src/hera/shared/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ def get_fields(cls: Type[PydanticBaseModel]) -> Dict[str, FieldInfo]:
return cls.__fields__ # type: ignore


__all__ = [
"BaseModel",
"Field",
"PrivateAttr",
"PydanticBaseModel", # Export for serialization.py to cover user-defined models
"ValidationError",
"root_validator",
"validator",
]


def get_field_annotations(cls: Type[PydanticBaseModel]) -> Dict[str, Any]:
return {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

Expand Down
10 changes: 9 additions & 1 deletion src/hera/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
from hera.workflows.data import Data
from hera.workflows.env import ConfigMapEnv, Env, FieldEnv, ResourceEnv, SecretEnv
from hera.workflows.env_from import ConfigMapEnvFrom, SecretEnvFrom
from hera.workflows.exceptions import InvalidDispatchType, InvalidTemplateCall, InvalidType
from hera.workflows.exceptions import (
InvalidDispatchType,
InvalidTemplateCall,
InvalidType,
NodeNameConflict,
TemplateNameConflict,
)
from hera.workflows.http_template import HTTP
from hera.workflows.io import Input, Output, RunnerInput, RunnerOutput
from hera.workflows.metrics import Counter, Gauge, Histogram, Label, Metric, Metrics
Expand Down Expand Up @@ -134,6 +140,7 @@
"Metric",
"Metrics",
"NFSVolume",
"NodeNameConflict",
"NoneArchiveStrategy",
"OSSArtifact",
"Operator",
Expand Down Expand Up @@ -168,6 +175,7 @@
"TarArchiveStrategy",
"Task",
"TaskResult",
"TemplateNameConflict",
"UserContainer",
"Volume",
"VsphereVirtualDiskVolume",
Expand Down
35 changes: 19 additions & 16 deletions src/hera/workflows/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import List, Optional, TypeVar, Union

from hera.shared import BaseMixin
from hera.workflows.exceptions import InvalidType
from hera.workflows.exceptions import InvalidType, TemplateNameConflict
from hera.workflows.protocol import Subbable, TTemplate

TNode = TypeVar("TNode", bound="SubNodeMixin")
Expand Down Expand Up @@ -85,27 +85,30 @@ def add_sub_node(self, node: Union[SubNodeMixin, TTemplate]) -> None:
if not pieces:
return

try:
# here, we are trying to add a node to the last piece of context in the hopes that it is a subbable
pieces[-1]._add_sub(node)
except InvalidType:
# if the above fails, it means the user invoked a decorated function e.g. `@script`. Hence,
# the object needs to be added as a template to the piece of context at [-1]. This will be the case for
# DAGs and Steps
pieces[-1]._add_sub(node.template) # type: ignore

# when the above does not raise an exception, it means the user invoked a decorated function e.g. `@script`
# inside a proper context. Here, we add the object to the overall workflow context, directly as a template,
# in case it is not found (based on the name). This helps users save on the number of templates that are
# added when using an object that is a `Script`
if hasattr(node, "template") and node.template is not None and not isinstance(node.template, str):
# When the user invokes a decorated function e.g. `@script inside a sub-context (dag/steps),
# we also add the step/task's template to the overall workflow context, if it is not already added.
from hera.workflows._mixins import TemplateInvocatorSubNodeMixin

if (
isinstance(node, TemplateInvocatorSubNodeMixin)
and node.template is not None
and not isinstance(node.template, str)
):
from hera.workflows.workflow import Workflow

assert isinstance(pieces[0], Workflow)

found = False
for t in pieces[0].templates: # type: ignore
for t in pieces[0].templates:
if t.name == node.template.name:
if t != node.template:
raise TemplateNameConflict(f"Found multiple templates with the same name: {t.name}")
found = True
break
if not found:
pieces[0]._add_sub(node.template)

pieces[-1]._add_sub(node)


_context = _HeraContext()
6 changes: 5 additions & 1 deletion src/hera/workflows/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from hera.shared._pydantic import PrivateAttr
from hera.workflows._meta_mixins import CallableTemplateMixin, ContextMixin
from hera.workflows._mixins import IOMixin, TemplateMixin
from hera.workflows.exceptions import InvalidType
from hera.workflows.exceptions import InvalidType, NodeNameConflict
from hera.workflows.models import (
DAGTask,
DAGTemplate as _ModelDAGTemplate,
Expand Down Expand Up @@ -44,11 +44,15 @@ class DAG(
target: Optional[str] = None
tasks: List[Union[Task, DAGTask]] = []

_node_names = PrivateAttr(default_factory=set)
_current_task_depends: Set[str] = PrivateAttr(set())

def _add_sub(self, node: Any):
if not isinstance(node, Task):
raise InvalidType(type(node))
if node.name in self._node_names:
raise NodeNameConflict(f"Found multiple Task nodes with name: {node.name}")
self._node_names.add(node.name)
self.tasks.append(node)

def _build_template(self) -> _ModelTemplate:
Expand Down
14 changes: 13 additions & 1 deletion src/hera/workflows/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,16 @@ class InvalidDispatchType(WorkflowsException):
...


__all__ = ["InvalidType", "InvalidTemplateCall", "InvalidDispatchType"]
class TemplateNameConflict(WorkflowsException):
"""Exception raised when multiple Templates are found with the same name.."""

...


class NodeNameConflict(WorkflowsException):
"""Exception raised when multiple Task/Step are found with the same name."""

...


__all__ = ["InvalidType", "InvalidTemplateCall", "InvalidDispatchType", "TemplateNameConflict", "NodeNameConflict"]
17 changes: 15 additions & 2 deletions src/hera/workflows/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Any, List, Optional, Union

from hera.shared._pydantic import PrivateAttr
from hera.workflows._context import _context
from hera.workflows._meta_mixins import CallableTemplateMixin, ContextMixin
from hera.workflows._mixins import (
Expand All @@ -17,7 +18,7 @@
TemplateInvocatorSubNodeMixin,
TemplateMixin,
)
from hera.workflows.exceptions import InvalidType
from hera.workflows.exceptions import InvalidType, NodeNameConflict
from hera.workflows.models import (
ParallelSteps,
Template as _ModelTemplate,
Expand Down Expand Up @@ -107,9 +108,14 @@ class Parallel(

sub_steps: List[Union[Step, _ModelWorkflowStep]] = []

_node_names = PrivateAttr(default_factory=set)

def _add_sub(self, node: Any):
if not isinstance(node, Step):
raise InvalidType(type(node))
if node.name in self._node_names:
raise NodeNameConflict(f"Found multiple Steps named: {node.name}")
self._node_names.add(node.name)
self.sub_steps.append(node)

def _build_step(self) -> List[_ModelWorkflowStep]:
Expand Down Expand Up @@ -140,6 +146,8 @@ class Steps(
* All Step objects initialised within a Parallel context will run in parallel.
"""

_node_names = PrivateAttr(default_factory=set)

sub_steps: List[
Union[
Step,
Expand Down Expand Up @@ -175,7 +183,12 @@ def _build_steps(self) -> Optional[List[ParallelSteps]]:
def _add_sub(self, node: Any):
if not isinstance(node, (Step, Parallel)):
raise InvalidType(type(node))

if isinstance(node, Step):
if node.name in self._node_names:
raise NodeNameConflict(f"Found multiple Step nodes with name: {node.name}")
self._node_names.add(node.name)
if isinstance(node, Parallel):
node._node_names = self._node_names
self.sub_steps.append(node)

def parallel(self) -> Parallel:
Expand Down
117 changes: 117 additions & 0 deletions tests/test_unit/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import pytest

from hera.workflows import DAG, Steps, WorkflowTemplate, script
from hera.workflows.exceptions import NodeNameConflict, TemplateNameConflict


class TestContextNameConflicts:
"""These tests ensure that template and node name conflicts raise Exceptions.
This should validate that no two templates have the same name,
and that no two Task/Step nodes have the same name.
"""

def test_conflict_on_templates_with_same_name(self):
"""Multiple templates can't have the same name."""
name = "name-of-dag-and-script"

@script(name=name)
def example():
print("hello")

with pytest.raises(TemplateNameConflict):
with WorkflowTemplate(name="my-workflow", entrypoint=name), DAG(name=name):
example()

with pytest.raises(TemplateNameConflict):
with WorkflowTemplate(
name="my-workflow",
entrypoint=name,
), Steps(name=name):
example()

def test_no_conflict_on_tasks_with_different_names_using_same_template(self):
"""Task nodes can have different names for the same script template."""
dag_name = "dag-name"
name_1 = "task-1"
name_2 = "task-2"

@script()
def example():
print("hello")

with WorkflowTemplate(
name="my-workflow",
entrypoint=dag_name,
), DAG(name=dag_name):
example(name=name_1)
example(name=name_2)

def test_no_conflict_on_dag_and_task_with_same_name(self):
"""Dag and task node can have the same name."""
name = "name-of-dag-and-task"

@script()
def example():
print("hello")

with WorkflowTemplate(
name="my-workflow",
entrypoint=name,
), DAG(name=name):
example(name=name) # task name same as dag template

with WorkflowTemplate(
name="my-workflow",
entrypoint=name,
), Steps(name=name):
example(name=name) # step name same as steps template

def test_conflict_on_multiple_tasks_with_same_name(self):
"""Dags cannot have two task nodes with the same name."""
name = "name-of-tasks"

@script()
def hello():
print("hello")

@script()
def world():
print("world")

with pytest.raises(NodeNameConflict):
with WorkflowTemplate(name="my-workflow", entrypoint="dag"), DAG(name="dag"):
hello(name=name)
world(name=name)

with pytest.raises(NodeNameConflict):
with WorkflowTemplate(name="my-workflow", entrypoint="steps"), Steps(name="steps"):
hello(name=name)
world(name=name)

with pytest.raises(NodeNameConflict):
with WorkflowTemplate(
name="my-workflow",
entrypoint="steps",
), Steps(name="steps") as s:
hello(name=name)
with s.parallel():
world(name=name)

with pytest.raises(NodeNameConflict):
with WorkflowTemplate(
name="my-workflow",
entrypoint="steps",
), Steps(name="steps") as s:
with s.parallel():
hello(name=name)
world(name=name)

with pytest.raises(NodeNameConflict):
with WorkflowTemplate(
name="my-workflow",
entrypoint="steps",
), Steps(name="steps") as s:
with s.parallel():
hello(name=name)
with s.parallel():
world(name=name)

0 comments on commit fe4d5c9

Please sign in to comment.