From 1a9696e2d0700818a9c09bde80fc5bb857cc152c Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 17 Feb 2023 19:08:19 -0800 Subject: [PATCH 01/13] Add special case for floor division in pystr_to_symbolic --- dace/symbolic.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dace/symbolic.py b/dace/symbolic.py index 69d1d058c3..0c9b5c741d 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -971,7 +971,8 @@ class BitwiseOpConverter(ast.NodeTransformer): ast.BitXor: 'BitwiseXor', ast.Invert: 'BitwiseNot', ast.LShift: 'LeftShift', - ast.RShift: 'RightShift' + ast.RShift: 'RightShift', + ast.FloorDiv: 'int_floor', } def visit_UnaryOp(self, node): @@ -980,7 +981,7 @@ def visit_UnaryOp(self, node): ast.Name(id=BitwiseOpConverter._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()), node) new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[]) return ast.copy_location(new_node, node) - return node + return self.generic_visit(node) def visit_BinOp(self, node): if type(node.op) in BitwiseOpConverter._ast_to_sympy_functions: @@ -990,7 +991,7 @@ def visit_BinOp(self, node): args=[self.visit(value) for value in (node.left, node.right)], keywords=[]) return ast.copy_location(new_node, node) - return node + return self.generic_visit(node) @lru_cache(maxsize=16384) @@ -1054,7 +1055,7 @@ def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic: # NOTE: If the expression contains bitwise operations, replace them with user-functions. # NOTE: Sympy does not support bitwise operations and converts them to boolean operations. - if isinstance(expr, str) and re.search('[&]|[|]|[\^]|[~]|[<<]|[>>]', expr): + if isinstance(expr, str) and re.search('[&]|[|]|[\^]|[~]|[<<]|[>>]|[//]', expr): expr = unparse(BitwiseOpConverter().visit(ast.parse(expr).body[0])) # TODO: support SymExpr over-approximated expressions From 22f699d9de76fff35df0b3878c07463c6d345f72 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 17 Feb 2023 19:08:50 -0800 Subject: [PATCH 02/13] Fix minor bug when trying to load library --- dace/codegen/compiled_sdfg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dace/codegen/compiled_sdfg.py b/dace/codegen/compiled_sdfg.py index a86e843ce0..90c28f592e 100644 --- a/dace/codegen/compiled_sdfg.py +++ b/dace/codegen/compiled_sdfg.py @@ -58,7 +58,10 @@ def is_loaded(self) -> bool: return True if not os.path.isfile(self._stub_filename): return False - self._stub = ctypes.CDLL(self._stub_filename) + try: + self._stub = ctypes.CDLL(self._stub_filename) + except OSError: + return False # Set return types of stub functions self._stub.load_library.restype = ctypes.c_void_p From 9592f9271063d0a78d3a1c282cf8aea2ccb5af82 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 17 Feb 2023 19:12:56 -0800 Subject: [PATCH 03/13] Only add symbolic elements to shape inference equations --- dace/frontend/python/parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 57539531d0..22099988c9 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -111,7 +111,8 @@ def infer_symbols_from_datadescriptor(sdfg: SDFG, if repldict: sym_dim = sym_dim.subs(repldict) - equations.append(sym_dim - real_dim) + if symbolic.issymbolic(sym_dim - real_dim): + equations.append(sym_dim - real_dim) if len(symbols) == 0: return {} @@ -896,5 +897,4 @@ def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any], simplify: Opti sdfg._regenerate_code = self.regenerate_code sdfg._recompile = self.recompile - return sdfg, cached From 2687aec18aed60c815384f0bedbaf803ce663bb9 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 17 Feb 2023 19:13:26 -0800 Subject: [PATCH 04/13] Replace floor division by int_floor in Python frontend --- dace/frontend/python/replacements.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 03c6119ea4..72a94648e4 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -553,7 +553,12 @@ def _arange(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs): @oprepo.replaces('elementwise') @oprepo.replaces('dace.elementwise') -def _elementwise(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, func: Union[StringLiteral, str], in_array: str, out_array=None): +def _elementwise(pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + func: Union[StringLiteral, str], + in_array: str, + out_array=None): """ Apply a lambda function to each element in the input. """ @@ -1904,7 +1909,9 @@ def _scalar_sym_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, lef ">=": sp.GreaterThan, "<=": sp.LessThan, ">": sp.StrictGreaterThan, - "<": sp.StrictLessThan + "<": sp.StrictLessThan, + # Boolean ops + "//": symbolic.int_floor, } @@ -4612,10 +4619,13 @@ def _cupy_empty(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, shape: Shape, 'GtE': '__ge__' } + def _makeboolop(op: str, method: str): + @oprepo.replaces_operator('StringLiteral', op, otherclass='StringLiteral') def _op(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, op1: StringLiteral, op2: StringLiteral): return getattr(op1, method)(op2) + for op, method in _boolop_to_method.items(): _makeboolop(op, method) From 4efe365e7a5157d77961c69c07d10c6b0089bde7 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 17 Feb 2023 19:17:06 -0800 Subject: [PATCH 05/13] Clear up interface of symstr and uses --- dace/codegen/common.py | 2 +- dace/codegen/cppunparse.py | 8 +++- dace/codegen/targets/cpp.py | 2 +- dace/codegen/targets/cuda.py | 6 +-- dace/codegen/targets/fpga.py | 13 +++--- dace/codegen/targets/intel_fpga.py | 11 +++-- dace/codegen/targets/mpi.py | 4 +- dace/codegen/targets/sve/infer.py | 1 - dace/data.py | 10 +++-- dace/libraries/blas/nodes/dot.py | 7 +-- dace/libraries/blas/nodes/ger.py | 4 +- dace/libraries/lapack/nodes/getri.py | 1 - dace/libraries/lapack/nodes/getrs.py | 1 - dace/sdfg/propagation.py | 22 ++++++--- dace/symbolic.py | 45 ++++++++++++------- dace/transformation/dataflow/map_for_loop.py | 2 +- dace/transformation/subgraph/expansion.py | 1 - dace/transformation/subgraph/helpers.py | 1 - .../subgraph/subgraph_fusion.py | 2 +- .../trivial_loop_elimination_test.py | 2 +- 20 files changed, 87 insertions(+), 58 deletions(-) diff --git a/dace/codegen/common.py b/dace/codegen/common.py index 3cf8b703a1..85c13ff22f 100644 --- a/dace/codegen/common.py +++ b/dace/codegen/common.py @@ -35,7 +35,7 @@ def find_outgoing_edges(node, dfg): @lru_cache(maxsize=16384) def _sym2cpp(s, arrayexprs): - return cppunparse.pyexpr2cpp(symbolic.symstr(s, arrayexprs)) + return cppunparse.pyexpr2cpp(symbolic.symstr(s, arrayexprs, cpp_mode=True)) def sym2cpp(s, arrayexprs: Optional[Set[str]] = None) -> Union[str, List[str]]: diff --git a/dace/codegen/cppunparse.py b/dace/codegen/cppunparse.py index a0a14a3033..eae0ed229e 100644 --- a/dace/codegen/cppunparse.py +++ b/dace/codegen/cppunparse.py @@ -120,6 +120,7 @@ def interleave(inter, f, seq, **kwargs): class LocalScheme(object): + def is_defined(self, local_name, current_depth): raise NotImplementedError('Abstract class') @@ -131,6 +132,7 @@ def clear_scope(self, from_indentation): class CPPLocals(LocalScheme): + def __init__(self): # Maps local name to a 3-tuple of line number, scope (measured in indentation) and type self.locals = {} @@ -163,6 +165,7 @@ class CPPUnparser: """Methods in this class recursively traverse an AST and output C++ source code for the abstract syntax; original formatting is disregarded. """ + def __init__(self, tree, depth, @@ -1132,7 +1135,9 @@ def py2cpp(code, expr_semicolon=True, defined_symbols=None): return '\n'.join(py2cpp(stmt) for stmt in code) elif isinstance(code, sympy.Basic): from dace import symbolic - return cppunparse(ast.parse(symbolic.symstr(code)), expr_semicolon, defined_symbols=defined_symbols) + return cppunparse(ast.parse(symbolic.symstr(code, cpp_mode=True)), + expr_semicolon, + defined_symbols=defined_symbols) elif code.__class__.__name__ == 'function': try: code_str = inspect.getsource(code) @@ -1151,6 +1156,7 @@ def py2cpp(code, expr_semicolon=True, defined_symbols=None): else: raise NotImplementedError('Unsupported type for py2cpp') + @lru_cache(maxsize=16384) def pyexpr2cpp(expr): return py2cpp(expr, expr_semicolon=False) diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index c34bad1f08..a0fa160b7b 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -1256,7 +1256,7 @@ def visit_BinOp(self, node: ast.BinOp): **self.constants, 'dace': dace, 'math': math })) - evaluated = symbolic.symstr(symbolic.evaluate(unparsed, self.constants)) + evaluated = symbolic.symstr(symbolic.evaluate(unparsed, self.constants), cpp_mode=True) node.right = ast.parse(evaluated).body[0].value except (TypeError, AttributeError, NameError, KeyError, ValueError, SyntaxError): return self.generic_visit(node) diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index 539d2b49cb..ea729458c0 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -983,7 +983,7 @@ def _emit_copy(self, state_id, src_node, src_storage, dst_node, dst_storage, dst callsite_stream.write("}") if dims == 1 and not (src_strides[-1] != 1 or dst_strides[-1] != 1): - copysize = ' * '.join([cppunparse.pyexpr2cpp(symbolic.symstr(s)) for s in copy_shape]) + copysize = ' * '.join(_topy(copy_shape)) array_length = copysize copysize += ' * sizeof(%s)' % dtype.ctype @@ -2415,8 +2415,8 @@ def make_ptr_vector_cast(self, *args, **kwargs): def _topy(arr): """ Converts an array of symbolic variables (or one) to C++ strings. """ if not isinstance(arr, list): - return cppunparse.pyexpr2cpp(symbolic.symstr(arr)) - return [cppunparse.pyexpr2cpp(symbolic.symstr(d)) for d in arr] + return cppunparse.pyexpr2cpp(symbolic.symstr(arr, cpp_mode=True)) + return [cppunparse.pyexpr2cpp(symbolic.symstr(d, cpp_mode=True)) for d in arr] def _named_idx(idx): diff --git a/dace/codegen/targets/fpga.py b/dace/codegen/targets/fpga.py index 61f2eeb21a..31cfc6e13f 100644 --- a/dace/codegen/targets/fpga.py +++ b/dace/codegen/targets/fpga.py @@ -1629,7 +1629,7 @@ def _emit_copy(self, sdfg, state_id, src_node, src_storage, dst_node, dst_storag else: raise TypeError("Memory copy type mismatch: {} vs {}".format(host_dtype, device_dtype)) - copysize = " * ".join([cppunparse.pyexpr2cpp(dace.symbolic.symstr(s)) for s in copy_shape]) + copysize = " * ".join([cppunparse.pyexpr2cpp(dace.symbolic.symstr(s, cpp_mode=True)) for s in copy_shape]) src_subset = memlet.src_subset or memlet.subset dst_subset = memlet.dst_subset or memlet.subset @@ -2198,13 +2198,14 @@ def _generate_MapExit(self, sdfg, dfg, state_id, node, function_stream, callsite # ranges could have been defined in terms of floor/ceiling. Before printing the code # they are converted from a symbolic expression to a C++ compilable expression for it, r in reversed(list(zip(pipeline.params, pipeline.range))): - callsite_stream.write("if ({it} >= {end}) {{\n{it} = {begin};\n".format(it=it, - begin=dace.symbolic.symstr( - r[0]), - end=dace.symbolic.symstr(r[1]))) + callsite_stream.write("if ({it} >= {end}) {{\n{it} = {begin};\n".format( + it=it, + begin=dace.symbolic.symstr(r[0], cpp_mode=True), + end=dace.symbolic.symstr(r[1], cpp_mode=True))) for it, r in zip(pipeline.params, pipeline.range): callsite_stream.write("}} else {{\n{it} += {step};\n}}\n".format(it=it, - step=dace.symbolic.symstr(r[2]))) + step=dace.symbolic.symstr( + r[2], cpp_mode=True))) if len(cond) > 0: callsite_stream.write("}\n") callsite_stream.write("}\n}\n") diff --git a/dace/codegen/targets/intel_fpga.py b/dace/codegen/targets/intel_fpga.py index 4c104000e6..ddbf531791 100644 --- a/dace/codegen/targets/intel_fpga.py +++ b/dace/codegen/targets/intel_fpga.py @@ -1476,12 +1476,11 @@ def visit_BinOp(self, node): left_value = cppunparse.cppunparse(self.visit(node.left), expr_semicolon=False) try: - unparsed = symbolic.pystr_to_symbolic( - evalnode(node.right, { - **self.constants, - 'dace': dace, - })) - evaluated = symbolic.symstr(evaluate(unparsed, self.constants)) + unparsed = symbolic.pystr_to_symbolic(evalnode(node.right, { + **self.constants, + 'dace': dace, + })) + evaluated = symbolic.symstr(evaluate(unparsed, self.constants), cpp_mode=True) infered_type = infer_expr_type(evaluated, self.dtypes) right_value = evaluated diff --git a/dace/codegen/targets/mpi.py b/dace/codegen/targets/mpi.py index 044134e0ac..c2c85f897f 100644 --- a/dace/codegen/targets/mpi.py +++ b/dace/codegen/targets/mpi.py @@ -117,8 +117,8 @@ def generate_scope(self, sdfg, dfg_scope, state_id, function_stream, callsite_st callsite_stream.write('{\n', sdfg, state_id, map_header) callsite_stream.write( '%s %s = %s + __dace_comm_rank * (%s);\n' % - (symtypes[var], var, cppunparse.pyexpr2cpp( - symbolic.symstr(begin)), cppunparse.pyexpr2cpp(symbolic.symstr(skip))), sdfg, state_id, map_header) + (symtypes[var], var, cppunparse.pyexpr2cpp(symbolic.symstr(begin, cpp_mode=True)), + cppunparse.pyexpr2cpp(symbolic.symstr(skip, cpp_mode=True))), sdfg, state_id, map_header) self._frame.allocate_arrays_in_scope(sdfg, map_header, function_stream, callsite_stream) diff --git a/dace/codegen/targets/sve/infer.py b/dace/codegen/targets/sve/infer.py index 49a84ca683..10556a9581 100644 --- a/dace/codegen/targets/sve/infer.py +++ b/dace/codegen/targets/sve/infer.py @@ -8,7 +8,6 @@ from dace import dtypes from dace.codegen import cppunparse from dace.symbolic import SymExpr -from dace.symbolic import symstr import sympy import sys diff --git a/dace/data.py b/dace/data.py index bb1cd08351..7443c39974 100644 --- a/dace/data.py +++ b/dace/data.py @@ -212,7 +212,8 @@ def validate(self): # `validate` function. def _validate(self): if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)) for s in self.shape): - raise TypeError('Shape must be a list or tuple of integer values ' 'or symbols') + raise TypeError('Shape must be a list or tuple of integer values ' + 'or symbols') return True def to_json(self): @@ -615,7 +616,8 @@ def validate(self): raise TypeError('Strides must be the same size as shape') if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)) for s in self.strides): - raise TypeError('Strides must be a list or tuple of integer ' 'values or symbols') + raise TypeError('Strides must be a list or tuple of integer ' + 'values or symbols') if len(self.offset) != len(self.shape): raise TypeError('Offset must be the same size as shape') @@ -848,7 +850,7 @@ def sizes(self): return [d.name if isinstance(d, symbolic.symbol) else str(d) for d in self.shape] def size_string(self): - return (" * ".join([cppunparse.pyexpr2cpp(symbolic.symstr(s)) for s in self.shape])) + return (" * ".join([cppunparse.pyexpr2cpp(symbolic.symstr(s, cpp_mode=True)) for s in self.shape])) def is_stream_array(self): return _prod(self.shape) != 1 @@ -926,6 +928,7 @@ class View(Array): In the Python frontend, ``numpy.reshape`` and ``numpy.ndarray.view`` both generate Views. """ + def validate(self): super().validate() @@ -949,6 +952,7 @@ class Reference(Array): In order to enable data-centric analysis and optimizations, avoid using References as much as possible. """ + def validate(self): super().validate() diff --git a/dace/libraries/blas/nodes/dot.py b/dace/libraries/blas/nodes/dot.py index 7a134cb7f4..40f5e0bae6 100644 --- a/dace/libraries/blas/nodes/dot.py +++ b/dace/libraries/blas/nodes/dot.py @@ -4,7 +4,6 @@ import dace.library import dace.properties import dace.sdfg.nodes -from dace.symbolic import symstr from dace.transformation.transformation import ExpandTransformation from dace.libraries.blas import blas_helpers from .. import environments @@ -544,9 +543,11 @@ def validate(self, sdfg, state): desc_res = sdfg.arrays[e.data.data] if desc_x.dtype != desc_y.dtype: - raise TypeError("Data types of input operands must be equal: " f"{desc_x.dtype}, {desc_y.dtype}") + raise TypeError("Data types of input operands must be equal: " + f"{desc_x.dtype}, {desc_y.dtype}") if desc_x.dtype.base_type != desc_res.dtype.base_type: - raise TypeError("Data types of input and output must be equal: " f"{desc_x.dtype}, {desc_res.dtype}") + raise TypeError("Data types of input and output must be equal: " + f"{desc_x.dtype}, {desc_res.dtype}") # Squeeze input memlets squeezed1 = copy.deepcopy(in_memlets[0].subset) diff --git a/dace/libraries/blas/nodes/ger.py b/dace/libraries/blas/nodes/ger.py index 21287dd29a..32170c9301 100644 --- a/dace/libraries/blas/nodes/ger.py +++ b/dace/libraries/blas/nodes/ger.py @@ -1,5 +1,4 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from dace.symbolic import symstr from dace.properties import Property, SymbolicProperty from dace.transformation.transformation import ExpandTransformation from dace.frontend.common import op_repository as oprepo @@ -288,7 +287,8 @@ def validate(self, sdfg, state): desc_y = sdfg.arrays[memlet.data] if size_a is None or size_x is None: - raise ValueError("Expected at least two inputs to Ger " "(matrix A and vector x)") + raise ValueError("Expected at least two inputs to Ger " + "(matrix A and vector x)") if size_y is None: raise ValueError("Expected exactly one output from Ger (vector y).") diff --git a/dace/libraries/lapack/nodes/getri.py b/dace/libraries/lapack/nodes/getri.py index 7e22f82f6e..a2de84fb6b 100644 --- a/dace/libraries/lapack/nodes/getri.py +++ b/dace/libraries/lapack/nodes/getri.py @@ -3,7 +3,6 @@ import dace.library import dace.properties import dace.sdfg.nodes -from dace.symbolic import symstr from dace.transformation.transformation import ExpandTransformation from .. import environments from dace import data as dt, dtypes, memlet as mm, SDFG, SDFGState, symbolic diff --git a/dace/libraries/lapack/nodes/getrs.py b/dace/libraries/lapack/nodes/getrs.py index 6aba519118..3af1cc9e44 100644 --- a/dace/libraries/lapack/nodes/getrs.py +++ b/dace/libraries/lapack/nodes/getrs.py @@ -3,7 +3,6 @@ import dace.library import dace.properties import dace.sdfg.nodes -from dace.symbolic import symstr from dace.transformation.transformation import ExpandTransformation from .. import environments from dace import data as dt, dtypes, memlet as mm, SDFG, SDFGState, symbolic diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index a68fd98de2..7561696528 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -26,6 +26,7 @@ class MemletPattern(object): """ A pattern match on a memlet subset that can be used for propagation. """ + def can_be_applied(self, expressions, variable_context, node_range, orig_edges): raise NotImplementedError @@ -37,6 +38,7 @@ def propagate(self, array, expressions, node_range): class SeparableMemletPattern(object): """ Memlet pattern that can be applied to each of the dimensions separately. """ + def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, dim_index, total_dims): raise NotImplementedError @@ -47,6 +49,7 @@ def propagate(self, array, dim_exprs, node_range): @registry.autoregister class SeparableMemlet(MemletPattern): """ Meta-memlet pattern that applies all separable memlet patterns. """ + def can_be_applied(self, expressions, variable_context, node_range, orig_edges): # Assuming correct dimensionality in each of the expressions data_dims = len(expressions[0]) @@ -111,6 +114,7 @@ class AffineSMemlet(SeparableMemletPattern): """ Separable memlet pattern that matches affine expressions, i.e., of the form `a * {index} + b`. """ + def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, dim_index, total_dims): params = variable_context[-1] @@ -300,6 +304,7 @@ class ModuloSMemlet(SeparableMemletPattern): Acts as a meta-pattern: Finds the underlying pattern for `f(x)`. """ + def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, dim_index, total_dims): # Pattern does not support unions of expressions if len(dim_exprs) > 1: return False @@ -351,6 +356,7 @@ class ConstantSMemlet(SeparableMemletPattern): """ Separable memlet pattern that matches constant (i.e., unrelated to current scope) expressions. """ + def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, dim_index, total_dims): # Pattern does not support unions of expressions. TODO: Support if len(dim_exprs) > 1: return False @@ -394,6 +400,7 @@ def propagate(self, array, dim_exprs, node_range): class GenericSMemlet(SeparableMemletPattern): """ Separable memlet pattern that detects any expression, and propagates interval bounds. Used as a last resort. """ + def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, dim_index, total_dims): dims = [] for dim in dim_exprs: @@ -446,8 +453,9 @@ def propagate(self, array, dim_exprs, node_range): if node_rs != 1: pos_lastindex = symbolic.pystr_to_symbolic( '%s + int_floor(%s - %s, %s) * %s' % - (symbolic.symstr(node_rb), symbolic.symstr(node_re), symbolic.symstr(node_rb), - symbolic.symstr(node_rs), symbolic.symstr(node_rs))) + (symbolic.symstr(node_rb, cpp_mode=False), symbolic.symstr(node_re, cpp_mode=False), + symbolic.symstr(node_rb, cpp_mode=False), symbolic.symstr( + node_rs, cpp_mode=False), symbolic.symstr(node_rs, cpp_mode=False))) neg_firstindex = pos_lastindex if isinstance(dim_exprs, list): @@ -515,6 +523,7 @@ def _subexpr(dexpr, repldict): class ConstantRangeMemlet(MemletPattern): """ Memlet pattern that matches arbitrary expressions with constant range. """ + def can_be_applied(self, expressions, variable_context, node_range, orig_edges): constant_range = True for dim in node_range: @@ -970,7 +979,8 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node): if inside_memlet.wcr is not None: if (memlet.wcr is not None and memlet.wcr != inside_memlet.wcr): - warnings.warn('Memlet appears with more than one' ' type of write-conflict resolution.') + warnings.warn('Memlet appears with more than one' + ' type of write-conflict resolution.') memlet.wcr = inside_memlet.wcr if memlet.dynamic and memlet.volume == 0: @@ -1013,7 +1023,8 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node): # union of the ranges to merge the subsets. if memlet.subset is not None: if memlet.subset.dims() != subset.dims(): - raise ValueError('Cannot merge subset ranges ' 'of unequal dimension!') + raise ValueError('Cannot merge subset ranges ' + 'of unequal dimension!') else: memlet.subset = subsets.union(memlet.subset, subset) if memlet.subset is None: @@ -1314,7 +1325,8 @@ def propagate_memlet(dfg_state, if arr is None: if memlet.data not in sdfg.arrays: - raise KeyError('Data descriptor (Array, Stream) "%s" not defined ' 'in SDFG.' % memlet.data) + raise KeyError('Data descriptor (Array, Stream) "%s" not defined ' + 'in SDFG.' % memlet.data) # FIXME: A memlet alone (without an edge) cannot figure out whether it is data<->data or data<->code # so this test cannot be used diff --git a/dace/symbolic.py b/dace/symbolic.py index 0c9b5c741d..c3ae29e554 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -1077,9 +1077,10 @@ class DaceSympyPrinter(sympy.printing.str.StrPrinter): """ Several notational corrections for integer math and C++ translation that sympy.printing.cxxcode does not provide. """ - def __init__(self, arrays, *args, **kwargs): + def __init__(self, arrays, cpp_mode=False, *args, **kwargs): super().__init__(*args, **kwargs) self.arrays = arrays or set() + self.cpp_mode = cpp_mode def _print_Float(self, expr): nf = sympy_numeric_fix(expr) @@ -1090,7 +1091,7 @@ def _print_Float(self, expr): def _print_Function(self, expr): if str(expr.func) in self.arrays: return f'{expr.func}[{expr.args[0]}]' - if str(expr.func) == 'int_floor': + if self.cpp_mode and str(expr.func) == 'int_floor': return '((%s) / (%s))' % (self._print(expr.args[0]), self._print(expr.args[1])) if str(expr.func) == 'AND': return f'(({self._print(expr.args[0])}) and ({self._print(expr.args[1])}))' @@ -1111,14 +1112,18 @@ def _print_Not(self, expr): return '(not (%s))' % self._print(expr.args[0]) def _print_Infinity(self, expr): - return 'INFINITY' + if self.cpp_mode: + return 'INFINITY' + return super()._print_Infinity(expr) def _print_NegativeInfinity(self, expr): - return '-INFINITY' + if self.cpp_mode: + return '-INFINITY' + return super()._print_NegativeInfinity(expr) def _print_Symbol(self, expr): if expr.name == 'NoneSymbol': - return 'nullptr' + return 'nullptr' if self.cpp_mode else 'None' return super()._print_Symbol(expr) def _print_Pow(self, expr): @@ -1126,11 +1131,12 @@ def _print_Pow(self, expr): exponent = self._print(expr.args[1]) # Special case for square root - try: - if float(exponent) == 0.5: - return f'dace::math::sqrt({base})' - except ValueError: - pass + if self.cpp_mode: + try: + if float(exponent) == 0.5: + return f'dace::math::sqrt({base})' + except ValueError: + pass # Special case for integer powers try: @@ -1148,29 +1154,34 @@ def _print_Pow(self, expr): res = f'reciprocal({res})' return res except ValueError: - return "dace::math::pow({f}, {s})".format(f=self._print(expr.args[0]), s=self._print(expr.args[1])) + if self.cpp_mode: + return "dace::math::pow({f}, {s})".format(f=self._print(expr.args[0]), s=self._print(expr.args[1])) + else: + return f'({self._print(expr.args[0])}) ** ({self._print(expr.args[1])})' @lru_cache(maxsize=16384) -def symstr(sym, arrayexprs: Optional[Set[str]] = None) -> str: +def symstr(sym, arrayexprs: Optional[Set[str]] = None, cpp_mode=False) -> str: """ - Convert a symbolic expression to a C++ compilable expression. + Convert a symbolic expression to a compilable expression. :param sym: Symbolic expression to convert. :param arrayexprs: Set of names of arrays, used to convert SymPy user-functions back to array expressions. - :return: C++-compilable expression. + :param cpp_mode: If True, returns a C++-compilable expression. Otherwise, + returns a Python expression. + :return: Expression in string format depending on the value of ``cpp_mode``. """ if isinstance(sym, SymExpr): - return symstr(sym.expr, arrayexprs) + return symstr(sym.expr, arrayexprs, cpp_mode=cpp_mode) try: sym = sympy_numeric_fix(sym) sym = sympy_intdiv_fix(sym) sym = sympy_divide_fix(sym) - sstr = DaceSympyPrinter(arrayexprs).doprint(sym) + sstr = DaceSympyPrinter(arrayexprs, cpp_mode).doprint(sym) if isinstance(sym, symbol) or isinstance(sym, sympy.Symbol) or isinstance( sym, sympy.Number) or dtypes.isconstant(sym): @@ -1178,7 +1189,7 @@ def symstr(sym, arrayexprs: Optional[Set[str]] = None) -> str: else: return '(' + sstr + ')' except (AttributeError, TypeError, ValueError): - sstr = DaceSympyPrinter(arrayexprs).doprint(sym) + sstr = DaceSympyPrinter(arrayexprs, cpp_mode).doprint(sym) return '(' + sstr + ')' diff --git a/dace/transformation/dataflow/map_for_loop.py b/dace/transformation/dataflow/map_for_loop.py index 4294540536..b1d81e20a8 100644 --- a/dace/transformation/dataflow/map_for_loop.py +++ b/dace/transformation/dataflow/map_for_loop.py @@ -70,7 +70,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG) -> Tuple[nodes.NestedSDFG, SDFGSta from dace.codegen.targets.cpp import cpp_array_expr def replace_param(param): - param = symbolic.symstr(param) + param = symbolic.symstr(param, cpp_mode=False) for p, pval in param_to_edge.items(): # TODO: Correct w.r.t. connector type param = param.replace(p, cpp_array_expr(nsdfg, pval.data)) diff --git a/dace/transformation/subgraph/expansion.py b/dace/transformation/subgraph/expansion.py index dee2603c62..db1e9b59ab 100644 --- a/dace/transformation/subgraph/expansion.py +++ b/dace/transformation/subgraph/expansion.py @@ -9,7 +9,6 @@ from dace.sdfg.graph import SubgraphView from dace.transformation import transformation from dace.properties import make_properties, Property -from dace.symbolic import symstr from dace.sdfg.propagation import propagate_memlets_sdfg from dace.transformation.subgraph import helpers from collections import defaultdict diff --git a/dace/transformation/subgraph/helpers.py b/dace/transformation/subgraph/helpers.py index 630bd329bd..b2af49c879 100644 --- a/dace/transformation/subgraph/helpers.py +++ b/dace/transformation/subgraph/helpers.py @@ -5,7 +5,6 @@ from dace.memlet import Memlet from dace.sdfg import replace, SDFG, SDFGState from dace.properties import make_properties, Property -from dace.symbolic import symstr from dace.sdfg.propagation import propagate_memlets_sdfg from dace.sdfg.graph import SubgraphView diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index ef90af207f..fa66319ddf 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -10,7 +10,7 @@ from dace.memlet import Memlet from dace.transformation import transformation from dace.properties import EnumProperty, ListProperty, make_properties, Property -from dace.symbolic import symstr, overapproximate +from dace.symbolic import overapproximate from dace.sdfg.propagation import propagate_memlets_sdfg, propagate_memlet, propagate_memlets_scope, _propagate_node from dace.transformation.subgraph import helpers from dace.transformation.dataflow import RedundantArray diff --git a/tests/transformations/trivial_loop_elimination_test.py b/tests/transformations/trivial_loop_elimination_test.py index 8c49bc308c..6f2769f921 100644 --- a/tests/transformations/trivial_loop_elimination_test.py +++ b/tests/transformations/trivial_loop_elimination_test.py @@ -2,7 +2,7 @@ from dace.sdfg.nodes import MapEntry import dace from dace.transformation.interstate import TrivialLoopElimination -from dace.symbolic import pystr_to_symbolic, symstr +from dace.symbolic import pystr_to_symbolic import unittest import numpy as np From 856cf52e9a8897f7fc1094e5795cab13c169e516 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 17 Feb 2023 19:24:30 -0800 Subject: [PATCH 06/13] Add test for floor division issue --- .../binop_replacements_test.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/python_frontend/binop_replacements_test.py b/tests/python_frontend/binop_replacements_test.py index 9f08590dd2..b74baf4277 100644 --- a/tests/python_frontend/binop_replacements_test.py +++ b/tests/python_frontend/binop_replacements_test.py @@ -405,6 +405,25 @@ def test_mixed(): assert (C[0] == ref) +def test_sym_floordiv(): + M, N, K = 20, 20, 2 + + @dace.program + def tester(a: dace.float64[M, N, K]): + for flat in dace.map[0:M * N * K]: + i = flat // (N * K) + resid = flat % (N * K) + j = resid // K + k = resid % K + a[i, j, k] = i * 1000 + j * 100 + k + + a = np.random.rand(20, 20, 2) + ref = np.copy(a) + tester(a) + tester.f(ref) + assert np.allclose(a, ref) + + if __name__ == "__main__": test_array_array() test_array_array1() @@ -441,3 +460,4 @@ def test_mixed(): test_sym_sym() test_mixed() test_bool_bool() + test_sym_floordiv() From bc03ff92c200cc3d2a0e1b080629c4124df3fbd5 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sat, 18 Feb 2023 10:10:59 -0800 Subject: [PATCH 07/13] Fix string processing of symbolic expressions --- dace/subsets.py | 4 ++-- dace/symbolic.py | 19 ++++++++++--------- tests/sdfg/multiple_connector_test.py | 1 + 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/dace/subsets.py b/dace/subsets.py index 98f0f9f562..05918edf9b 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -461,8 +461,8 @@ def from_string(string): # If dimension has only 1 token, then it is an index (not a range), # treat as range of size 1 if len(uni_dim_tokens) < 2: - ranges.append( - (symbolic.pystr_to_symbolic(uni_dim_tokens[0]), symbolic.pystr_to_symbolic(uni_dim_tokens[0]), 1)) + value = symbolic.pystr_to_symbolic(uni_dim_tokens[0].strip()) + ranges.append((value, value, 1)) continue #return Range(ranges) # If dimension has more than 4 tokens, the range is invalid diff --git a/dace/symbolic.py b/dace/symbolic.py index c3ae29e554..c6d484fa18 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -1048,15 +1048,16 @@ def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic: # _clash also allows pi, beta, zeta and other common greek letters locals.update(_sympy_clash) - # Sympy processes "not/and/or" as direct evaluation. Replace with - # And/Or(x, y), Not(x) - if isinstance(expr, str) and re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b', expr): - expr = unparse(SympyBooleanConverter().visit(ast.parse(expr).body[0])) - - # NOTE: If the expression contains bitwise operations, replace them with user-functions. - # NOTE: Sympy does not support bitwise operations and converts them to boolean operations. - if isinstance(expr, str) and re.search('[&]|[|]|[\^]|[~]|[<<]|[>>]|[//]', expr): - expr = unparse(BitwiseOpConverter().visit(ast.parse(expr).body[0])) + if isinstance(expr, str): + # Sympy processes "not/and/or" as direct evaluation. Replace with + # And/Or(x, y), Not(x) + if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b', expr): + expr = unparse(SympyBooleanConverter().visit(ast.parse(expr).body[0])) + + # NOTE: If the expression contains bitwise operations, replace them with user-functions. + # NOTE: Sympy does not support bitwise operations and converts them to boolean operations. + if re.search('[&]|[|]|[\^]|[~]|[<<]|[>>]|[//]', expr): + expr = unparse(BitwiseOpConverter().visit(ast.parse(expr).body[0])) # TODO: support SymExpr over-approximated expressions try: diff --git a/tests/sdfg/multiple_connector_test.py b/tests/sdfg/multiple_connector_test.py index 50d7dad569..501be355a2 100644 --- a/tests/sdfg/multiple_connector_test.py +++ b/tests/sdfg/multiple_connector_test.py @@ -1,3 +1,4 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import dace from dace.sdfg import InvalidSDFGError From 674f600d5f34b45fb1158ba028010ebffd204046 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 22 Feb 2023 12:33:50 -0800 Subject: [PATCH 08/13] Fix memlet volume computation for nested indirect accesses --- dace/frontend/python/newast.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 1ac9573212..5a4cc33018 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2988,16 +2988,23 @@ def _add_access( state = self.last_state + new_memlet = None + if has_indirection: + new_memlet = dace.Memlet.from_array(parent_name, parent_array) + new_memlet.volume = rng.num_elements() + else: + new_memlet = dace.Memlet.simple(parent_name, rng) + if access_type == 'r': if has_indirection: - self.inputs[var_name] = (state, dace.Memlet.from_array(parent_name, parent_array), inner_indices) + self.inputs[var_name] = (state, new_memlet, inner_indices) else: - self.inputs[var_name] = (state, dace.Memlet.simple(parent_name, rng), inner_indices) + self.inputs[var_name] = (state, new_memlet, inner_indices) else: if has_indirection: - self.outputs[var_name] = (state, dace.Memlet.from_array(parent_name, parent_array), inner_indices) + self.outputs[var_name] = (state, new_memlet, inner_indices) else: - self.outputs[var_name] = (state, dace.Memlet.simple(parent_name, rng), inner_indices) + self.outputs[var_name] = (state, new_memlet, inner_indices) self.variables[var_name] = var_name return (var_name, squeezed_rng) From c9e27a12b03936bcb026497752b078b8b4b8073d Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 10 Mar 2023 17:01:41 -0800 Subject: [PATCH 09/13] Fix review comments --- dace/data.py | 3 +-- dace/frontend/python/replacements.py | 2 +- dace/libraries/blas/nodes/dot.py | 6 ++---- dace/libraries/blas/nodes/ger.py | 3 +-- dace/sdfg/propagation.py | 9 +++------ 5 files changed, 8 insertions(+), 15 deletions(-) diff --git a/dace/data.py b/dace/data.py index 7443c39974..2fc5f334c6 100644 --- a/dace/data.py +++ b/dace/data.py @@ -616,8 +616,7 @@ def validate(self): raise TypeError('Strides must be the same size as shape') if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)) for s in self.strides): - raise TypeError('Strides must be a list or tuple of integer ' - 'values or symbols') + raise TypeError('Strides must be a list or tuple of integer values or symbols') if len(self.offset) != len(self.shape): raise TypeError('Offset must be the same size as shape') diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 72a94648e4..3e0ff554d5 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -1910,7 +1910,7 @@ def _scalar_sym_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, lef "<=": sp.LessThan, ">": sp.StrictGreaterThan, "<": sp.StrictLessThan, - # Boolean ops + # Binary ops "//": symbolic.int_floor, } diff --git a/dace/libraries/blas/nodes/dot.py b/dace/libraries/blas/nodes/dot.py index 40f5e0bae6..bbbd4fa0a1 100644 --- a/dace/libraries/blas/nodes/dot.py +++ b/dace/libraries/blas/nodes/dot.py @@ -543,11 +543,9 @@ def validate(self, sdfg, state): desc_res = sdfg.arrays[e.data.data] if desc_x.dtype != desc_y.dtype: - raise TypeError("Data types of input operands must be equal: " - f"{desc_x.dtype}, {desc_y.dtype}") + raise TypeError(f"Data types of input operands must be equal: {desc_x.dtype}, {desc_y.dtype}") if desc_x.dtype.base_type != desc_res.dtype.base_type: - raise TypeError("Data types of input and output must be equal: " - f"{desc_x.dtype}, {desc_res.dtype}") + raise TypeError(f"Data types of input and output must be equal: {desc_x.dtype}, {desc_res.dtype}") # Squeeze input memlets squeezed1 = copy.deepcopy(in_memlets[0].subset) diff --git a/dace/libraries/blas/nodes/ger.py b/dace/libraries/blas/nodes/ger.py index 32170c9301..55ab4677f8 100644 --- a/dace/libraries/blas/nodes/ger.py +++ b/dace/libraries/blas/nodes/ger.py @@ -287,8 +287,7 @@ def validate(self, sdfg, state): desc_y = sdfg.arrays[memlet.data] if size_a is None or size_x is None: - raise ValueError("Expected at least two inputs to Ger " - "(matrix A and vector x)") + raise ValueError("Expected at least two inputs to Ger (matrix A and vector x)") if size_y is None: raise ValueError("Expected exactly one output from Ger (vector y).") diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 7561696528..89ba6928c7 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -979,8 +979,7 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node): if inside_memlet.wcr is not None: if (memlet.wcr is not None and memlet.wcr != inside_memlet.wcr): - warnings.warn('Memlet appears with more than one' - ' type of write-conflict resolution.') + warnings.warn('Memlet appears with more than one type of write-conflict resolution.') memlet.wcr = inside_memlet.wcr if memlet.dynamic and memlet.volume == 0: @@ -1023,8 +1022,7 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node): # union of the ranges to merge the subsets. if memlet.subset is not None: if memlet.subset.dims() != subset.dims(): - raise ValueError('Cannot merge subset ranges ' - 'of unequal dimension!') + raise ValueError('Cannot merge subset ranges of unequal dimension!') else: memlet.subset = subsets.union(memlet.subset, subset) if memlet.subset is None: @@ -1325,8 +1323,7 @@ def propagate_memlet(dfg_state, if arr is None: if memlet.data not in sdfg.arrays: - raise KeyError('Data descriptor (Array, Stream) "%s" not defined ' - 'in SDFG.' % memlet.data) + raise KeyError('Data descriptor (Array, Stream) "%s" not defined in SDFG.' % memlet.data) # FIXME: A memlet alone (without an edge) cannot figure out whether it is data<->data or data<->code # so this test cannot be used From 372a13944428dc5509879227ea6cd68220aaec9c Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 10 Mar 2023 17:11:55 -0800 Subject: [PATCH 10/13] Further fix nested volumes if symbolic --- dace/frontend/python/newast.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 4794afc6f2..477bd1e18b 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2991,7 +2991,8 @@ def _add_access( new_memlet = None if has_indirection: new_memlet = dace.Memlet.from_array(parent_name, parent_array) - new_memlet.volume = rng.num_elements() + volume = rng.num_elements() + new_memlet.volume = volume if not symbolic.issymbolic(volume) else -1 else: new_memlet = dace.Memlet.simple(parent_name, rng) From 74205a2b912c86fbe6b811443b2f8ad5fe6515ea Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 15 Mar 2023 10:51:09 -0700 Subject: [PATCH 11/13] Add HIP choices to auto-optimize library nodes --- dace/transformation/auto/auto_optimize.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index fdf0fee2c7..f0fe22e181 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -319,10 +319,23 @@ def tile_wcrs(graph_or_subgraph: GraphViewType, validate_all: bool, prefer_parti def find_fast_library(device: dtypes.DeviceType) -> List[str]: + from dace.codegen.common import get_gpu_backend + # Returns the optimized library node implementations for the given target # device if device is dtypes.DeviceType.GPU: - return ['cuBLAS', 'cuSolverDn', 'GPUAuto', 'CUB', 'pure'] + try: + backend = get_gpu_backend() + except RuntimeError: + backend = 'none' + + if backend == 'cuda': + return ['cuBLAS', 'cuSolverDn', 'GPUAuto', 'CUB', 'pure'] + elif backend == 'hip': + return ['rocBLAS', 'GPUAuto', 'pure'] + else: + return ['GPUAuto', 'pure'] + elif device is dtypes.DeviceType.FPGA: return ['FPGA_PartialSums', 'FPGAPartialReduction', 'FPGA_Accumulate', 'FPGA1DSystolic', 'pure'] elif device is dtypes.DeviceType.CPU: From a18751f9ff7f3b24c91a65f6784fa5f44bebdf4a Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 15 Mar 2023 10:51:40 -0700 Subject: [PATCH 12/13] Fix sympy dependency version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 464f299f5e..12562c2a85 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ }, include_package_data=True, install_requires=[ - 'numpy', 'networkx >= 2.5', 'astunparse', 'sympy<=1.9', 'pyyaml', 'ply', 'websockets', 'requests', 'flask', + 'numpy', 'networkx >= 2.5', 'astunparse', 'sympy>=1.9', 'pyyaml', 'ply', 'websockets', 'requests', 'flask', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', 'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"' ] + cmake_requires, From ebbe0b84fad6485f0c8e3e0dd6043e04dc8169ff Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 15 Mar 2023 10:53:11 -0700 Subject: [PATCH 13/13] Support more values for transpositions in Python frontend Gemm call --- dace/libraries/blas/nodes/gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/libraries/blas/nodes/gemm.py b/dace/libraries/blas/nodes/gemm.py index 8a700b9cb6..4a49397255 100644 --- a/dace/libraries/blas/nodes/gemm.py +++ b/dace/libraries/blas/nodes/gemm.py @@ -995,8 +995,8 @@ def __init__(self, name, location=None, transA=False, transB=False, alpha=1, bet location=location, inputs=({"_a", "_b", "_cin"} if beta != 0 and cin else {"_a", "_b"}), outputs={"_c"}) - self.transA = transA - self.transB = transB + self.transA = True if transA else False + self.transB = True if transB else False self.alpha = alpha self.beta = beta self.cin = cin