Skip to content

Commit

Permalink
Fix used_symbols for the case of a symbol that is in SDFG.symbols but…
Browse files Browse the repository at this point in the history
… never used
  • Loading branch information
tbennun committed Sep 24, 2023
1 parent b2d6cd3 commit 92c7a14
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 18 deletions.
27 changes: 10 additions & 17 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __getitem__(self, key):
token = tokens.pop(0)
result = result.members[token]
return result

def __setitem__(self, key, val):
if isinstance(key, str) and '.' in key:
raise KeyError('NestedDict does not support setting nested keys')
Expand Down Expand Up @@ -1335,24 +1335,12 @@ def used_symbols(self, all_symbols: bool) -> Set[str]:
defined_syms = set()
free_syms = set()

# Exclude data descriptor names, constants, and shapes of global data descriptors
not_strictly_necessary_global_symbols = set()
for name, desc in self.arrays.items():
# Exclude data descriptor names and constants
for name in self.arrays.keys():
defined_syms.add(name)

if not all_symbols:
used_desc_symbols = desc.used_symbols(all_symbols)
not_strictly_necessary = (desc.used_symbols(all_symbols=True) - used_desc_symbols)
not_strictly_necessary_global_symbols |= set(map(str, not_strictly_necessary))

defined_syms |= set(self.constants_prop.keys())

# Start with the set of SDFG free symbols
if all_symbols:
free_syms |= set(self.symbols.keys())
else:
free_syms |= set(s for s in self.symbols.keys() if s not in not_strictly_necessary_global_symbols)

# Add free state symbols
used_before_assignment = set()

Expand All @@ -1378,6 +1366,11 @@ def used_symbols(self, all_symbols: bool) -> Set[str]:
# Remove symbols that were used before they were assigned
defined_syms -= used_before_assignment

# Add the set of SDFG symbol parameters
# If all_symbols is False, those symbols would only be added in the case of non-Python tasklets
if all_symbols:
free_syms |= set(self.symbols.keys())

# Subtract symbols defined in inter-state edges and constants
return free_syms - defined_syms

Expand Down Expand Up @@ -1498,7 +1491,7 @@ def signature_arglist(self, with_types=True, for_call=False, with_arrays=True, a
"""
arglist = arglist or self.arglist(scalars_only=not with_arrays)
return [v.as_arg(name=k, with_types=with_types, for_call=for_call) for k, v in arglist.items()]

def python_signature_arglist(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> List[str]:
""" Returns a list of arguments necessary to call this SDFG,
formatted as a list of Data-Centric Python definitions.
Expand Down Expand Up @@ -1528,7 +1521,7 @@ def signature(self, with_types=True, for_call=False, with_arrays=True, arglist=N
:param arglist: An optional cached argument list.
"""
return ", ".join(self.signature_arglist(with_types, for_call, with_arrays, arglist))

def python_signature(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> str:
""" Returns a Data-Centric Python signature of this SDFG, used when generating code.
Expand Down
2 changes: 1 addition & 1 deletion dace/transformation/passes/constant_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] =
del edge.data.assignments[sym]

# If symbols are never unknown any longer, remove from SDFG
fsyms = sdfg.free_symbols
fsyms = sdfg.used_symbols(all_symbols=False)
result = {k: v for k, v in result.items() if k not in fsyms}
for sym in result:
if sym in sdfg.symbols:
Expand Down

0 comments on commit 92c7a14

Please sign in to comment.