diff --git a/pymbolic/mapper/subst_applier.py b/pymbolic/mapper/subst_applier.py new file mode 100644 index 00000000..f221575a --- /dev/null +++ b/pymbolic/mapper/subst_applier.py @@ -0,0 +1,42 @@ +__copyright__ = "Copyright (C) 2021 Thomas Gibson" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import pymbolic.mapper + + +class SubstitutionApplier(pymbolic.mapper.IdentityMapper): + """todo. + """ + + def map_substitution(self, expr, current_substs): + new_substs = current_substs.copy() + new_substs.update( + {variable: self.rec(value, current_substs) + for variable, value in zip(expr.variables, expr.values)}) + return self.rec(expr.child, new_substs) + + def map_variable(self, expr, current_substs): + return current_substs.get(expr.name, expr) + + def __call__(self, expr): + current_substs = {} + return super().__call__(expr, current_substs) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 4055322a..285580fd 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -1420,7 +1420,24 @@ def get_extra_properties(self): class Substitution(Expression): - """Work-alike of sympy's Subs.""" + """A (deferred) substitution applicable to a subexpression. + + See also sympy's ``Subs``. + + .. attribute:: child + + The sub-:class:`Expression` to which the substitution is to be applied. + + .. attribute:: variables + + A sequence of string identifiers of the variables to be replaced with + their corresponding entry in :attr:`values`. + + .. attribute:: values + + A sequence of sub-:class:`Expression` objects corresponding to each + string identifier in :attr:`variables`. + """ init_arg_names = ("child", "variables", "values") @@ -1429,6 +1446,9 @@ def __init__(self, child, variables, values): self.variables = variables self.values = values + if len(variables) != len(values): + raise ValueError("variables and values must have the same length") + def __getinitargs__(self): return (self.child, self.variables, self.values) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 428623d4..c8027327 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -678,6 +678,40 @@ def test_np_bool_handling(): assert evaluate(expr) is True +def test_subst_applier(): + x = prim.Variable("x") + y = prim.Variable("y") + z = prim.Variable("z") + + from pymbolic.mapper.substitutor import substitute as subst_actual + + def subst_deferred(expr, **kwargs): + variables = [] + values = [] + for name, value in kwargs.items(): + variables.append(name) + values.append(value) + return prim.Substitution(expr, variables, values) + + from pymbolic.mapper.subst_applier import SubstitutionApplier + sapp = SubstitutionApplier() + + results = [] + for subst in [subst_actual, subst_deferred]: + expr = subst(x + y, x=5*y) + print(expr) + expr = subst(subst(expr**2, y=z) - subst(expr, y=x), x=y) + print(expr) + expr = sapp(expr) + print(expr) + + results.append(sapp(expr)) + print("--------") + + result_actual, result_deferred = results + assert result_actual == result_deferred + + if __name__ == "__main__": import sys if len(sys.argv) > 1: