Skip to content

Commit

Permalink
fix: Improve linting/editor experience for hera models.
Browse files Browse the repository at this point in the history
Signed-off-by: DanCardin <ddcardin@gmail.com>
  • Loading branch information
DanCardin committed Jan 31, 2024
1 parent 77c50e7 commit b5f0c8e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 19 deletions.
6 changes: 3 additions & 3 deletions src/hera/shared/_global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union

from hera.auth import TokenGenerator
from hera.shared._pydantic import BaseModel, root_validator
from hera.shared._pydantic import BaseModel, get_fields, root_validator

TBase = TypeVar("TBase", bound="BaseMixin")
TypeTBase = Type[TBase]
Expand Down Expand Up @@ -119,7 +119,7 @@ def set_class_defaults(self, cls: Type[TBase], **kwargs: Any) -> None:
cls: The class to set defaults for.
kwargs: The default values to set.
"""
invalid_keys = set(kwargs) - set(cls.__fields__)
invalid_keys = set(kwargs) - set(get_fields(cls))
if invalid_keys:
raise ValueError(f"Invalid keys for class {cls}: {invalid_keys}")
self._defaults[cls].update(kwargs)
Expand All @@ -143,7 +143,7 @@ def _init_private_attributes(self):
this method. We also tried other ways including creating a metaclass that invokes hera_init after init,
but that always broke auto-complete for IDEs like VSCode.
"""
super()._init_private_attributes()
super()._init_private_attributes() # type: ignore
self.__hera_init__()

def __hera_init__(self):
Expand Down
24 changes: 21 additions & 3 deletions src/hera/shared/_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Module that holds the underlying base Pydantic models for Hera objects."""

from typing import Literal
from typing import TYPE_CHECKING, Any, Dict, Literal, Type

_PYDANTIC_VERSION: Literal[1, 2] = 1
# The pydantic v1 interface is used for both pydantic v1 and v2 in order to support
# users across both versions.

try:
from pydantic.v1 import ( # type: ignore
BaseModel as PydanticBaseModel,
Field,
ValidationError,
root_validator,
Expand All @@ -18,7 +17,6 @@
_PYDANTIC_VERSION = 2
except (ImportError, ModuleNotFoundError):
from pydantic import ( # type: ignore[assignment,no-redef]
BaseModel as PydanticBaseModel,
Field,
ValidationError,
root_validator,
Expand All @@ -28,6 +26,26 @@
_PYDANTIC_VERSION = 1


# TYPE_CHECKING-guarding specifically the `BaseModel` import helps the type checkers
# provide proper type checking to models. Without this, both mypy and pyright lose
# native pydantic hinting for `__init__` arguments.
if TYPE_CHECKING:
from pydantic import BaseModel as PydanticBaseModel
else:
try:
from pydantic.v1 import BaseModel as PydanticBaseModel # type: ignore
except (ImportError, ModuleNotFoundError):
from pydantic import BaseModel as PydanticBaseModel # type: ignore[assignment,no-redef]


def get_fields(cls: Type[PydanticBaseModel]) -> Dict[str, Any]:
"""Centralize access to __fields__."""
try:
return cls.model_fields # type: ignore
except AttributeError:
return cls.__fields__ # type: ignore


__all__ = [
"BaseModel",
"Field",
Expand Down
7 changes: 4 additions & 3 deletions src/hera/workflows/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from typing_extensions import Annotated, get_args, get_origin # type: ignore

from hera.shared import BaseMixin, global_config
from hera.shared._pydantic import BaseModel, root_validator, validator
from hera.shared._pydantic import BaseModel, get_fields, root_validator, validator
from hera.shared.serialization import serialize
from hera.workflows._context import SubNodeMixin, _context
from hera.workflows.artifact import Artifact
Expand Down Expand Up @@ -1211,9 +1211,10 @@ def __init__(self, model_path: str, hera_builder: Optional[Callable] = None):
self.model_path = model_path.split(".")
curr_class: Type[BaseModel] = self._get_model_class()
for key in self.model_path:
if key not in curr_class.__fields__:
fields = get_fields(curr_class)
if key not in fields:
raise ValueError(f"Model key '{key}' does not exist in class {curr_class}")
curr_class = curr_class.__fields__[key].outer_type_
curr_class = fields[key].outer_type_

@classmethod
def _get_model_class(cls) -> Type[BaseModel]:
Expand Down
27 changes: 17 additions & 10 deletions src/hera/workflows/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import ChainMap
from typing import Any, List, Optional, Union

from hera.shared._pydantic import BaseModel
from hera.shared._pydantic import BaseModel, get_fields
from hera.shared.serialization import serialize
from hera.workflows.artifact import Artifact
from hera.workflows.parameter import Parameter
Expand Down Expand Up @@ -32,30 +32,36 @@ def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> Lis
parameters = []
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

for field in cls.__fields__:
fields = get_fields(cls)
for field in fields:
if get_origin(annotations[field]) is Annotated:
if isinstance(get_args(annotations[field])[1], Parameter):
param = get_args(annotations[field])[1]
if object_override:
param.default = serialize(getattr(object_override, field))
elif cls.__fields__[field].default:
elif fields[field].default:
# Serialize the value (usually done in Parameter's validator)
param.default = serialize(cls.__fields__[field].default)
param.default = serialize(fields[field].default)
parameters.append(param)
else:
# Create a Parameter from basic type annotations
if object_override:
parameters.append(Parameter(name=field, default=serialize(getattr(object_override, field))))
parameters.append(
Parameter(
name=field,
default=serialize(getattr(object_override, field)),
)
)
else:
parameters.append(Parameter(name=field, default=cls.__fields__[field].default))
parameters.append(Parameter(name=field, default=fields[field].default))
return parameters

@classmethod
def _get_artifacts(cls) -> List[Artifact]:
artifacts = []
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

for field in cls.__fields__:
for field in get_fields(cls):
if get_origin(annotations[field]) is Annotated:
if isinstance(get_args(annotations[field])[1], Artifact):
artifact = get_args(annotations[field])[1]
Expand All @@ -82,15 +88,16 @@ def _get_outputs(cls) -> List[Union[Artifact, Parameter]]:
outputs = []
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

for field in cls.__fields__:
fields = get_fields(cls)
for field in fields:
if field in {"exit_code", "result"}:
continue
if get_origin(annotations[field]) is Annotated:
if isinstance(get_args(annotations[field])[1], (Parameter, Artifact)):
outputs.append(get_args(annotations[field])[1])
else:
# Create a Parameter from basic type annotations
outputs.append(Parameter(name=field, default=cls.__fields__[field].default))
outputs.append(Parameter(name=field, default=fields[field].default))
return outputs

@classmethod
Expand All @@ -102,4 +109,4 @@ def _get_output(cls, field_name: str) -> Union[Artifact, Parameter]:
return get_args(annotation)[1]

# Create a Parameter from basic type annotations
return Parameter(name=field_name, default=cls.__fields__[field_name].default)
return Parameter(name=field_name, default=get_fields(cls)[field_name].default)

0 comments on commit b5f0c8e

Please sign in to comment.