Skip to content

Commit

Permalink
style[next]: more strict typing (#1494)
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Mar 19, 2024
1 parent ee17241 commit b26e6a3
Show file tree
Hide file tree
Showing 43 changed files with 264 additions and 208 deletions.
22 changes: 10 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,16 @@ warn_unused_ignores = true

# GT4Py configs
[[tool.mypy.overrides]]
allow_incomplete_defs = false
allow_untyped_defs = false
ignore_missing_imports = false
module = 'gt4py.*'

[[tool.mypy.overrides]]
# The following ignore_errors are only temporary.
# TODO: Fix errors and enable these settings.
disallow_incomplete_defs = false
disallow_untyped_defs = false
allow_incomplete_defs = true
allow_untyped_defs = true
follow_imports = 'silent'
module = 'gt4py.cartesian.*'
warn_unused_ignores = false
Expand Down Expand Up @@ -186,10 +188,6 @@ module = 'gt4py.cartesian.frontend.defir_to_gtir'
ignore_errors = true
module = 'gt4py.cartesian.frontend.meta'

[[tool.mypy.overrides]]
disallow_untyped_defs = true
module = 'gt4py.eve.*'

[[tool.mypy.overrides]]
module = 'gt4py.eve.extended_typing'
warn_unused_ignores = false
Expand All @@ -202,14 +200,14 @@ module = 'gt4py.storage.*'
warn_unused_ignores = false

[[tool.mypy.overrides]]
# # TODO: this should be changed to true after a transition period
disallow_incomplete_defs = false
module = 'gt4py.next.*'
allow_incomplete_defs = true
allow_untyped_defs = true
module = 'gt4py.next.iterator.*'

[[tool.mypy.overrides]]
# TODO: temporarily to propagate it to all of next
disallow_incomplete_defs = true
module = 'gt4py.next.ffront.*'
allow_incomplete_defs = true
allow_untyped_defs = true
module = 'gt4py.next.program_processors.runners.dace_iterator.*'

[[tool.mypy.overrides]]
ignore_errors = true
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,7 +1666,7 @@ def __init__(self, definition, *, options, externals=None, dtypes=None):
self.block = None
self.dtypes = dtypes

def __str__(self):
def __str__(self) -> str:
result = "<GT4Py.GTScriptParser> {\n"
result += "\n".join("\t{}: {}".format(name, getattr(self, name)) for name in vars(self))
result += "\n}"
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/cartesian/frontend/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class LevelMarker(enum.Enum):
START = 0
END = -1

def __str__(self):
def __str__(self) -> str:
return self.name


Expand Down Expand Up @@ -251,7 +251,7 @@ def from_value(cls, value):

return result

def __str__(self):
def __str__(self) -> str:
return self.name


Expand All @@ -268,7 +268,7 @@ class DataType(enum.Enum):
FLOAT32 = 104
FLOAT64 = 108

def __str__(self):
def __str__(self) -> str:
return self.name

@property
Expand Down Expand Up @@ -665,7 +665,7 @@ def symbol(self):
elif self == self.FORWARD:
return "->"

def __str__(self):
def __str__(self) -> str:
return self.name

def __lshift__(self, steps: int):
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def from_dace_storage(cls, schedule):
class AxisBound(common.AxisBound):
axis: Axis

def __str__(self):
def __str__(self) -> str:
return get_axis_bound_str(self, self.axis.domain_symbol())

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/gtc/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Axis(enum.Enum):
J = 1
K = 2

def __str__(self):
def __str__(self) -> str:
return self.name

names = [ax.name for ax in Axis]
Expand Down Expand Up @@ -191,7 +191,7 @@ def __repr__(self):
def __hash__(self):
return tuple.__hash__(self)

def __str__(self):
def __str__(self) -> str:
return tuple.__repr__(self)

@property
Expand Down Expand Up @@ -384,7 +384,7 @@ def __repr__(self):
def __hash__(self):
return tuple.__hash__(self)

def __str__(self):
def __str__(self) -> str:
return tuple.__repr__(self)

@property
Expand Down
10 changes: 5 additions & 5 deletions src/gt4py/cartesian/gtscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def __repr__(self):
def __eq__(self, other):
return repr(self) == repr(other)

def __str__(self):
def __str__(self) -> str:
return f"{self.axis}[{self.index}] + {self.offset}"

def __add__(self, offset: int):
Expand Down Expand Up @@ -516,7 +516,7 @@ def __init__(self, axis: str, start: int, end: int):
def __repr__(self):
return f"AxisInterval(axis={self.axis}, start={self.start}, end={self.end})"

def __str__(self):
def __str__(self) -> str:
return f"{self.axis}[{self.start}:{self.end}]"

def __len__(self):
Expand All @@ -532,7 +532,7 @@ def __init__(self, name: str, shift: int):
def __repr__(self):
return f"ShiftedAxis(name={self.name}, shift={self.shift})"

def __str__(self):
def __str__(self) -> str:
return f"{self.name}+{self.shift}"

def __add__(self, shift):
Expand All @@ -559,7 +559,7 @@ def __gt_axis_name__(self) -> str:
def __repr__(self):
return f"Axis(name={self.name})"

def __str__(self):
def __str__(self) -> str:
return self.name

def __getitem__(self, interval):
Expand Down Expand Up @@ -654,7 +654,7 @@ def __repr__(self):
args = f"dtype={self.dtype!r}, axes={self.axes!r}, data_dims={self.data_dims!r}"
return f"_FieldDescriptor({args})"

def __str__(self):
def __str__(self) -> str:
return (
f"Field<[{', '.join(str(ax) for ax in self.axes)}], ({self.dtype}, {self.data_dims})>"
)
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/cartesian/utils/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, joiner, index):
self.joiner = joiner
self.index = index

