Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Remove int type from FOAST, PAST, ITIR #1255

Merged
merged 39 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e83e183
Remove int type (without size) from FOAST, PAST, ITIR
tehrengruber May 22, 2023
3b12e66
Merge origin/main
tehrengruber May 22, 2023
7b40a51
Cleanup
tehrengruber May 22, 2023
7caa673
Fix type inference
tehrengruber May 22, 2023
6844afb
Small fixes
tehrengruber May 22, 2023
ce3956f
Fix type inference tests
tehrengruber May 22, 2023
90132c5
Fix doctests
tehrengruber May 22, 2023
3f72f87
Fix tests
tehrengruber May 22, 2023
6e71f9f
Fix format
tehrengruber May 22, 2023
d9c409d
Fix tests
tehrengruber May 22, 2023
d66c64f
Fix failing tests
tehrengruber May 24, 2023
5a4e5af
Merge remote-tracking branch 'origin/main' into remove_int_type
tehrengruber May 24, 2023
4f70fb6
Change default datatype for tests to float
tehrengruber May 30, 2023
7aebcc2
Fix format
tehrengruber May 30, 2023
62d04b4
Change default dtype in testing to float for unstructured
tehrengruber May 31, 2023
9f03c43
Remove IJKIntField type alias
tehrengruber May 31, 2023
1e9e191
Merge branch 'change_default_case_to_float' into remove_int_type
tehrengruber May 31, 2023
a3fce3a
Fix broken tests
tehrengruber May 31, 2023
cdb33ba
Fix doctest
tehrengruber May 31, 2023
8ad7d90
Merge remote-tracking branch 'origin/main' into remove_int_type
tehrengruber May 31, 2023
b5f05d1
Fix failing tests
tehrengruber May 31, 2023
25c55ad
Fix failing tests
tehrengruber May 31, 2023
bb9ed6e
Merge origin/main
tehrengruber Jun 12, 2023
24d6c7a
Address review comments
tehrengruber Jun 12, 2023
d024d91
Merge remote-tracking branch 'origin/main' into change_default_case_t…
tehrengruber Jun 12, 2023
c4aab5b
Merge origin/main
tehrengruber Jun 12, 2023
3b9c47a
Use int32 as default data type
tehrengruber Jun 12, 2023
c239bf9
Merge origin_tehrengruber/change_default_case_to_float
tehrengruber Jun 12, 2023
8f85deb
Fix tests
tehrengruber Jun 12, 2023
a7155a8
Fix tests
tehrengruber Jun 12, 2023
46d588e
Fix tests
tehrengruber Jun 12, 2023
8581783
Fix typo
tehrengruber Jun 12, 2023
0c6f399
Fix tests
tehrengruber Jun 13, 2023
47d82b6
Fix typo
tehrengruber Jun 13, 2023
efd8baa
Merge origin/main
tehrengruber Jun 14, 2023
97f3bc0
Small fix
tehrengruber Jun 14, 2023
5c7a76a
Small fix
tehrengruber Jun 14, 2023
8246e8f
Small fix
tehrengruber Jun 14, 2023
2b18b1a
Cleanup
tehrengruber Jun 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/gt4py/next/ffront/foast_pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,15 @@ def pretty_format(node: foast.LocatedNode) -> str:
Pretty print (to string) an `foast.LocatedNode`.

>>> from gt4py.next.common import Field, Dimension
>>> from gt4py.next.ffront.fbuiltins import float64
>>> from gt4py.next.ffront.decorator import field_operator
>>> IDim = Dimension("IDim")
>>> @field_operator
... def field_op(a: Field[[IDim], int]) -> Field[[IDim], int]:
... return a+1
... def field_op(a: Field[[IDim], float64]) -> Field[[IDim], float64]:
... return a + 1.0
>>> print(pretty_format(field_op.foast_node))
@field_operator
def field_op(a: Field[[IDim], int64]) -> Field[[IDim], int64]:
return a + 1
def field_op(a: Field[[IDim], float64]) -> Field[[IDim], float64]:
return a + 1.0
"""
return _PrettyPrinter().apply(node)
9 changes: 3 additions & 6 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,8 @@ def _visit_slice_bound(
if slice_bound is None:
lowered_bound = default_value
elif isinstance(slice_bound, past.Constant):
assert (
isinstance(slice_bound.type, ts.ScalarType)
and slice_bound.type.kind == ts.ScalarKind.INT
assert isinstance(slice_bound.type, ts.ScalarType) and type_info.is_integral(
slice_bound.type
)
if slice_bound.value < 0:
lowered_bound = itir.FunCall(
Expand Down Expand Up @@ -219,7 +218,7 @@ def _construct_itir_domain_arg(
else:
lower = self._visit_slice_bound(
slices[dim_i].lower if slices else None,
itir.Literal(value="0", type="int"),
itir.Literal(value="0", type=itir.INTEGER_INDEX_BUILTIN),
dim_size,
)
upper = self._visit_slice_bound(
Expand Down Expand Up @@ -338,8 +337,6 @@ def visit_Constant(self, node: past.Constant, **kwargs) -> itir.Literal:
f"Scalars of kind {node.type.kind} not supported currently."
)
typename = node.type.kind.name.lower()
if typename.startswith("int"):
typename = "int"
return itir.Literal(value=str(node.value), type=typename)

raise NotImplementedError("Only scalar literals supported currently.")
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def bool(*args): # noqa: A001
}
UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"}
BINARY_MATH_NUMBER_BUILTINS = {"minimum", "maximum", "fmod", "power"}
TYPEBUILTINS = {"int", "int32", "int64", "float", "float32", "float64", "bool"}
TYPEBUILTINS = {"int32", "int64", "float32", "float64", "bool"}
MATH_BUILTINS = (
UNARY_MATH_NUMBER_BUILTINS
| UNARY_MATH_FP_BUILTINS
Expand Down
6 changes: 4 additions & 2 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,10 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib
"xor_",
}

INTEGER_BUILTINS = {"int", "int32", "int64"}
FLOATING_POINT_BUILTINS = {"float", "float32", "float64"}
#: builtin / dtype used to construct integer indices, like domain bounds
INTEGER_INDEX_BUILTIN = "int32"
INTEGER_BUILTINS = {"int32", "int64"}
FLOATING_POINT_BUILTINS = {"float32", "float64"}
TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"}

BUILTINS = {
Expand Down
40 changes: 32 additions & 8 deletions src/gt4py/next/iterator/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from typing import Callable, Union

import numpy as np

from gt4py.next.iterator import ir as itir


Expand Down Expand Up @@ -61,17 +63,15 @@ def ensure_expr(literal_or_expr: Union[str, int, itir.Expr]) -> itir.Expr:
SymRef(id=SymbolRef('a'))

>>> ensure_expr(3)
Literal(value='3', type='int')
Literal(value='3', type='int32')

>>> ensure_expr(itir.OffsetLiteral(value="i"))
OffsetLiteral(value='i')
"""
if isinstance(literal_or_expr, str):
return ref(literal_or_expr)
elif isinstance(literal_or_expr, int):
return itir.Literal(value=str(literal_or_expr), type="int")
elif isinstance(literal_or_expr, float):
return itir.Literal(value=str(literal_or_expr), type="float")
elif isinstance(literal_or_expr, (int, float, bool)):
return literal(literal_or_expr)
return literal_or_expr


Expand Down Expand Up @@ -116,7 +116,7 @@ class call:
Examples
--------
>>> call("plus")(1, 1)
FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type='int'), Literal(value='1', type='int')])
FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type='int32'), Literal(value='1', type='int32')])
"""

def __init__(self, expr):
Expand Down Expand Up @@ -220,7 +220,7 @@ def make_tuple(*args):

def tuple_get(index, tuple_expr):
"""Create a tuple_get FunCall, shorthand for ``call("tuple_get")(index, tuple_expr)``."""
return call("tuple_get")(index, tuple_expr)
return call("tuple_get")(literal(index, itir.INTEGER_INDEX_BUILTIN), tuple_expr)


def if_(cond, true_val, false_val):
Expand Down Expand Up @@ -271,7 +271,31 @@ def shift(offset, value=None):
return call(call("shift")(*args))


def literal(value: str, typename: str):
def literal(value: str | bool | int | float, typename: str | None = None):
if isinstance(value, str):
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
if typename is None:
raise ValueError("Argument `typename` mandatory for `value` of type string.")
elif isinstance(value, bool):
assert typename in [None, "bool"]
typename = "bool"
value = str(value)
elif isinstance(value, int):
if np.iinfo(np.int32).min <= value <= np.iinfo(np.int32).max:
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
typename = "int32"
elif np.iinfo(np.int64).min <= value <= np.iinfo(np.int64).max:
typename = "int64"
else:
raise ValueError(
f"Value `{value}` is out of range to be representable as `int32` or `int64`."
)
value = str(value)
elif isinstance(value, float):
assert typename in [None, "float64"]
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
typename = "float64"
value = str(value)
else:
raise ValueError("Invalid argument.")

return itir.Literal(value=value, type=typename)


Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/pretty_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from lark import lark, lexer as lark_lexer, visitors as lark_visitors

from gt4py.next.iterator import ir
from gt4py.next.iterator import ir, ir_makers as im


GRAMMAR = """
Expand Down Expand Up @@ -98,10 +98,10 @@ def SYM_REF(self, value: lark_lexer.Token) -> Union[ir.SymRef, ir.Literal]:
return ir.SymRef(id=value.value)

def INT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal:
return ir.Literal(value=value.value, type="int")
return im.literal(int(value))
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved

def FLOAT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal:
return ir.Literal(value=value.value, type="float")
return ir.Literal(value=value.value, type="float64")

def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral:
v: Union[int, str] = value.value[:-1]
Expand Down
11 changes: 3 additions & 8 deletions src/gt4py/next/iterator/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@

from gt4py.eve import Node
from gt4py.next import iterator
from gt4py.next.iterator import builtins
from gt4py.next.iterator import builtins, ir_makers as im
from gt4py.next.iterator.ir import (
AxisLiteral,
Expr,
FencilDefinition,
FunCall,
FunctionDefinition,
Lambda,
Literal,
NoneLiteral,
OffsetLiteral,
StencilClosure,
Expand Down Expand Up @@ -152,12 +151,8 @@ def make_node(o):
return lambdadef(o)
if isinstance(o, iterator.runtime.Offset):
return OffsetLiteral(value=o.value)
if isinstance(o, bool):
return Literal(value=str(o), type="bool")
if isinstance(o, int):
return Literal(value=str(o), type="int")
if isinstance(o, float):
return Literal(value=str(o), type="float")
if isinstance(o, (bool, int, float)):
return im.literal(o)
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(o, CartesianAxis):
return AxisLiteral(value=o.value)
if isinstance(o, tuple):
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,12 @@ def _named_range_with_offsets(
if lower_offset:
lower_bound = ir.FunCall(
fun=ir.SymRef(id="plus"),
args=[lower_bound, ir.Literal(value=str(lower_offset), type="int")],
args=[lower_bound, ir.Literal(value=str(lower_offset), type=ir.INTEGER_INDEX_BUILTIN)],
)
if upper_offset:
upper_bound = ir.FunCall(
fun=ir.SymRef(id="plus"),
args=[upper_bound, ir.Literal(value=str(upper_offset), type="int")],
args=[upper_bound, ir.Literal(value=str(upper_offset), type=ir.INTEGER_INDEX_BUILTIN)],
)
return ir.FunCall(
fun=ir.SymRef(id="named_range"), args=[axis_literal, lower_bound, upper_bound]
Expand Down Expand Up @@ -344,8 +344,8 @@ def _unstructured_domain(
fun=ir.SymRef(id="named_range"),
args=[
ir.AxisLiteral(value=axis),
ir.Literal(value="0", type="int"),
ir.Literal(value=str(size), type="int"),
ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN),
ir.Literal(value=str(size), type=ir.INTEGER_INDEX_BUILTIN),
],
)
]
Expand Down
38 changes: 33 additions & 5 deletions src/gt4py/next/iterator/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,22 @@ def handle_constraint(
return True


class UnionPrimitive(Type):
"""Union of primitive types."""

names: tuple[str, ...]

def handle_constraint(
self, other: Type, add_constraint: abc.Callable[[Type, Type], None]
) -> bool:
if isinstance(other, UnionPrimitive):
raise AssertionError("`UnionPrimitive` may only appear on one side of a constraint.")
if not isinstance(other, Primitive):
return False

return other.name in self.names


class Value(Type):
"""Marker for values."""

Expand Down Expand Up @@ -339,9 +355,16 @@ class LetPolymorphic(Type):
dtype: Type


def _default_constraints():
return {
(FLOAT_DTYPE, UnionPrimitive(names=("float32", "float64"))),
(INT_DTYPE, UnionPrimitive(names=("int32", "int64"))),
}


BOOL_DTYPE = Primitive(name="bool")
INT_DTYPE = Primitive(name="int")
FLOAT_DTYPE = Primitive(name="float")
INT_DTYPE = TypeVar.fresh()
FLOAT_DTYPE = TypeVar.fresh()
AXIS_DTYPE = Primitive(name="axis")
NAMED_RANGE_DTYPE = Primitive(name="named_range")
DOMAIN_DTYPE = Primitive(name="domain")
Expand Down Expand Up @@ -551,7 +574,7 @@ class _TypeInferrer(eve.traits.VisitorWithSymbolTableTrait, eve.NodeTranslator):

offset_provider: Optional[dict[str, Connectivity | Dimension]]
collected_types: dict[int, Type] = dataclasses.field(default_factory=dict)
constraints: set[tuple[Type, Type]] = dataclasses.field(default_factory=set)
constraints: set[tuple[Type, Type]] = dataclasses.field(default_factory=_default_constraints)

def visit(self, node, **kwargs) -> typing.Any:
result = super().visit(node, **kwargs)
Expand Down Expand Up @@ -657,8 +680,13 @@ def _visit_tuple_get(self, node: ir.FunCall, **kwargs) -> Type:
# Calls to `tuple_get` are handled as being part of the grammar, not as function calls.
if len(node.args) != 2:
raise TypeError("`tuple_get` requires exactly two arguments.")
if not isinstance(node.args[0], ir.Literal) or node.args[0].type != "int":
raise TypeError("The first argument to `tuple_get` must be a literal int.")
if (
not isinstance(node.args[0], ir.Literal)
or node.args[0].type != ir.INTEGER_INDEX_BUILTIN
):
raise TypeError(
f"The first argument to `tuple_get` must be a literal of type `{ir.INTEGER_INDEX_BUILTIN}`."
)
idx = int(node.args[0].value)
tup = self.visit(node.args[1], **kwargs)
kind = TypeVar.fresh() # `kind == Iterator()` means splitting an iterator of tuples
Expand Down
2 changes: 0 additions & 2 deletions src/gt4py/next/program_processors/codegens/gtfn/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ class GTFNCodegen(codegen.TemplatedGenerator):
"maximum": "std::max",
"fmod": "std::fmod",
"power": "std::pow",
"float": "double",
"float32": "float",
"float64": "double",
"int": "long",
"int32": "std::int32_t",
"int64": "std::int64_t",
"bool": "bool",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@
def pytype_to_cpptype(t: str):
try:
return {
"float": "double",
"float32": "float",
"float64": "double",
"int": "long",
"int32": "std::int32_t",
"int64": "std::int64_t",
"bool": "bool",
Expand Down
10 changes: 9 additions & 1 deletion src/gt4py/next/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ def visit(self, node, *, memo: dict[T, T]) -> typing.Any: # type: ignore[overri
return node


def _assert_constituent_types(value: typing.Any, allowed_types: tuple[type, ...]) -> None:
if isinstance(value, tuple):
for el in value:
_assert_constituent_types(el, allowed_types)
else:
assert isinstance(value, allowed_types)


class _Renamer:
"""Efficiently rename (that is, replace) nodes in a type expression.

Expand Down Expand Up @@ -191,7 +199,7 @@ def collect_parents(node: Type) -> None:
self._parents.setdefault(child, []).append((node, typing.cast(str, field)))
collect_parents(child)
else:
assert isinstance(child, (int, str))
_assert_constituent_types(child, (int, str))

collect_parents(dtype)

Expand Down
Loading