Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RefineNestedAccess take indices into account when checking for missing free symbols #1317

Merged
merged 5 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,27 @@ def free_symbols(self) -> Set[str]:
result |= self.dst_subset.free_symbols
return result

def get_free_symbols_by_indices(self, indices_src: List[int], indices_dst: List[int]) -> Set[str]:
"""
Returns set of free symbols used in this edges properties but only taking certain indices of the src and dst
subset into account

:param indices_src: The indices of the src subset to take into account
:type indices_src: List[int]
:param indices_dst: The indices of the dst subset to take into account
:type indices_dst: List[int]
:return: The set of free symbols
:rtype: Set[str]
"""
# Symbolic properties are in volume, and the two subsets
result = set()
result |= set(map(str, self.volume.free_symbols))
if self.src_subset:
result |= self.src_subset.get_free_symbols_by_indices(indices_src)
if self.dst_subset:
result |= self.dst_subset.get_free_symbols_by_indices(indices_dst)
return result

def get_stride(self, sdfg: 'dace.sdfg.SDFG', map: 'dace.sdfg.nodes.Map', dim: int = -1) -> 'dace.symbolic.SymExpr':
""" Returns the stride of the underlying memory when traversing a Map.

Expand Down
16 changes: 16 additions & 0 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,22 @@ def free_symbols(self) -> Set[str]:
result |= symbolic.symlist(d).keys()
return result

def get_free_symbols_by_indices(self, indices: List[int]) -> Set[str]:
"""
Get set of free symbols by only looking at the dimension given by the indices list

:param indices: The indices of the dimensions to look at
:type indices: List[int]
:return: The set of free symbols
:rtype: Set[str]
"""
result = set()
for i, dim in enumerate(self.ranges):
if i in indices:
for d in dim:
result |= symbolic.symlist(d).keys()
return result

def reorder(self, order):
""" Re-orders the dimensions in-place according to a permutation list.

Expand Down
4 changes: 2 additions & 2 deletions dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def _check_cand(candidates, outer_edges):
continue

# Check w.r.t. loops
if len(nstate.ranges) > 0:
if nstate is not None and len(nstate.ranges) > 0:
# Re-annotate loop ranges, in case someone changed them
# TODO: Move out of here!
for ns in nsdfg.sdfg.states():
Expand All @@ -1022,7 +1022,7 @@ def _check_cand(candidates, outer_edges):

# If there are any symbols here that are not defined
# in "defined_symbols"
missing_symbols = (memlet.free_symbols - set(nsdfg.symbol_mapping.keys()))
missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), list(indices)) - set(nsdfg.symbol_mapping.keys()))
if missing_symbols:
ignore.add(cname)
continue
Expand Down
32 changes: 32 additions & 0 deletions tests/transformations/refine_nested_access_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,38 @@ def inner_sdfg(A: dace.int32[5, 5], B: dace.int32[5, 5], select: dace.bool[5, 5]
assert np.allclose(B, lower.T + lower - diag)


def test_free_sybmols_only_by_indices():
i = dace.symbol('i')
idx_a = dace.symbol('idx_a')
idx_b = dace.symbol('idx_b')
sdfg = dace.SDFG('refine_free_symbols_only_by_indices')
sdfg.add_array('A', [5], dace.int32)
sdfg.add_array('B', [5, 5], dace.int32)

@dace.program
def inner_sdfg(A: dace.int32[5], B: dace.int32[5, 5], idx_a: int, idx_b: int):
if A[i] > 0.5:
B[i, idx_a] = 1
else:
B[i, idx_b] = 0

state = sdfg.add_state()
A = state.add_access('A')
B = state.add_access('B')
map_entry, map_exit = state.add_map('map', dict(i='0:5'))
nsdfg = state.add_nested_sdfg(inner_sdfg.to_sdfg(simplify=False), sdfg, {'A'}, {'B'}, {'i': 'i'})
state.add_memlet_path(A, map_entry, nsdfg, dst_conn='A', memlet=dace.Memlet.from_array('A', sdfg.arrays['A']))
state.add_memlet_path(nsdfg, map_exit, B, src_conn='B', memlet=dace.Memlet.from_array('B', sdfg.arrays['B']))

num = sdfg.apply_transformations_repeated(RefineNestedAccess)
assert num == 1

assert len(state.in_edges(map_exit)) == 1
edge = state.in_edges(map_exit)[0]
assert edge.data.subset == dace.subsets.Range([(i, i, 1), (0, 4, 1)])


if __name__ == '__main__':
test_refine_dataflow()
test_refine_interstate()
test_free_sybmols_only_by_indices()
Loading