-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next][dace]: Skeleton of GTIR DaCe backend #1538
Merged
Merged
Changes from 48 commits
Commits
Show all changes
84 commits
Select commit
Hold shift + click to select a range
b36fcb3
Skeleton for ITIR translation
edopao 2020182
Minor edit
edopao 5c6b6ba
Use Python callstack as a context stack for the ITIR visitor
edopao 60e1c69
Format error
edopao 073a0a4
Refactor tasklet codegen
edopao 50be68f
Code refactoring
edopao 4e2dc15
Add domain to field operator
edopao ea9da35
Minor edit
edopao daf7827
Remove hard-coded field shape
edopao 9672b3b
Remove hard-coded target domain
edopao b6326b8
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao 26f3790
Refactoring
edopao 1efffa7
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao f99fa84
Fix formatting
edopao d6e1088
More refactoring
edopao 9854497
Minor edit
edopao 37d83d7
Fix formatting
edopao 29986ef
Use callable to build taskgraph
edopao 390f3b4
Add draft of select operator
edopao de27419
Remove node mapping
edopao cd900f5
Remove node mapping (fix + test case)
edopao 326cbb5
Add test case for inlined mathematic builtins
edopao a10b614
Go full functional (remove SDFGState member var)
edopao aef4265
Minor edit
edopao 9e67dfe
Minor edit (1)
edopao 4b4109e
Fix state handling
edopao 495fd0a
Edit comments based on review
edopao 0085194
Add test case for nested select
edopao 41e2a44
Separate builtin translation from driver logic
edopao 7148c5f
Improve code comments
edopao 452399d
Avoid inheritance: pass dataflow builder as arg to builtin translator
edopao e404226
Codestyle review changes
edopao bb0dfac
Remove circular dependency for builtin translators
edopao 412cd5d
Fix formatting
edopao 651de5c
Minor edit
edopao dcf3eab
Add support to translate each builtin call to a tasklet node
edopao 7e6909e
Resolve dace warnings
edopao 2b07cc5
Remove bultin translator for domain expressions
edopao 2370fa6
Remove bultin translator for domain expressions (1)
edopao 8e801df
Refactor
edopao 812a6e5
Minor edit
edopao 1d0b50b
Extract ITIR visitor to separate class
edopao 97a1d22
Code refactoring
edopao a30cc7d
Fix formatting
edopao e9455e3
Changes in preparation for shift builtin
edopao 801704b
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao c45c417
Add support for programs without computation (pure memlets)
edopao d67518a
Fix test
edopao c4385c1
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao ed16fd4
Import updates from branch dace-fieldview-shifts
edopao 9f7176f
Review comments
edopao 46febb0
Avoid tasklet-to-tasklet edge connections
edopao 949bad7
Add support for in-out field parameters
edopao 8890f95
Refactoring: import modules, not symbols
edopao 87b71a6
Minor edit
edopao 665a609
Remove internal package for builtin translators
edopao 82fdf64
Add wrapper function to build SDFG
edopao e4718b0
Merge pull request #4 from edopao/dace-fieldview-refactor_imports
edopao 6ccecf1
Code changes imported from branch dace-fieldview-shifts
edopao 3c71efa
Import changes from neighbors branch
edopao 2f75cfb
Add debuginfo for ir.Program and ir.Stmt nodes
edopao 085f307
Fix error in debuginfo
edopao f19960b
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao dc1434c
Fix error in debuginfo (1)
edopao a5b0f41
import changes from neighbors branch
edopao f7ac3d8
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao 9318011
Import changes from branch dace-fieldview-neighbors
edopao 11efdeb
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao d7312fa
Support field with start offset
edopao c4f2738
Test IR updated for literal operand
edopao 0fd0b65
Add test coverage to previous commit
edopao 38d2720
Refactor PrimitiveTranslator interface
edopao e855ef9
Fix formatting
edopao 4cff071
Fix for domain horzontal/vertical dims
edopao f642e85
Fix for type inference on single value expression
edopao fc9661c
Import changes from dace-fieldview-shifts
edopao e424d4e
Minor edit
edopao 66c5fcd
Address review comments
edopao d5abad4
Merge remote-tracking branch 'origin/main' into dace-fieldview
edopao 1df1bc3
Apply convention for map variables
edopao abf3918
Import changes from dace-fieldview-shifts
edopao 7f60cfe
Import changes from branch dace-fieldview-shifts
edopao b3131db
Avoid direct import of symbols from module
edopao 130c877
Address review comments
edopao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
13 changes: 13 additions & 0 deletions
13
src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# GT4Py - GridTools Framework | ||
# | ||
# Copyright (c) 2014-2023, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# This file is part of the GT4Py project and the GridTools framework. | ||
# GT4Py is free software: you can redistribute it and/or modify it under | ||
# the terms of the GNU General Public License as published by the | ||
# Free Software Foundation, either version 3 of the License, or any later | ||
# version. See the LICENSE.txt file at the top-level directory of this | ||
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later |
31 changes: 31 additions & 0 deletions
31
src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# GT4Py - GridTools Framework | ||
# | ||
# Copyright (c) 2014-2023, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# This file is part of the GT4Py project and the GridTools framework. | ||
# GT4Py is free software: you can redistribute it and/or modify it under | ||
# the terms of the GNU General Public License as published by the | ||
# Free Software Foundation, either version 3 of the License, or any later | ||
# version. See the LICENSE.txt file at the top-level directory of this | ||
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
|
||
from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_field_operator import ( | ||
GTIRBuiltinAsFieldOp as AsFieldOp, | ||
) | ||
from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_select import ( | ||
GTIRBuiltinSelect as Select, | ||
) | ||
from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_symbol_ref import ( | ||
GTIRBuiltinSymbolRef as SymbolRef, | ||
) | ||
|
||
|
||
# export short names of translation classes for GTIR builtin functions | ||
__all__ = [ | ||
"AsFieldOp", | ||
"Select", | ||
"SymbolRef", | ||
] |
155 changes: 155 additions & 0 deletions
155
...xt/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_field_operator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# GT4Py - GridTools Framework | ||
# | ||
# Copyright (c) 2014-2023, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# This file is part of the GT4Py project and the GridTools framework. | ||
# GT4Py is free software: you can redistribute it and/or modify it under | ||
# the terms of the GNU General Public License as published by the | ||
# Free Software Foundation, either version 3 of the License, or any later | ||
# version. See the LICENSE.txt file at the top-level directory of this | ||
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
|
||
|
||
from typing import Callable, TypeAlias | ||
|
||
import dace | ||
import dace.subsets as sbs | ||
|
||
from gt4py.next.common import Connectivity, Dimension | ||
from gt4py.next.iterator import ir as itir | ||
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm | ||
from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( | ||
GTIRBuiltinTranslator, | ||
) | ||
from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_tasklet import ( | ||
GTIRToTasklet, | ||
IteratorExpr, | ||
MemletExpr, | ||
SymbolExpr, | ||
TaskletExpr, | ||
) | ||
from gt4py.next.program_processors.runners.dace_fieldview.utility import get_domain, unique_name | ||
from gt4py.next.type_system import type_specifications as ts | ||
|
||
|
||
# Define type of variables used for field indexing | ||
_INDEX_DTYPE = dace.int64 | ||
|
||
|
||
class GTIRBuiltinAsFieldOp(GTIRBuiltinTranslator): | ||
"""Generates the dataflow subgraph for the `as_field_op` builtin function.""" | ||
|
||
TaskletConnector: TypeAlias = tuple[dace.nodes.Tasklet, str] | ||
|
||
stencil_expr: itir.Lambda | ||
stencil_args: list[Callable] | ||
field_domain: dict[Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]] | ||
field_type: ts.FieldType | ||
offset_provider: dict[str, Connectivity | Dimension] | ||
|
||
def __init__( | ||
self, | ||
sdfg: dace.SDFG, | ||
state: dace.SDFGState, | ||
node: itir.FunCall, | ||
stencil_args: list[Callable], | ||
offset_provider: dict[str, Connectivity | Dimension], | ||
): | ||
super().__init__(sdfg, state) | ||
self.offset_provider = offset_provider | ||
|
||
assert cpm.is_call_to(node.fun, "as_fieldop") | ||
assert len(node.fun.args) == 2 | ||
stencil_expr, domain_expr = node.fun.args | ||
# expect stencil (represented as a lambda function) as first argument | ||
assert isinstance(stencil_expr, itir.Lambda) | ||
# the domain of the field operator is passed as second argument | ||
assert isinstance(domain_expr, itir.FunCall) | ||
|
||
domain = get_domain(domain_expr) | ||
# define field domain with all dimensions in alphabetical order | ||
sorted_domain_dims = sorted(domain.keys(), key=lambda x: x.value) | ||
|
||
# add local storage to compute the field operator over the given domain | ||
# TODO: use type inference to determine the result type | ||
node_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) | ||
|
||
self.field_domain = domain | ||
self.field_type = ts.FieldType(sorted_domain_dims, node_type) | ||
self.stencil_expr = stencil_expr | ||
self.stencil_args = stencil_args | ||
|
||
def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: | ||
dimension_index_fmt = "i_{dim}" | ||
# first visit the list of arguments and build a symbol map | ||
stencil_args: list[IteratorExpr | MemletExpr] = [] | ||
for arg in self.stencil_args: | ||
arg_nodes = arg() | ||
assert len(arg_nodes) == 1 | ||
data_node, arg_type = arg_nodes[0] | ||
# require all argument nodes to be data access nodes (no symbols) | ||
assert isinstance(data_node, dace.nodes.AccessNode) | ||
|
||
if isinstance(arg_type, ts.ScalarType): | ||
scalar_arg = MemletExpr(data_node, sbs.Indices([0])) | ||
stencil_args.append(scalar_arg) | ||
else: | ||
assert isinstance(arg_type, ts.FieldType) | ||
indices: dict[str, MemletExpr | SymbolExpr | TaskletExpr] = { | ||
dim.value: SymbolExpr( | ||
dace.symbolic.SymExpr(dimension_index_fmt.format(dim=dim.value)), | ||
_INDEX_DTYPE, | ||
) | ||
for dim in self.field_domain.keys() | ||
} | ||
iterator_arg = IteratorExpr( | ||
data_node, | ||
[dim.value for dim in arg_type.dims], | ||
sbs.Indices([0] * len(arg_type.dims)), | ||
indices, | ||
) | ||
stencil_args.append(iterator_arg) | ||
|
||
# represent the field operator as a mapped tasklet graph, which will range over the field domain | ||
taskgen = GTIRToTasklet(self.sdfg, self.head_state, self.offset_provider) | ||
input_connections, output_expr = taskgen.visit(self.stencil_expr, args=stencil_args) | ||
assert isinstance(output_expr, TaskletExpr) | ||
|
||
# allocate local temporary storage for the result field | ||
field_shape = [ | ||
# diff between upper and lower bound | ||
self.field_domain[dim][1] - self.field_domain[dim][0] | ||
for dim in self.field_type.dims | ||
] | ||
field_node = self.add_local_storage(self.field_type, field_shape) | ||
|
||
# assume tasklet with single output | ||
output_index = ",".join( | ||
dimension_index_fmt.format(dim=dim.value) for dim in self.field_type.dims | ||
) | ||
output_memlet = dace.Memlet(data=field_node.data, subset=output_index) | ||
|
||
# create map range corresponding to the field operator domain | ||
map_ranges = { | ||
dimension_index_fmt.format(dim=dim.value): f"{lb}:{ub}" | ||
for dim, (lb, ub) in self.field_domain.items() | ||
} | ||
me, mx = self.head_state.add_map(unique_name("map"), map_ranges) | ||
|
||
for data_node, data_subset, lambda_node, lambda_connector in input_connections: | ||
memlet = dace.Memlet(data=data_node.data, subset=data_subset, volume=1) | ||
self.head_state.add_memlet_path( | ||
data_node, | ||
me, | ||
lambda_node, | ||
dst_conn=lambda_connector, | ||
memlet=memlet, | ||
) | ||
self.head_state.add_memlet_path( | ||
output_expr.node, mx, field_node, src_conn=output_expr.connector, memlet=output_memlet | ||
) | ||
|
||
return [(field_node, self.field_type)] |
125 changes: 125 additions & 0 deletions
125
...gt4py/next/program_processors/runners/dace_fieldview/gtir_builtins/gtir_builtin_select.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# GT4Py - GridTools Framework | ||
# | ||
# Copyright (c) 2014-2023, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# This file is part of the GT4Py project and the GridTools framework. | ||
# GT4Py is free software: you can redistribute it and/or modify it under | ||
# the terms of the GNU General Public License as published by the | ||
# Free Software Foundation, either version 3 of the License, or any later | ||
# version. See the LICENSE.txt file at the top-level directory of this | ||
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
|
||
|
||
from typing import Callable | ||
|
||
import dace | ||
|
||
from gt4py import eve | ||
from gt4py.next.iterator import ir as itir | ||
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm | ||
from gt4py.next.program_processors.runners.dace_fieldview.gtir_builtins.gtir_builtin_translator import ( | ||
GTIRBuiltinTranslator, | ||
) | ||
from gt4py.next.program_processors.runners.dace_fieldview.utility import get_symbolic_expr | ||
edopao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from gt4py.next.type_system import type_specifications as ts | ||
|
||
|
||
class GTIRBuiltinSelect(GTIRBuiltinTranslator): | ||
"""Generates the dataflow subgraph for the `select` builtin function.""" | ||
|
||
true_br_builder: Callable | ||
false_br_builder: Callable | ||
edopao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__( | ||
self, | ||
sdfg: dace.SDFG, | ||
state: dace.SDFGState, | ||
dataflow_builder: eve.NodeVisitor, | ||
node: itir.FunCall, | ||
): | ||
super().__init__(sdfg, state) | ||
|
||
assert cpm.is_call_to(node.fun, "select") | ||
assert len(node.fun.args) == 3 | ||
cond_expr, true_expr, false_expr = node.fun.args | ||
|
||
# expect condition as first argument | ||
cond = get_symbolic_expr(cond_expr) | ||
|
||
# use current head state to terminate the dataflow, and add a entry state | ||
# to connect the true/false branch states as follows: | ||
# | ||
# ------------ | ||
# === | select | === | ||
# || ------------ || | ||
# \/ \/ | ||
# ------------ ------------- | ||
# | true | | false | | ||
# ------------ ------------- | ||
# || || | ||
# || ------------ || | ||
# ==> | head | <== | ||
# ------------ | ||
# | ||
select_state = sdfg.add_state_before(state, state.label + "_select") | ||
sdfg.remove_edge(sdfg.out_edges(select_state)[0]) | ||
|
||
# expect true branch as second argument | ||
true_state = sdfg.add_state(state.label + "_true_branch") | ||
sdfg.add_edge(select_state, true_state, dace.InterstateEdge(condition=cond)) | ||
sdfg.add_edge(true_state, state, dace.InterstateEdge()) | ||
self.true_br_builder = dataflow_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) | ||
|
||
# and false branch as third argument | ||
false_state = sdfg.add_state(state.label + "_false_branch") | ||
sdfg.add_edge(select_state, false_state, dace.InterstateEdge(condition=(f"not {cond}"))) | ||
sdfg.add_edge(false_state, state, dace.InterstateEdge()) | ||
self.false_br_builder = dataflow_builder.visit( | ||
false_expr, sdfg=sdfg, head_state=false_state | ||
) | ||
|
||
def build(self) -> list[tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]]: | ||
# retrieve true/false states as predecessors of head state | ||
branch_states = tuple(edge.src for edge in self.sdfg.in_edges(self.head_state)) | ||
assert len(branch_states) == 2 | ||
if branch_states[0].label.endswith("_true_branch"): | ||
true_state, false_state = branch_states | ||
else: | ||
false_state, true_state = branch_states | ||
|
||
true_br_args = self.true_br_builder() | ||
false_br_args = self.false_br_builder() | ||
|
||
output_nodes = [] | ||
for true_br, false_br in zip(true_br_args, false_br_args, strict=True): | ||
true_br_node, true_br_type = true_br | ||
assert isinstance(true_br_node, dace.nodes.AccessNode) | ||
false_br_node, false_br_type = false_br | ||
assert isinstance(false_br_node, dace.nodes.AccessNode) | ||
assert true_br_type == false_br_type | ||
array_type = self.sdfg.arrays[true_br_node.data] | ||
access_node = self.add_local_storage(true_br_type, array_type.shape) | ||
output_nodes.append((access_node, true_br_type)) | ||
|
||
data_name = access_node.data | ||
true_br_output_node = true_state.add_access(data_name) | ||
true_state.add_nedge( | ||
true_br_node, | ||
true_br_output_node, | ||
dace.Memlet.from_array( | ||
true_br_output_node.data, true_br_output_node.desc(self.sdfg) | ||
), | ||
) | ||
|
||
false_br_output_node = false_state.add_access(data_name) | ||
false_state.add_nedge( | ||
false_br_node, | ||
false_br_output_node, | ||
dace.Memlet.from_array( | ||
false_br_output_node.data, false_br_output_node.desc(self.sdfg) | ||
), | ||
) | ||
return output_nodes |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to myself: does the data have to flow over a "local storage"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the final write to an external field, this internal node is redundant, and dace simplify pass will remove it. On the other hand, it is needed when we write back to a field also used as input (in-out field parameters): in this case, the dace simplify pass will correctly keep it. I have added a testcase
test_gtir_update
to cover this case.