Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: generic scalar expressions as fieldop args #1633

Merged
merged 12 commits into from
Sep 19, 2024
11 changes: 3 additions & 8 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,14 @@ def applied_as_fieldop(*fields) -> ts.FieldType:


@_register_builtin_type_synthesizer
def cond(
pred: ts.ScalarType,
true_branch: ts.DataType,
false_branch: ts.DataType,
) -> ts.FieldType | ts.DeferredType:
def cond(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType:
def type_synthesizer_per_element(
pred: ts.ScalarType,
true_branch: ts.FieldType | ts.DeferredType,
false_branch: ts.FieldType | ts.DeferredType,
true_branch: ts.DataType,
false_branch: ts.DataType,
):
assert isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL
assert true_branch == false_branch
assert isinstance(true_branch, (ts.FieldType, ts.DeferredType))

return true_branch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def _get_symbolic_value(
f"__out = {symbolic_expr}",
)
temp_name, _ = sdfg.add_scalar(
f"__{temp_name or 'tmp'}",
temp_name or sdfg.temp_data_name(),
dace_fieldview_util.as_dace_type(scalar_type),
find_new_name=True,
transient=True,
Expand Down Expand Up @@ -369,6 +369,93 @@ def translate_literal(
return [(data_node, data_type)]


def translate_scalar_expr(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
) -> list[TemporaryData]:
assert isinstance(node, gtir.FunCall)
assert isinstance(node.type, ts.ScalarType)

args = []
connectors = []
scalar_expr_args = []

for arg_expr in node.args:
visit_expr = True
if isinstance(arg_expr, gtir.SymRef):
try:
# `gt_symbol` refers to symbols defined in the GT4Py program
gt_symbol_type = sdfg_builder.get_symbol_type(arg_expr.id)
if not isinstance(gt_symbol_type, ts.ScalarType):
raise ValueError(f"Invalid argument to scalar expression {arg_expr}.")
except KeyError:
# this is the case of non-variable argument, e.g. target type such as `float64`,
# used in a casting expression like `cast_(variable, float64)`
visit_expr = False

if visit_expr:
# we visit the argument expression and obtain the access node to
# a scalar data container, which will be connected to the tasklet
arg_node, arg_type = sdfg_builder.visit(
arg_expr,
sdfg=sdfg,
head_state=state,
reduce_identity=reduce_identity,
)[0]
if not (
isinstance(arg_type, ts.ScalarType)
and isinstance(arg_node.desc(sdfg), dace.data.Scalar)
):
raise ValueError(f"Invalid argument to scalar expression {arg_expr}.")
param = f"__in_{arg_node.data}"
args.append(arg_node)
connectors.append(param)
scalar_expr_args.append(gtir.SymRef(id=param))
else:
assert isinstance(arg_expr, gtir.SymRef)
edopao marked this conversation as resolved.
Show resolved Hide resolved
scalar_expr_args.append(arg_expr)

# we visit the scalar expression replacing the input arguments with the corresponding data connectors
scalar_node = gtir.FunCall(fun=node.fun, args=scalar_expr_args)
python_code = gtir_python_codegen.get_source(scalar_node)
tasklet_node = sdfg_builder.add_tasklet(
name="scalar_expr",
state=state,
inputs=set(connectors),
outputs={"__out"},
code=f"__out = {python_code}",
)
# create edges for the input data connectors
for arg_node, conn in zip(args, connectors, strict=True):
state.add_edge(
arg_node,
None,
tasklet_node,
conn,
dace.Memlet(data=arg_node.data, subset="0"),
)
# finally, create temporary for the result value
temp_name, _ = sdfg.add_scalar(
sdfg.temp_data_name(),
dace_fieldview_util.as_dace_type(node.type),
find_new_name=True,
transient=True,
)
temp_node = state.add_access(temp_name)
state.add_edge(
tasklet_node,
"__out",
temp_node,
None,
dace.Memlet(data=temp_name, subset="0"),
)

return [(temp_node, node.type)]


def translate_symbol_ref(
node: gtir.Node,
sdfg: dace.SDFG,
Expand All @@ -379,20 +466,24 @@ def translate_symbol_ref(
"""Generates the dataflow subgraph for a `ir.SymRef` node."""
assert isinstance(node, gtir.SymRef)

sym_value = str(node.id)
sym_type = sdfg_builder.get_symbol_type(sym_value)
symbol_name = str(node.id)
# we retrieve the type of the symbol in the GT4Py prgram
gt_symbol_type = sdfg_builder.get_symbol_type(symbol_name)

# Create new access node in current state. It is possible that multiple
# access nodes are created in one state for the same data container.
# We rely on the dace simplify pass to remove duplicated access nodes.
if isinstance(sym_type, ts.FieldType):
sym_node = state.add_access(sym_value)
if isinstance(gt_symbol_type, ts.FieldType):
sym_node = state.add_access(symbol_name)
elif symbol_name in sdfg.arrays:
# access the existing scalar container
sym_node = state.add_access(symbol_name)
else:
sym_node = _get_symbolic_value(
sdfg, state, sdfg_builder, sym_value, sym_type, temp_name=sym_value
sdfg, state, sdfg_builder, symbol_name, gt_symbol_type, temp_name=f"__{symbol_name}"
)

return [(sym_node, sym_type)]
return [(sym_node, gt_symbol_type)]


if TYPE_CHECKING:
Expand All @@ -401,5 +492,6 @@ def translate_symbol_ref(
translate_as_field_op,
translate_cond,
translate_literal,
translate_scalar_expr,
translate_symbol_ref,
]
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,12 @@ class SDFGBuilder(DataflowBuilder, Protocol):

@abc.abstractmethod
def get_symbol_type(self, symbol_name: str) -> ts.FieldType | ts.ScalarType:
"""Retrieve the GT4Py type of a symbol used in the program."""
pass

@abc.abstractmethod
def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
"""Visit a node of the GT4Py IR."""
pass


Expand Down Expand Up @@ -366,6 +368,10 @@ def visit_FunCall(
reduce_identity=reduce_identity,
args=node_args,
)
elif isinstance(node.type, ts.ScalarType):
return gtir_builtin_translators.translate_scalar_expr(
node, sdfg, head_state, self, reduce_identity
)
else:
raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).")

Expand Down Expand Up @@ -395,21 +401,15 @@ def visit_Lambda(
lambda_symbols = self.global_symbols | {
pname: type_ for pname, (_, type_) in lambda_args_mapping.items()
}
# obtain the set of symbols that are used in the lambda node and all its child nodes
used_symbols = {str(sym.id) for sym in eve.walk_values(node).if_isinstance(gtir.SymRef)}

nsdfg = dace.SDFG(f"{sdfg.label}_nested")
nstate = nsdfg.add_state("lambda")

# add sdfg storage for the symbols that need to be passed as input parameters,
# that is only the symbols that are used in the context of the lambda node
# that are only the symbols used in the context of the lambda node
self._add_sdfg_params(
nsdfg,
[
gtir.Sym(id=p_name, type=p_type)
for p_name, p_type in lambda_symbols.items()
if p_name in used_symbols
],
[gtir.Sym(id=p_name, type=p_type) for p_name, p_type in lambda_symbols.items()],
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
)

lambda_nodes = GTIRToSDFG(self.offset_provider, lambda_symbols.copy()).visit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,8 @@ def _make_cartesian_shift(
if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr):
# purely symbolic expression which can be interpreted at compile time
new_index = SymbolExpr(
dace.symbolic.SymExpr(index_expr.value) + offset_expr.value, index_expr.dtype
dace.symbolic.pystr_to_symbolic(index_expr.value) + offset_expr.value,
index_expr.dtype,
)
else:
# the offset needs to be calculated by means of a tasklet (i.e. dynamic offset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_domain(
axis = named_range.args[0]
assert isinstance(axis, gtir.AxisLiteral)
bounds = [
dace.symbolic.SymExpr(gtir_python_codegen.get_source(arg))
dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg))
for arg in named_range.args[1:3]
]
dim = gtx_common.Dimension(axis.value, axis.kind)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ def test_gtir_update():
im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size")
)
stencil1 = im.as_fieldop(
im.lambda_("a")(im.plus(im.deref("a"), 1.0)),
im.lambda_("a")(im.plus(im.deref("a"), 0 - 1.0)),
domain,
)("x")
stencil2 = im.op_as_fieldop("plus", domain)("x", 1.0)
stencil2 = im.op_as_fieldop("plus", domain)("x", 0 - 1.0)

for i, stencil in enumerate([stencil1, stencil2]):
testee = gtir.Program(
Expand All @@ -200,7 +200,7 @@ def test_gtir_update():
sdfg = dace_backend.build_sdfg_from_gtir(testee, {})

a = np.random.rand(N)
ref = a + 1.0
ref = a - 1.0

sdfg(a, **FSYMBOLS)
assert np.allclose(a, ref)
Expand Down Expand Up @@ -1382,6 +1382,50 @@ def test_gtir_let_lambda_with_cond():
assert np.allclose(b, a if s else a * 2)


def test_gtir_if_scalars():
domain = im.call("cartesian_domain")(
im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size")
)
testee = gtir.Program(
id="if_scalars",
function_definitions=[],
params=[
gtir.Sym(id="x", type=IFTYPE),
gtir.Sym(id="y_0", type=SIZE_TYPE),
gtir.Sym(id="y_1", type=SIZE_TYPE),
gtir.Sym(id="z", type=IFTYPE),
gtir.Sym(id="pred", type=ts.ScalarType(ts.ScalarKind.BOOL)),
gtir.Sym(id="size", type=SIZE_TYPE),
],
declarations=[],
body=[
gtir.SetAt(
expr=im.op_as_fieldop("plus", domain)(
"x",
im.cond(
"pred",
im.call("cast_")("y_0", "float64"),
im.call("cast_")("y_1", "float64"),
),
),
domain=domain,
target=gtir.SymRef(id="z"),
)
],
)

a = np.random.rand(N)
b = np.empty_like(a)
d1 = np.random.randint(0, 1000)
d2 = np.random.randint(0, 1000)

sdfg = dace_backend.build_sdfg_from_gtir(testee, {})

for s in [False, True]:
sdfg(a, y_0=d1, y_1=d2, z=b, pred=np.bool_(s), **FSYMBOLS)
assert np.allclose(b, (a + d1 if s else a + d2))


def test_gtir_if_values():
domain = im.call("cartesian_domain")(
im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size")
Expand Down
Loading