diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index e588feb51c..077be93304 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -62,7 +62,7 @@ def __getitem__(self, key): token = tokens.pop(0) result = result.members[token] return result - + def __setitem__(self, key, val): if isinstance(key, str) and '.' in key: raise KeyError('NestedDict does not support setting nested keys') @@ -1335,24 +1335,12 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: defined_syms = set() free_syms = set() - # Exclude data descriptor names, constants, and shapes of global data descriptors - not_strictly_necessary_global_symbols = set() - for name, desc in self.arrays.items(): + # Exclude data descriptor names and constants + for name in self.arrays.keys(): defined_syms.add(name) - if not all_symbols: - used_desc_symbols = desc.used_symbols(all_symbols) - not_strictly_necessary = (desc.used_symbols(all_symbols=True) - used_desc_symbols) - not_strictly_necessary_global_symbols |= set(map(str, not_strictly_necessary)) - defined_syms |= set(self.constants_prop.keys()) - # Start with the set of SDFG free symbols - if all_symbols: - free_syms |= set(self.symbols.keys()) - else: - free_syms |= set(s for s in self.symbols.keys() if s not in not_strictly_necessary_global_symbols) - # Add free state symbols used_before_assignment = set() @@ -1378,6 +1366,11 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: # Remove symbols that were used before they were assigned defined_syms -= used_before_assignment + # Add the set of SDFG symbol parameters + # If all_symbols is False, those symbols would only be added in the case of non-Python tasklets + if all_symbols: + free_syms |= set(self.symbols.keys()) + # Subtract symbols defined in inter-state edges and constants return free_syms - defined_syms @@ -1498,7 +1491,7 @@ def signature_arglist(self, with_types=True, for_call=False, with_arrays=True, a """ arglist = arglist or self.arglist(scalars_only=not with_arrays) return [v.as_arg(name=k, with_types=with_types, for_call=for_call) for k, v in arglist.items()] - + def python_signature_arglist(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> List[str]: """ Returns a list of arguments necessary to call this SDFG, formatted as a list of Data-Centric Python definitions. @@ -1528,7 +1521,7 @@ def signature(self, with_types=True, for_call=False, with_arrays=True, arglist=N :param arglist: An optional cached argument list. """ return ", ".join(self.signature_arglist(with_types, for_call, with_arrays, arglist)) - + def python_signature(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> str: """ Returns a Data-Centric Python signature of this SDFG, used when generating code. diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index dd2523c005..9cec6d11af 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -118,7 +118,7 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = del edge.data.assignments[sym] # If symbols are never unknown any longer, remove from SDFG - fsyms = sdfg.free_symbols + fsyms = sdfg.used_symbols(all_symbols=False) result = {k: v for k, v in result.items() if k not in fsyms} for sym in result: if sym in sdfg.symbols: