Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style[next]: more strict typing #1494

Merged
merged 9 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,16 @@ warn_unused_ignores = true

# GT4Py configs
[[tool.mypy.overrides]]
allow_incomplete_defs = false
allow_untyped_defs = false
ignore_missing_imports = false
module = 'gt4py.*'

[[tool.mypy.overrides]]
# The following ignore_errors are only temporary.
# TODO: Fix errors and enable these settings.
disallow_incomplete_defs = false
disallow_untyped_defs = false
allow_incomplete_defs = true
allow_untyped_defs = true
follow_imports = 'silent'
module = 'gt4py.cartesian.*'
warn_unused_ignores = false
Expand Down Expand Up @@ -186,10 +188,6 @@ module = 'gt4py.cartesian.frontend.defir_to_gtir'
ignore_errors = true
module = 'gt4py.cartesian.frontend.meta'

[[tool.mypy.overrides]]
disallow_untyped_defs = true
module = 'gt4py.eve.*'

[[tool.mypy.overrides]]
module = 'gt4py.eve.extended_typing'
warn_unused_ignores = false
Expand All @@ -202,14 +200,14 @@ module = 'gt4py.storage.*'
warn_unused_ignores = false

[[tool.mypy.overrides]]
# # TODO: this should be changed to true after a transition period
disallow_incomplete_defs = false
module = 'gt4py.next.*'
allow_incomplete_defs = true
allow_untyped_defs = true
module = 'gt4py.next.iterator.*'

[[tool.mypy.overrides]]
# TODO: temporarily to propagate it to all of next
disallow_incomplete_defs = true
module = 'gt4py.next.ffront.*'
allow_incomplete_defs = true
allow_untyped_defs = true
module = 'gt4py.next.program_processors.runners.dace_iterator.*'

[[tool.mypy.overrides]]
ignore_errors = true
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,7 +1666,7 @@ def __init__(self, definition, *, options, externals=None, dtypes=None):
self.block = None
self.dtypes = dtypes

def __str__(self):
def __str__(self) -> str:
result = "<GT4Py.GTScriptParser> {\n"
result += "\n".join("\t{}: {}".format(name, getattr(self, name)) for name in vars(self))
result += "\n}"
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/cartesian/frontend/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class LevelMarker(enum.Enum):
START = 0
END = -1

def __str__(self):
def __str__(self) -> str:
return self.name


Expand Down Expand Up @@ -251,7 +251,7 @@ def from_value(cls, value):

return result

def __str__(self):
def __str__(self) -> str:
return self.name


Expand All @@ -268,7 +268,7 @@ class DataType(enum.Enum):
FLOAT32 = 104
FLOAT64 = 108

def __str__(self):
def __str__(self) -> str:
return self.name

@property
Expand Down Expand Up @@ -665,7 +665,7 @@ def symbol(self):
elif self == self.FORWARD:
return "->"

def __str__(self):
def __str__(self) -> str:
return self.name

def __lshift__(self, steps: int):
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def from_dace_storage(cls, schedule):
class AxisBound(common.AxisBound):
axis: Axis

def __str__(self):
def __str__(self) -> str:
return get_axis_bound_str(self, self.axis.domain_symbol())

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/gtc/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Axis(enum.Enum):
J = 1
K = 2

def __str__(self):
def __str__(self) -> str:
return self.name

names = [ax.name for ax in Axis]
Expand Down Expand Up @@ -191,7 +191,7 @@ def __repr__(self):
def __hash__(self):
return tuple.__hash__(self)

def __str__(self):
def __str__(self) -> str:
return tuple.__repr__(self)

@property
Expand Down Expand Up @@ -384,7 +384,7 @@ def __repr__(self):
def __hash__(self):
return tuple.__hash__(self)

def __str__(self):
def __str__(self) -> str:
return tuple.__repr__(self)

@property
Expand Down
10 changes: 5 additions & 5 deletions src/gt4py/cartesian/gtscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def __repr__(self):
def __eq__(self, other):
return repr(self) == repr(other)

def __str__(self):
def __str__(self) -> str:
return f"{self.axis}[{self.index}] + {self.offset}"

def __add__(self, offset: int):
Expand Down Expand Up @@ -516,7 +516,7 @@ def __init__(self, axis: str, start: int, end: int):
def __repr__(self):
return f"AxisInterval(axis={self.axis}, start={self.start}, end={self.end})"

def __str__(self):
def __str__(self) -> str:
return f"{self.axis}[{self.start}:{self.end}]"

def __len__(self):
Expand All @@ -532,7 +532,7 @@ def __init__(self, name: str, shift: int):
def __repr__(self):
return f"ShiftedAxis(name={self.name}, shift={self.shift})"

def __str__(self):
def __str__(self) -> str:
return f"{self.name}+{self.shift}"

def __add__(self, shift):
Expand All @@ -559,7 +559,7 @@ def __gt_axis_name__(self) -> str:
def __repr__(self):
return f"Axis(name={self.name})"

def __str__(self):
def __str__(self) -> str:
return self.name

def __getitem__(self, interval):
Expand Down Expand Up @@ -654,7 +654,7 @@ def __repr__(self):
args = f"dtype={self.dtype!r}, axes={self.axes!r}, data_dims={self.data_dims!r}"
return f"_FieldDescriptor({args})"

def __str__(self):
def __str__(self) -> str:
return (
f"Field<[{', '.join(str(ax) for ax in self.axes)}], ({self.dtype}, {self.data_dims})>"
)
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/cartesian/utils/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, joiner, index):
self.joiner = joiner
self.index = index

def __str__(self):
def __str__(self) -> str:
return self.joiner.joiner_str if self.index < self.joiner.n_items - 1 else ""

def __init__(self, joiner_str):
Expand Down Expand Up @@ -138,5 +138,5 @@ def __iadd__(self, source_line):
def __len__(self):
return len(self.lines)

