Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[next]: DaCe scalar argument in stencil closure #1293

Merged
merged 6 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
sdfg: dace.SDFG = sdfg_genenerator.visit(program)

dace_args = get_args(program.params, args)
dace_field_args = {n: v for n, v in dace_args.items() if hasattr(v, "shape")}
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = get_connectivity_args(neighbor_tables)
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def _visit_parallel_stencil_closure(
# Create an SDFG for the tasklet that computes a single item of the output domain.
index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain}

input_arrays = [(array_table[name], name, self.storage_types[name]) for name in input_names]
input_arrays = [(name, self.storage_types[name]) for name in input_names]
conn_arrays = [(array_table[name], name) for name in conn_names]

context, _, results = closure_to_tasklet_sdfg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,18 @@ def visit_Lambda(
context.body.add_array(name, shape=shape, strides=strides, dtype=dtype)

# Translate the function's body
result: ValueExpr = self.visit(node.expr)[0]
result: ValueExpr | SymbolExpr = self.visit(node.expr)[0]
# Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors
result = self.add_expr_tasklet([(result, "result")], "result", result.dtype, "forward")[0]
if isinstance(result, ValueExpr):
result_name = unique_var_name()
self.context.body.add_scalar(result_name, result.dtype, transient=True)
result_access = self.context.state.add_access(result_name)
self.context.state.add_edge(
result.value, None, result_access, None, dace.Memlet(f"{result.value.data}[0]")
)
result = ValueExpr(value=result_access, dtype=result.dtype)
else:
result = self.add_expr_tasklet([], result.value, result.dtype, "forward")[0]
self.context.body.arrays[result.value.data].transient = False
self.context = prev_context

Expand All @@ -319,7 +328,7 @@ def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | Iter
def visit_Literal(self, node: itir.Literal) -> list[SymbolExpr]:
node_type = self.node_types[id(node)]
assert isinstance(node_type, Val)
return [SymbolExpr(node.value, node_type.dtype)]
return [SymbolExpr(node.value, itir_type_as_dace_type(node_type.dtype))]

def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr:
if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref":
Expand Down Expand Up @@ -608,7 +617,7 @@ def closure_to_tasklet_sdfg(
node: itir.StencilClosure,
offset_provider: dict[str, Any],
domain: dict[str, str],
inputs: Sequence[tuple[dace.ndarray, str, ts.TypeSpec]],
inputs: Sequence[tuple[str, ts.TypeSpec]],
connectivities: Sequence[tuple[dace.ndarray, str]],
node_types: dict[int, next_typing.Type],
) -> tuple[Context, Sequence[tuple[str, ValueExpr]], Sequence[ValueExpr]]:
Expand All @@ -624,19 +633,26 @@ def closure_to_tasklet_sdfg(
access = state.add_access(name)
idx_accesses[dim] = access
state.add_edge(tasklet, "value", access, None, dace.Memlet(data=name, subset="0"))
for _, name, ty in inputs:
assert isinstance(ty, ts.FieldType)
ndim = len(ty.dims)
shape = [dace.symbol(f"{unique_var_name()}_shp{i}", dtype=dace.int64) for i in range(ndim)]
stride = [
dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(ndim)
]
dims = [dim.value for dim in ty.dims]
dtype = as_dace_type(ty.dtype)
body.add_array(name, shape=shape, strides=stride, dtype=dtype)
field = state.add_access(name)
indices = {dim: idx_accesses[dim] for dim in domain.keys()}
symbol_map[name] = IteratorExpr(field, indices, dtype, dims)
for name, ty in inputs:
if isinstance(ty, ts.FieldType):
ndim = len(ty.dims)
shape = [
dace.symbol(f"{unique_var_name()}_shp{i}", dtype=dace.int64) for i in range(ndim)
]
stride = [
dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(ndim)
]
dims = [dim.value for dim in ty.dims]
dtype = as_dace_type(ty.dtype)
body.add_array(name, shape=shape, strides=stride, dtype=dtype)
field = state.add_access(name)
indices = {dim: idx_accesses[dim] for dim in domain.keys()}
symbol_map[name] = IteratorExpr(field, indices, dtype, dims)
else:
assert isinstance(ty, ts.ScalarType)
dtype = as_dace_type(ty)
body.add_scalar(name, dtype=dtype)
symbol_map[name] = ValueExpr(state.add_access(name), dtype)
for arr, name in connectivities:
shape = [dace.symbol(f"{unique_var_name()}_shp{i}", dtype=dace.int64) for i in range(2)]
stride = [dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(2)]
Expand All @@ -654,7 +670,7 @@ def closure_to_tasklet_sdfg(
inner_outputs = _visit_closure_callable(
node,
translator,
[name for _, name, _ in inputs],
[name for name, _ in inputs],
)
for output in inner_outputs:
context.body.arrays[output.value.data].transient = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@ def testee(a: cases.IJKFloatField, b: cases.IJKFloatField) -> cases.IJKFloatFiel

def test_scalar_arg(unstructured_case): # noqa: F811 # fixtures
"""Test scalar argument being turned into 0-dim field."""
if unstructured_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: broadcast")

@gtx.field_operator
def testee(a: int32) -> cases.VField:
Expand All @@ -198,9 +196,6 @@ def testee(a: int32) -> cases.VField:


def test_nested_scalar_arg(unstructured_case): # noqa: F811 # fixtures
if unstructured_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: broadcast")

@gtx.field_operator
def testee_inner(a: int32) -> cases.VField:
return broadcast(a + 1, (Vertex,))
Expand Down Expand Up @@ -587,9 +582,6 @@ def expected(a, b, c, d):

@pytest.mark.parametrize("left, right", [(2, 3), (3, 2)])
def test_ternary_operator(cartesian_case, left, right):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: broadcast")

@gtx.field_operator
def testee(a: cases.IField, b: cases.IField, left: int32, right: int32) -> cases.IField:
return a if left < right else b
Expand Down Expand Up @@ -917,7 +909,7 @@ def return_undefined():

def test_zero_dims_fields(cartesian_case):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: broadcast")
pytest.xfail("Not supported in DaCe backend: zero-dimensional fields")
edopao marked this conversation as resolved.
Show resolved Hide resolved

@gtx.field_operator
def implicit_broadcast_scalar(inp: cases.EmptyField):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,6 @@ def conditional_nested_tuple(


def test_broadcast_simple(cartesian_case):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: broadcast")

@gtx.field_operator
def simple_broadcast(inp: cases.IField) -> cases.IJField:
return broadcast(inp, (IDim, JDim))
Expand All @@ -221,8 +218,6 @@ def simple_broadcast(inp: cases.IField) -> cases.IJField:


def test_broadcast_scalar(cartesian_case):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: broadcast")
size = cartesian_case.default_sizes[IDim]

@gtx.field_operator
Expand All @@ -233,9 +228,6 @@ def scalar_broadcast() -> gtx.Field[[IDim], float64]:


def test_broadcast_two_fields(cartesian_case):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: broadcast")

@gtx.field_operator
def broadcast_two_fields(inp1: cases.IField, inp2: gtx.Field[[JDim], int32]) -> cases.IJField:
a = broadcast(inp1, (IDim, JDim))
Expand All @@ -248,9 +240,6 @@ def broadcast_two_fields(inp1: cases.IField, inp2: gtx.Field[[JDim], int32]) ->


def test_broadcast_shifted(cartesian_case):
if cartesian_case.backend == dace_iterator.run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: broadcast")

@gtx.field_operator
def simple_broadcast(inp: cases.IField) -> cases.IJField:
bcasted = broadcast(inp, (IDim, JDim))
Expand Down