Skip to content

Commit

Permalink
ITIR type inference fix stencil closure location constraints (#1185)
Browse files Browse the repository at this point in the history
- Properly propagate constraints imposed on defined location of closure inputs from closure outputs.
- Uncouple current location of fencil arguments and stencil arguments, i.e. on fencil level the iterator arguments have current location ANYWHERE and only its other properties are propagated to the stencils of a closure (e.g. dtype, defined_location, ...)
  • Loading branch information
tehrengruber committed Mar 8, 2023
1 parent 013fd0a commit be0c3fb
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 40 deletions.
37 changes: 29 additions & 8 deletions src/gt4py/next/iterator/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def visit_FunctionDefinition(
# their parameters to inherit the constraints of the arguments in a call to them. A simple
# way to do this is to run the type inference on the function itself and reindex its type
# vars when referencing the function, i.e. in a `SymRef`.
collected_types = _infer_all(fun, offset_provider=self.offset_provider, reindex=False)
collected_types = infer_all(fun, offset_provider=self.offset_provider, reindex=False)
fun_type = LetPolymorphic(dtype=collected_types.pop(id(fun)))
assert not set(self.collected_types.keys()) & set(collected_types.keys())
self.collected_types = {**self.collected_types, **collected_types}
Expand All @@ -594,7 +594,6 @@ def visit_StencilClosure(
domain = self.visit(node.domain, **kwargs)
stencil = self.visit(node.stencil, **kwargs)
output = self.visit(node.output, **kwargs)
inputs = Tuple.from_elems(*self.visit(node.inputs, **kwargs))
output_dtype = TypeVar.fresh()
output_loc = TypeVar.fresh()
self.constraints.add(
Expand All @@ -607,18 +606,40 @@ def visit_StencilClosure(
kind=Iterator(),
dtype=output_dtype,
size=Column(),
current_loc=output_loc,
defined_loc=output_loc,
),
)
)

inputs: list[Type] = self.visit(node.inputs, **kwargs)
stencil_params = []
for input_ in inputs:
stencil_param = Val(current_loc=output_loc, defined_loc=TypeVar.fresh())
self.constraints.add(
(
input_,
Val(
kind=stencil_param.kind,
dtype=stencil_param.dtype,
size=stencil_param.size,
# closure input and stencil param differ in `current_loc`
current_loc=ANYWHERE,
defined_loc=stencil_param.defined_loc,
),
)
)
stencil_params.append(stencil_param)

self.constraints.add(
(
stencil,
FunctionType(args=inputs, ret=Val(kind=Value(), dtype=output_dtype, size=Column())),
FunctionType(
args=Tuple.from_elems(*stencil_params),
ret=Val(kind=Value(), dtype=output_dtype, size=Column()),
),
)
)
return Closure(output=output, inputs=inputs)
return Closure(output=output, inputs=Tuple.from_elems(*inputs))

def visit_FencilDefinition(
self,
Expand All @@ -637,13 +658,13 @@ def visit_FencilDefinition(
params = [self.visit(p, **kwargs) for p in node.params]
self.visit(node.closures, **kwargs)
return FencilDefinitionType(
name=node.id,
name=str(node.id),
fundefs=Tuple.from_elems(*ftypes),
params=Tuple.from_elems(*params),
)


def _infer_all(
def infer_all(
node: ir.Node,
offset_provider: Optional[dict[str, Connectivity | Dimension]] = None,
reindex: bool = True,
Expand Down Expand Up @@ -674,7 +695,7 @@ def infer(
offset_provider: typing.Optional[dict[str, typing.Any]] = None,
) -> Type:
"""Infer the type of the given iterator IR expression."""
inferred_types = _infer_all(expr, offset_provider)
inferred_types = infer_all(expr, offset_provider)
return inferred_types[id(expr)]


Expand Down
122 changes: 92 additions & 30 deletions tests/next_tests/iterator_tests/test_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np

from gt4py.next.common import Dimension
from gt4py.next.ffront import itir_makers as im
from gt4py.next.iterator import ir, type_inference as ti
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.iterator.runtime import CartesianAxis
Expand Down Expand Up @@ -617,22 +618,22 @@ def test_stencil_closure():
kind=ti.Iterator(),
dtype=ti.TypeVar(idx=0),
size=ti.Column(),
current_loc=ti.TypeVar(idx=1),
current_loc=ti.ANYWHERE,
defined_loc=ti.TypeVar(idx=1),
),
inputs=ti.Tuple.from_elems(
ti.Val(
kind=ti.Iterator(),
dtype=ti.TypeVar(idx=0),
size=ti.Column(),
current_loc=ti.TypeVar(idx=2),
defined_loc=ti.TypeVar(idx=2),
current_loc=ti.ANYWHERE,
defined_loc=ti.TypeVar(idx=1),
),
),
)
inferred = ti.infer(testee)
assert inferred == expected
assert ti.pformat(inferred) == "(It[T₂, T, T₀ᶜ]) ⇒ It[T₁, T₁, T₀ᶜ]"
assert ti.pformat(inferred) == "(It[ANYWHERE, T, T₀ᶜ]) ⇒ It[ANYWHERE, T₁, T₀ᶜ]"


def test_fencil_definition():
Expand Down Expand Up @@ -674,39 +675,100 @@ def test_fencil_definition():
kind=ti.Iterator(),
dtype=ti.TypeVar(idx=0),
size=ti.Column(),
current_loc=ti.TypeVar(idx=1),
current_loc=ti.ANYWHERE,
defined_loc=ti.TypeVar(idx=1),
),
ti.Val(
kind=ti.Iterator(),
dtype=ti.TypeVar(idx=0),
size=ti.Column(),
current_loc=ti.TypeVar(idx=2),
defined_loc=ti.TypeVar(idx=2),
current_loc=ti.ANYWHERE,
defined_loc=ti.TypeVar(idx=1),
),
ti.Val(
kind=ti.Iterator(),
dtype=ti.TypeVar(idx=3),
dtype=ti.TypeVar(idx=2),
size=ti.Column(),
current_loc=ti.TypeVar(idx=4),
defined_loc=ti.TypeVar(idx=4),
current_loc=ti.ANYWHERE,
defined_loc=ti.TypeVar(idx=3),
),
ti.Val(
kind=ti.Iterator(),
dtype=ti.TypeVar(idx=3),
dtype=ti.TypeVar(idx=2),
size=ti.Column(),
current_loc=ti.TypeVar(idx=5),
defined_loc=ti.TypeVar(idx=5),
current_loc=ti.ANYWHERE,
defined_loc=ti.TypeVar(idx=3),
),
),
)
inferred = ti.infer(testee)
assert inferred == expected
assert (
ti.pformat(inferred)
== "{f(intˢ, intˢ, intˢ, It[T₁, T₁, T₀ᶜ], It[T₂, T₂, T₀ᶜ], It[T₄, T₄, T₃ᶜ], It[T₅, T₅, T₃ᶜ])}"
== "{f(intˢ, intˢ, intˢ, It[ANYWHERE, T₁, T₀ᶜ], It[ANYWHERE, T₁, T₀ᶜ], It[ANYWHERE, T₃, T₂ᶜ], It[ANYWHERE, T₃, T₂ᶜ])}"
)


def test_fencil_definition_same_closure_input():
f1 = ir.FunctionDefinition(
id="f1", params=[im.sym("vertex_it")], expr=im.deref_(im.shift_("E2V")("vertex_it"))
)
f2 = ir.FunctionDefinition(id="f2", params=[im.sym("vertex_it")], expr=im.deref_("vertex_it"))

testee = ir.FencilDefinition(
id="fencil",
function_definitions=[f1, f2],
params=[im.sym("vertex_it"), im.sym("output_edge_it"), im.sym("output_vertex_it")],
closures=[
ir.StencilClosure(
domain=im.call_("unstructured_domain")(
im.call_("named_range")(
ir.AxisLiteral(value="Edge"),
ir.Literal(value="0", type="int"),
ir.Literal(value="10", type="int"),
)
),
stencil=im.ref("f1"),
output=im.ref("output_edge_it"),
inputs=[im.ref("vertex_it")],
),
ir.StencilClosure(
domain=im.call_("unstructured_domain")(
im.call_("named_range")(
ir.AxisLiteral(value="Vertex"),
ir.Literal(value="0", type="int"),
ir.Literal(value="10", type="int"),
)
),
stencil=im.ref("f2"),
output=im.ref("output_vertex_it"),
inputs=[im.ref("vertex_it")],
),
],
)

offset_provider = {
"E2V": NeighborTableOffsetProvider(
np.empty((0, 2), dtype=np.int64), Dimension("Edge"), Dimension("Vertex"), 2, False
)
}
inferred_all: dict[int, ti.Type] = ti.infer_all(testee, offset_provider)

# validate locations of fencil params
fencil_param_types = [inferred_all[id(testee.params[i])] for i in range(3)]
assert fencil_param_types[0].defined_loc == ti.Location(name="Vertex")
assert fencil_param_types[1].defined_loc == ti.Location(name="Edge")
assert fencil_param_types[2].defined_loc == ti.Location(name="Vertex")

# validate locations of stencil params
f1_param_type: ti.Val = inferred_all[id(f1.params[0])]
assert f1_param_type.current_loc == ti.Location(name="Edge")
assert f1_param_type.defined_loc == ti.Location(name="Vertex")
# f2 is polymorphic and there is no shift inside so we only get a TypeVar here
f2_param_type: ti.Val = inferred_all[id(f2.params[0])]
assert isinstance(f2_param_type.current_loc, ti.TypeVar)
assert isinstance(f2_param_type.defined_loc, ti.TypeVar)


def test_fencil_definition_with_function_definitions():
fundefs = [
Expand Down Expand Up @@ -794,51 +856,51 @@ def test_fencil_definition_with_function_definitions():
kind=ti.Iterator(),
dtype=ti.TypeVar(idx=4),
size=ti.Column(),
current_loc=ti.TypeVar(idx=5),
current_loc=ti.ANYWHERE,
defined_loc=ti.TypeVar(idx=5),
),
ti.Val(
kind=ti.Iterator(),
dtype=ti.TypeVar(idx=4),
size=ti.Column(),
current_loc=ti.TypeVar(idx=6),
defined_loc=ti.TypeVar(idx=6),
current_loc=ti.ANYWHERE,
defined_loc=ti.TypeVar(idx=5),
),
ti.Val(
kind=ti.Iterator(),
dtype=ti.TypeVar(idx=7),
dtype=ti.TypeVar(idx=6),
size=ti.Column(),
current_loc=ti.TypeVar(idx=8),
defined_loc=ti.TypeVar(idx=8),
current_loc=ti.ANYWHERE,
defined_loc=ti.TypeVar(idx=7),
),
ti.Val(
kind=ti.Iterator(),
dtype=ti.TypeVar(idx=7),
dtype=ti.TypeVar(idx=6),
size=ti.Column(),
current_loc=ti.TypeVar(idx=9),
defined_loc=ti.TypeVar(idx=9),
current_loc=ti.ANYWHERE,
defined_loc=ti.TypeVar(idx=7),
),
ti.Val(
kind=ti.Iterator(),
dtype=ti.TypeVar(idx=10),
dtype=ti.TypeVar(idx=8),
size=ti.Column(),
current_loc=ti.TypeVar(idx=11),
defined_loc=ti.TypeVar(idx=11),
current_loc=ti.ANYWHERE,
defined_loc=ti.TypeVar(idx=9),
),
ti.Val(
kind=ti.Iterator(),
dtype=ti.TypeVar(idx=10),
dtype=ti.TypeVar(idx=8),
size=ti.Column(),
current_loc=ti.TypeVar(idx=12),
defined_loc=ti.TypeVar(idx=12),
current_loc=ti.ANYWHERE,
defined_loc=ti.TypeVar(idx=9),
),
),
)
inferred = ti.infer(testee)
assert inferred == expected
assert (
ti.pformat(inferred)
== "{f :: (T₀) → T₀, g :: (It[T₃, T₃, T₁²]) → T₁², foo(intˢ, intˢ, intˢ, It[T₅, T₅, T₄ᶜ], It[T₆, T, T₄ᶜ], It[T₈, T, Tᶜ], It[T₉, T, Tᶜ], It[T₁₁, T₁₁, T₁₀ᶜ], It[T₁₂, T₁₂, T₁₀ᶜ])}"
== "{f :: (T₀) → T₀, g :: (It[T₃, T₃, T₁²]) → T₁², foo(intˢ, intˢ, intˢ, It[ANYWHERE, T₅, T₄ᶜ], It[ANYWHERE, T, T₄ᶜ], It[ANYWHERE, T, Tᶜ], It[ANYWHERE, T, Tᶜ], It[ANYWHERE, T, Tᶜ], It[ANYWHERE, T, Tᶜ])}"
)