def __str__(self):
def __str__(self) -> str:
return self.text
17 changes: 17 additions & 0 deletions src/gt4py/eve/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,23 @@ def generic_dump(cls, node: RootNode, **kwargs: Any) -> str:
"""
return str(node)

@overload
havogt marked this conversation as resolved.
Show resolved Hide resolved
def generic_visit(self, node: Node, **kwargs: Any) -> str: ...

@overload
def generic_visit(
self,
node: Union[
list,
tuple,
collections.abc.Set,
collections.abc.Sequence,
dict,
collections.abc.Mapping,
],
**kwargs: Any,
) -> Collection[str]: ...

def generic_visit(self, node: RootNode, **kwargs: Any) -> Union[str, Collection[str]]:
if isinstance(node, Node):
template, key = self.get_template(node)
Expand Down
26 changes: 5 additions & 21 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class DimensionKind(StrEnum):
VERTICAL = "vertical"
LOCAL = "local"

def __str__(self):
def __str__(self) -> str:
return self.value


Expand All @@ -80,7 +80,7 @@ class Dimension:
value: str
kind: DimensionKind = dataclasses.field(default=DimensionKind.HORIZONTAL)

def __str__(self):
def __str__(self) -> str:
return f"{self.value}[{self.kind}]"

def __call__(self, val: int) -> NamedIndex:
Expand Down Expand Up @@ -641,7 +641,7 @@ def asnumpy(self) -> np.ndarray: ...
def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ...

@abc.abstractmethod
def restrict(self, item: AnyIndexSpec) -> Field: ...
def restrict(self, item: AnyIndexSpec) -> Self: ...

@abc.abstractmethod
def as_scalar(self) -> core_defs.ScalarT: ...
Expand All @@ -651,7 +651,7 @@ def as_scalar(self) -> core_defs.ScalarT: ...
def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ...

@abc.abstractmethod
def __getitem__(self, item: AnyIndexSpec) -> Field: ...
def __getitem__(self, item: AnyIndexSpec) -> Self: ...

@abc.abstractmethod
def __abs__(self) -> Field: ...
Expand Down Expand Up @@ -867,22 +867,6 @@ def _connectivity(
raise NotImplementedError


@dataclasses.dataclass(frozen=True)
class GTInfo:
definition: Any
ir: Any


@dataclasses.dataclass(frozen=True)
class Backend:
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

# TODO : proper definition and implementation
def generate_operator(self, ir):
return ir


@runtime_checkable
class Connectivity(Protocol):
max_neighbors: int
Expand Down Expand Up @@ -1083,7 +1067,7 @@ class FieldBuiltinFuncRegistry:
collections.ChainMap()
)

def __init_subclass__(cls, **kwargs):
def __init_subclass__(cls, **kwargs: Any) -> None:
cls._builtin_func_map = collections.ChainMap(
{}, # New empty `dict` for new registrations on this class
*[
Expand Down
9 changes: 6 additions & 3 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import functools
import itertools
import operator
from collections.abc import Iterator, Sequence

from gt4py.eve.extended_typing import Any, Optional, Sequence, cast
from gt4py.eve.extended_typing import Any, Optional, cast
from gt4py.next import common
from gt4py.next.embedded import exceptions as embedded_exceptions

Expand Down Expand Up @@ -148,9 +149,11 @@ def restrict_to_intersection(
)


def iterate_domain(domain: common.Domain):
def iterate_domain(
domain: common.Domain,
) -> Iterator[tuple[tuple[common.Dimension, int]]]:
for i in itertools.product(*[list(r) for r in domain.ranges]):
yield tuple(zip(domain.dims, i))
yield tuple(zip(domain.dims, i)) # type: ignore[misc] # trust me, `i` is `tuple[int, ...]`


def _expand_ellipsis(
Expand Down
5 changes: 3 additions & 2 deletions src/gt4py/next/embedded/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import contextlib
import contextvars as cvars
from collections.abc import Generator
from typing import Any

import gt4py.eve as eve
Expand All @@ -39,7 +40,7 @@ def new_context(
*,
closure_column_range: common.NamedRange | eve.NothingType = eve.NOTHING,
offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING,
):
) -> Generator[cvars.Context, None, None]:
import gt4py.next.embedded.context as this_module

updates: list[tuple[cvars.ContextVar[Any], Any]] = []
Expand All @@ -51,7 +52,7 @@ def new_context(
# Create new context with provided values
ctx = cvars.copy_context()

def ctx_updater(*args):
def ctx_updater(*args: tuple[cvars.ContextVar[Any], Any]) -> None:
for cvar, value in args:
cvar.set(value)

Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _get_nd_array_class(*fields: common.Field | core_defs.Scalar) -> type[NdArra


def _make_builtin(
builtin_name: str, array_builtin_name: str, reverse=False
builtin_name: str, array_builtin_name: str, reverse: bool = False
) -> Callable[..., NdArrayField]:
def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
cls_ = _get_nd_array_class(*fields)
Expand Down Expand Up @@ -228,7 +228,7 @@ def remap(

__call__ = remap # type: ignore[assignment]

def restrict(self, index: common.AnyIndexSpec) -> common.Field:
def restrict(self, index: common.AnyIndexSpec) -> NdArrayField:
new_domain, buffer_slice = self._slice(index)
new_buffer = self.ndarray[buffer_slice]
new_buffer = self.__class__.array_ns.asarray(new_buffer)
Expand Down Expand Up @@ -435,7 +435,7 @@ def inverse_image(

return new_dims

def restrict(self, index: common.AnyIndexSpec) -> common.Field:
def restrict(self, index: common.AnyIndexSpec) -> NdArrayConnectivityField:
cache_key = (id(self.ndarray), self.domain, index)

if (restricted_connectivity := self._cache.get(cache_key, None)) is None:
Expand Down
Loading
Loading