From 7246062ec947b835f28f38b241f8b09aeaba4fb6 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sun, 14 Nov 2021 18:58:19 -0600 Subject: [PATCH 1/2] implement an EqualityMapper with caching --- pymbolic/mapper/__init__.py | 7 +- pymbolic/mapper/equality.py | 258 ++++++++++++++++++++++++++++++++++++ pymbolic/polynomial.py | 7 - pymbolic/primitives.py | 33 +++-- test/test_pymbolic.py | 1 - 5 files changed, 277 insertions(+), 29 deletions(-) create mode 100644 pymbolic/mapper/equality.py diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index bac3b58c..8375a34b 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -120,8 +120,8 @@ def handle_unsupported_expression(self, expr, *args, **kwargs): """ raise UnsupportedExpressionError( - "{} cannot handle expressions of type {}".format( - type(self), type(expr))) + "'{}' cannot handle expressions of type '{}'".format( + type(self).__name__, type(expr).__name__)) def __call__(self, expr, *args, **kwargs): """Dispatch *expr* to its corresponding mapper method. Pass on @@ -155,7 +155,7 @@ def __call__(self, expr, *args, **kwargs): rec = __call__ def map_algebraic_leaf(self, expr, *args, **kwargs): - raise NotImplementedError + raise NotImplementedError(type(expr).__name__) def map_variable(self, expr, *args, **kwargs): return self.map_algebraic_leaf(expr, *args, **kwargs) @@ -486,6 +486,7 @@ def map_subscript(self, expr, *args, **kwargs): index = self.rec(expr.index, *args, **kwargs) if aggregate is expr.aggregate and index is expr.index: return expr + return type(expr)(aggregate, index) def map_lookup(self, expr, *args, **kwargs): diff --git a/pymbolic/mapper/equality.py b/pymbolic/mapper/equality.py new file mode 100644 index 00000000..5ce4dea7 --- /dev/null +++ b/pymbolic/mapper/equality.py @@ -0,0 +1,258 @@ +__copyright__ = "Copyright (C) 2021 Alexandru Fikl" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Any, Dict, Tuple + +from pymbolic.mapper import Mapper, UnsupportedExpressionError +from pymbolic.primitives import Expression + + +class EqualityMapper(Mapper): + __slots__ = ["_ids_to_result"] + + def __init__(self) -> None: + self._ids_to_result: Dict[Tuple[int, int], bool] = {} + + def rec(self, expr: Any, other: Any) -> bool: + key = (id(expr), id(other)) + if key in self._ids_to_result: + return self._ids_to_result[key] + + if expr is other: + result = True + elif expr.__class__ != other.__class__: + result = False + else: + try: + method = getattr(self, expr.mapper_method) + except AttributeError: + if isinstance(expr, Expression): + result = self.handle_unsupported_expression(expr, other) + else: + result = self.map_foreign(expr, other) + else: + result = method(expr, other) + + self._ids_to_result[key] = result + return result + + def __call__(self, expr: Any, other: Any) -> bool: + return self.rec(expr, other) + + # {{{ handle_unsupported_expression + + def handle_unsupported_expression(self, expr, other) -> bool: + eq = expr.make_equality_mapper() + if type(self) == type(eq): + raise UnsupportedExpressionError( + "'{}' cannot handle expressions of type '{}'".format( + type(self).__name__, type(expr).__name__)) + + # NOTE: this may look fishy, but we want to preserve the cache as we + # go through the expression tree, so that it does not do + # unnecessary checks when we change the mapper for some subclass + eq._ids_to_result = self._ids_to_result + + return eq(expr, other) + + # }}} + + # {{{ foreign + + def map_tuple(self, expr, other) -> bool: + return ( + len(expr) == len(other) + and all(self.rec(el, other_el) + for el, other_el in zip(expr, other))) + + def map_foreign(self, expr, other) -> bool: + import numpy as np + from pymbolic.primitives import VALID_CONSTANT_CLASSES + + if isinstance(expr, VALID_CONSTANT_CLASSES): + return expr == other + elif isinstance(expr, (np.ndarray, tuple)): + return self.map_tuple(expr, other) + else: + raise ValueError( + f"{type(self).__name__} encountered invalid foreign object: " + f"{expr!r}") + + # }}} + + # {{{ primitives + + # NOTE: `type(expr) == type(other)` is checked in `__call__`, so the + # checks below can assume that the two operands always have the same type + + # NOTE: as much as possible, these should try to put the "cheap" checks + # first so that the shortcircuiting removes the need to to extra work + + # NOTE: `all` is also shortcircuiting, so should be better to use a + # generator there to avoid extra work + + def map_nan(self, expr, other) -> bool: + return True + + def map_wildcard(self, expr, other) -> bool: + return True + + def map_function_symbol(self, expr, other) -> bool: + return True + + def map_variable(self, expr, other) -> bool: + return expr.name == other.name + + def map_subscript(self, expr, other) -> bool: + return ( + self.rec(expr.index, other.index) + and self.rec(expr.aggregate, other.aggregate)) + + def map_lookup(self, expr, other) -> bool: + return ( + expr.name == other.name + and self.rec(expr.aggregate, other.aggregate)) + + def map_call(self, expr, other) -> bool: + return ( + len(expr.parameters) == len(other.parameters) + and self.rec(expr.function, other.function) + and all(self.rec(p, other_p) + for p, other_p in zip(expr.parameters, other.parameters))) + + def map_call_with_kwargs(self, expr, other) -> bool: + return ( + len(expr.parameters) == len(other.parameters) + and len(expr.kw_parameters) == len(other.kw_parameters) + and self.rec(expr.function, other.function) + and all(self.rec(p, other_p) + for p, other_p in zip(expr.parameters, other.parameters)) + and all(k == other_k and self.rec(v, other_v) + for (k, v), (other_k, other_v) in zip( + sorted(expr.kw_parameters.items()), + sorted(other.kw_parameters.items())))) + + def map_sum(self, expr, other) -> bool: + return ( + len(expr.children) == len(other.children) + and all(self.rec(child, other_child) + for child, other_child in zip(expr.children, other.children)) + ) + + map_slice = map_sum + map_product = map_sum + map_min = map_sum + map_max = map_sum + + def map_bitwise_not(self, expr, other) -> bool: + return self.rec(expr.child, other.child) + + map_bitwise_and = map_sum + map_bitwise_or = map_sum + map_bitwise_xor = map_sum + map_logical_and = map_sum + map_logical_or = map_sum + map_logical_not = map_bitwise_not + + def map_quotient(self, expr, other) -> bool: + return ( + self.rec(expr.numerator, other.numerator) + and self.rec(expr.denominator, other.denominator) + ) + + map_floor_div = map_quotient + map_remainder = map_quotient + + def map_power(self, expr, other) -> bool: + return ( + self.rec(expr.base, other.base) + and self.rec(expr.exponent, other.exponent) + ) + + def map_left_shift(self, expr, other) -> bool: + return ( + self.rec(expr.shift, other.shift) + and self.rec(expr.shiftee, other.shiftee)) + + map_right_shift = map_left_shift + + def map_comparison(self, expr, other) -> bool: + return ( + expr.operator == other.operator + and self.rec(expr.left, other.left) + and self.rec(expr.right, other.right)) + + def map_if(self, expr, other) -> bool: + return ( + self.rec(expr.condition, other.condition) + and self.rec(expr.then, other.then) + and self.rec(expr.else_, other.else_)) + + def map_if_positive(self, expr, other) -> bool: + return ( + self.rec(expr.criterion, other.criterion) + and self.rec(expr.then, other.then) + and self.rec(expr.else_, other.else_)) + + def map_common_subexpression(self, expr, other) -> bool: + return ( + expr.prefix == other.prefix + and expr.scope == other.scope + and self.rec(expr.child, other.child) + and all(k == other_k and v == other_v + for (k, v), (other_k, other_v) in zip( + expr.get_extra_properties(), + other.get_extra_properties()))) + + def map_substitution(self, expr, other) -> bool: + return ( + len(expr.variables) == len(other.variables) + and len(expr.values) == len(other.values) + and expr.variables == other.variables + and self.rec(expr.child, other.child) + and all(self.rec(v, other_v) + for v, other_v in zip(expr.values, other.values)) + ) + + def map_derivative(self, expr, other) -> bool: + return ( + len(expr.variables) == len(other.variables) + and self.rec(expr.child, other.child) + and all(self.rec(v, other_v) + for v, other_v in zip(expr.variables, other.variables))) + + def map_polynomial(self, expr, other) -> bool: + return ( + self.rec(expr.Base, other.Data) + and self.rec(expr.Data, other.Data)) + + # }}} + + # {{{ geometry_algebra.primitives + + def map_nabla_component(self, expr, other) -> bool: + return ( + expr.ambient_axis == other.ambient_axis + and expr.nabla_id == other.nabla_id + ) + + # }}} diff --git a/pymbolic/polynomial.py b/pymbolic/polynomial.py index 774f2fbd..f6bb6516 100644 --- a/pymbolic/polynomial.py +++ b/pymbolic/polynomial.py @@ -93,13 +93,6 @@ def traits(self): def __nonzero__(self): return len(self.Data) != 0 - def __eq__(self, other): - return isinstance(other, Polynomial) \ - and (self.Base == other.Base) \ - and (self.Data == other.Data) - def __ne__(self, other): - return not self.__eq__(other) - def __neg__(self): return Polynomial(self.Base, [(exp, -coeff) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index d267f83b..ae59a4df 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -202,7 +202,7 @@ class Expression: .. automethod:: make_stringifier .. automethod:: __eq__ - .. automethod:: is_equal + .. automethod:: make_equality_mapper .. automethod:: __hash__ .. automethod:: get_hash .. automethod:: __str__ @@ -511,18 +511,12 @@ def __repr__(self): # {{{ hash/equality interface def __eq__(self, other): - """Provides equality testing with quick positive and negative paths - based on :func:`id` and :meth:`__hash__`. + """Provides equality testing with quick positive and negative paths. Subclasses should generally not override this method, but instead - provide an implementation of :meth:`is_equal`. + provide an implementation of :meth:`make_equality_mapper`. """ - if self is other: - return True - elif hash(self) != hash(other): - return False - else: - return self.is_equal(other) + return self.make_equality_mapper()(self, other) def __ne__(self, other): return not self.__eq__(other) @@ -555,9 +549,18 @@ def __setstate__(self, state): # {{{ hash/equality backend + def make_equality_mapper(self): + from pymbolic.mapper.equality import EqualityMapper + return EqualityMapper() + def is_equal(self, other): - return (type(other) == type(self) - and self.__getinitargs__() == other.__getinitargs__()) + from warnings import warn + warn("'Expression.is_equal' is deprecated and will be removed in 2023. " + "To customize the equality check, subclass 'EqualityMapper' " + "and overwrite 'Expression.make_equality_mapper'", + DeprecationWarning, stacklevel=2) + + return self.make_equality_mapper()(self, other) def get_hash(self): return hash((type(self).__name__,) + self.__getinitargs__()) @@ -1038,12 +1041,6 @@ class Quotient(QuotientBase): .. attribute:: denominator """ - def is_equal(self, other): - from pymbolic.rational import Rational - return isinstance(other, (Rational, Quotient)) \ - and (self.numerator == other.numerator) \ - and (self.denominator == other.denominator) - mapper_method = intern("map_quotient") diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 774a4c82..f5f17443 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -824,7 +824,6 @@ def map_spatial_constant(self, expr): # {{{ test_equality_complexity -@pytest.mark.xfail def test_equality_complexity(): # NOTE: https://github.com/inducer/pymbolic/issues/73 from numpy.random import default_rng From 46c23921371f6034dc343a1dadfbd146d94787a9 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Thu, 28 Apr 2022 15:48:04 -0500 Subject: [PATCH 2/2] point ci to changed branches --- .github/workflows/ci.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 356bc548..a23ef694 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -91,6 +91,7 @@ jobs: downstream_tests: strategy: + fail-fast: false matrix: downstream_project: [loopy, grudge, pytential, pytato] name: Tests for downstream project ${{ matrix.downstream_project }} @@ -103,6 +104,11 @@ jobs: run: | curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 - test_downstream "$DOWNSTREAM_PROJECT" + + if [[ "$DOWNSTREAM_PROJECT" == "grudge" && "$GITHUB_HEAD_REF" == "equality-mapper" ]]; then + test_downstream "$DOWNSTREAM_PROJECT" + else + test_downstream "https://github.com/alexfikl/$DOWNSTREAM_PROJECT.git@equality-mapper" + fi # vim: sw=4