Expand Down
8 changes: 6 additions & 2 deletions tests/next_tests/iterator_tests/test_with_toy_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
v2v_arr,
)

from gt4py.next.common import Dimension
from gt4py.next.iterator import transforms
from gt4py.next.iterator.builtins import deref, lift, plus, reduce, shift
from gt4py.next.iterator.embedded import (
Expand All @@ -37,7 +36,7 @@
np_as_located_field,
)
from gt4py.next.iterator.runtime import fundef, offset
from gt4py.next.program_processors.formatters import gtfn
from gt4py.next.program_processors.formatters import gtfn, type_check
from gt4py.next.program_processors.runners import gtfn_cpu

from .conftest import run_processor
Expand Down Expand Up @@ -137,6 +136,8 @@ def sparse_stencil(non_sparse, inp):

def test_sparse_input_field(program_processor_no_gtfn_exec, lift_mode):
program_processor, validate = program_processor_no_gtfn_exec
if program_processor == type_check.check:
pytest.xfail("Partial shifts not properly supported by type inference.")
non_sparse = np_as_located_field(Edge)(np.zeros(18))
inp = np_as_located_field(Vertex, V2E)(np.asarray([[1, 2, 3, 4]] * 9))
out = np_as_located_field(Vertex)(np.zeros([9]))
Expand All @@ -159,6 +160,9 @@ def test_sparse_input_field(program_processor_no_gtfn_exec, lift_mode):

def test_sparse_input_field_v2v(program_processor_no_gtfn_exec, lift_mode):
program_processor, validate = program_processor_no_gtfn_exec
if program_processor == type_check.check:
pytest.xfail("Partial shifts not properly supported by type inference.")

non_sparse = np_as_located_field(Edge)(np.zeros(18))
inp = np_as_located_field(Vertex, V2V)(v2v_arr)
out = np_as_located_field(Vertex)(np.zeros([9]))
Expand Down

0 comments on commit be0c3fb

Please sign in to comment.