Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate foast & past typing information to ITIR #1199

Merged
Merged
5 changes: 5 additions & 0 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ def _visit_assign(self, node: foast.Assign, **kwargs) -> tuple[itir.Sym, itir.Ex
return sym, expr

def visit_Symbol(self, node: foast.Symbol, **kwargs) -> itir.Sym:
# TODO(tehrengruber): extend to more types
if isinstance(node.type, ts.FieldType):
kind = "Iterator"
dtype = node.type.dtype.kind.name.lower()
return itir.Sym(id=node.id, kind=kind, dtype=dtype)
return im.sym(node.id)

def visit_Name(self, node: foast.Name, **kwargs) -> itir.SymRef:
Expand Down
5 changes: 5 additions & 0 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ def visit_Name(self, node: past.Name, **kwargs) -> itir.SymRef:
return itir.SymRef(id=node.id)

def visit_Symbol(self, node: past.Symbol, **kwargs) -> itir.Sym:
# TODO(tehrengruber): extend to more types
if isinstance(node.type, ts.FieldType):
kind = "Iterator"
dtype = node.type.dtype.kind.name.lower()
return itir.Sym(id=node.id, kind=kind, dtype=dtype)
return itir.Sym(id=node.id)

def visit_BinOp(self, node: past.BinOp, **kwargs) -> itir.FunCall:
Expand Down
19 changes: 18 additions & 1 deletion src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from typing import ClassVar, List, Union
import typing
from typing import ClassVar, List, Optional, Union

import gt4py.eve as eve
from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels
Expand All @@ -38,6 +39,22 @@ def __hash__(self) -> int:

class Sym(Node): # helper
id: Coerced[SymbolName] # noqa: A003
# TODO(tehrengruber): Revisit. Using strings is a workaround to avoid coupling with the
# type inference.
kind: Optional[typing.Literal["Iterator", "Value"]] = None
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
dtype: Optional[str] = None

@datamodels.validator("dtype")
def _kind_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str):
if value and value not in ["Iterator", "Value"]:
raise ValueError(f"Invalid kind `{value}`, must be one of `Iterator`, `Value`.")

@datamodels.validator("dtype")
def _dtype_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str):
if value and value not in TYPEBUILTINS:
raise ValueError(
f"Invalid dtype `{value}`, must be one of `{'`, `'.join(TYPEBUILTINS)}`."
)
egparedes marked this conversation as resolved.
Show resolved Hide resolved


@noninstantiable
Expand Down
20 changes: 19 additions & 1 deletion src/gt4py/next/iterator/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,25 @@ def visit(self, node, **kwargs) -> typing.Any:
return result

def visit_Sym(self, node: ir.Sym, **kwargs) -> Type:
return TypeVar.fresh()
result = TypeVar.fresh()
if node.kind:
kind = {"Iterator": Iterator(), "Value": Value()}[node.kind]
Comment on lines +561 to +562
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not very important, but I think it would better to use explicit pattern matching:

Suggested change
if node.kind:
kind = {"Iterator": Iterator(), "Value": Value()}[node.kind]
match node.kind:
case "Iterator":
kind = Iterator()
case "Value":
kind = Value()
case _:
raise AssertionError(...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the point, but I like the less verbose version more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, for just two cases, I don't think the dict version brings any value. Even a simple if-else would be better.

self.constraints.add(
(Val(kind=kind, current_loc=TypeVar.fresh(), defined_loc=TypeVar.fresh()), result)
)
if node.dtype:
assert node.dtype in ir.TYPEBUILTINS
self.constraints.add(
(
Val(
dtype=Primitive(name=node.dtype),
current_loc=TypeVar.fresh(),
defined_loc=TypeVar.fresh(),
),
result,
)
)
return result

def visit_SymRef(self, node: ir.SymRef, *, symtable, **kwargs) -> Type:
if node.id in ir.BUILTINS:
Expand Down
20 changes: 20 additions & 0 deletions tests/next_tests/iterator_tests/test_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,26 @@ def test_lambda():
assert ti.pformat(inferred) == "(T₀) → T₀"


def test_typed_lambda():
testee = ir.Lambda(
params=[ir.Sym(id="x", kind="Iterator", dtype="float")], expr=ir.SymRef(id="x")
)
expected_val = ti.Val(
kind=ti.Iterator(),
dtype=ti.Primitive(name="float"),
size=ti.TypeVar(idx=0),
current_loc=ti.TypeVar(idx=1),
defined_loc=ti.TypeVar(idx=2),
)
expected = ti.FunctionType(
args=ti.Tuple.from_elems(expected_val),
ret=expected_val,
)
inferred = ti.infer(testee)
assert inferred == expected
assert ti.pformat(inferred) == "(It[T₁, T₂, float⁰]) → It[T₁, T₂, float⁰]"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is being tested here that hasn't been in the first assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea of @fthaler was to test the type pretty printer like this. Additionally this gives a more readable representation of the inferred type.



def test_plus():
testee = ir.SymRef(id="plus")
t = ti.Val(kind=ti.Value(), dtype=ti.TypeVar(idx=0), size=ti.TypeVar(idx=1))
Expand Down