def __str__(self):
def __str__(self) -> str:
return self.joiner.joiner_str if self.index < self.joiner.n_items - 1 else ""

def __init__(self, joiner_str):
Expand Down Expand Up @@ -138,5 +138,5 @@ def __iadd__(self, source_line):
def __len__(self):
return len(self.lines)

def __str__(self):
def __str__(self) -> str:
return self.text
17 changes: 17 additions & 0 deletions src/gt4py/eve/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,23 @@ def generic_dump(cls, node: RootNode, **kwargs: Any) -> str:
"""
return str(node)

@overload
def generic_visit(self, node: Node, **kwargs: Any) -> str: ...

@overload
def generic_visit(
self,
node: Union[
list,
tuple,
collections.abc.Set,
collections.abc.Sequence,
dict,
collections.abc.Mapping,
],
**kwargs: Any,
) -> Collection[str]: ...

def generic_visit(self, node: RootNode, **kwargs: Any) -> Union[str, Collection[str]]:
if isinstance(node, Node):
template, key = self.get_template(node)
Expand Down
26 changes: 5 additions & 21 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class DimensionKind(StrEnum):
VERTICAL = "vertical"
LOCAL = "local"

def __str__(self):
def __str__(self) -> str:
return self.value


Expand All @@ -80,7 +80,7 @@ class Dimension:
value: str
kind: DimensionKind = dataclasses.field(default=DimensionKind.HORIZONTAL)

def __str__(self):
def __str__(self) -> str:
return f"{self.value}[{self.kind}]"

def __call__(self, val: int) -> NamedIndex:
Expand Down Expand Up @@ -641,7 +641,7 @@ def asnumpy(self) -> np.ndarray: ...
def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ...

@abc.abstractmethod
def restrict(self, item: AnyIndexSpec) -> Field: ...
def restrict(self, item: AnyIndexSpec) -> Self: ...

@abc.abstractmethod
def as_scalar(self) -> core_defs.ScalarT: ...
Expand All @@ -651,7 +651,7 @@ def as_scalar(self) -> core_defs.ScalarT: ...
def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ...

@abc.abstractmethod
def __getitem__(self, item: AnyIndexSpec) -> Field: ...
def __getitem__(self, item: AnyIndexSpec) -> Self: ...

@abc.abstractmethod
def __abs__(self) -> Field: ...
Expand Down Expand Up @@ -867,22 +867,6 @@ def _connectivity(
raise NotImplementedError


@dataclasses.dataclass(frozen=True)
class GTInfo:
definition: Any
ir: Any


@dataclasses.dataclass(frozen=True)
class Backend:
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

# TODO : proper definition and implementation
def generate_operator(self, ir):
return ir


@runtime_checkable
class Connectivity(Protocol):
max_neighbors: int
Expand Down Expand Up @@ -1083,7 +1067,7 @@ class FieldBuiltinFuncRegistry:
collections.ChainMap()
)

def __init_subclass__(cls, **kwargs):
def __init_subclass__(cls, **kwargs: Any) -> None:
cls._builtin_func_map = collections.ChainMap(
{}, # New empty `dict` for new registrations on this class
*[
Expand Down
9 changes: 6 additions & 3 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import functools
import itertools
import operator
from collections.abc import Iterator, Sequence

from gt4py.eve.extended_typing import Any, Optional, Sequence, cast
from gt4py.eve.extended_typing import Any, Optional, cast
from gt4py.next import common
from gt4py.next.embedded import exceptions as embedded_exceptions

Expand Down Expand Up @@ -148,9 +149,11 @@ def restrict_to_intersection(
)


def iterate_domain(domain: common.Domain):
def iterate_domain(
domain: common.Domain,
) -> Iterator[tuple[tuple[common.Dimension, int]]]:
for i in itertools.product(*[list(r) for r in domain.ranges]):
yield tuple(zip(domain.dims, i))
yield tuple(zip(domain.dims, i)) # type: ignore[misc] # trust me, `i` is `tuple[int, ...]`


def _expand_ellipsis(
Expand Down
5 changes: 3 additions & 2 deletions src/gt4py/next/embedded/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import contextlib
import contextvars as cvars
from collections.abc import Generator
from typing import Any

import gt4py.eve as eve
Expand All @@ -39,7 +40,7 @@ def new_context(
*,
closure_column_range: common.NamedRange | eve.NothingType = eve.NOTHING,
offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING,
):
) -> Generator[cvars.Context, None, None]:
import gt4py.next.embedded.context as this_module

updates: list[tuple[cvars.ContextVar[Any], Any]] = []
Expand All @@ -51,7 +52,7 @@ def new_context(
# Create new context with provided values
ctx = cvars.copy_context()

def ctx_updater(*args):
def ctx_updater(*args: tuple[cvars.ContextVar[Any], Any]) -> None:
for cvar, value in args:
cvar.set(value)

Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _get_nd_array_class(*fields: common.Field | core_defs.Scalar) -> type[NdArra


def _make_builtin(
builtin_name: str, array_builtin_name: str, reverse=False
builtin_name: str, array_builtin_name: str, reverse: bool = False
) -> Callable[..., NdArrayField]:
def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
cls_ = _get_nd_array_class(*fields)
Expand Down Expand Up @@ -228,7 +228,7 @@ def remap(

__call__ = remap # type: ignore[assignment]

def restrict(self, index: common.AnyIndexSpec) -> common.Field:
def restrict(self, index: common.AnyIndexSpec) -> NdArrayField:
new_domain, buffer_slice = self._slice(index)
new_buffer = self.ndarray[buffer_slice]
new_buffer = self.__class__.array_ns.asarray(new_buffer)
Expand Down Expand Up @@ -435,7 +435,7 @@ def inverse_image(

return new_dims

def restrict(self, index: common.AnyIndexSpec) -> common.Field:
def restrict(self, index: common.AnyIndexSpec) -> NdArrayConnectivityField:
cache_key = (id(self.ndarray), self.domain, index)

if (restricted_connectivity := self._cache.get(cache_key, None)) is None:
Expand Down
Loading

0 comments on commit b26e6a3

Please sign in to comment.