From 3360d4fed4b26a1f2c79d335882274f6745f6398 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 24 Sep 2024 14:41:19 -0400 Subject: [PATCH] compiler: fix handling of initvalue --- devito/finite_differences/finite_difference.py | 4 ++-- devito/finite_differences/tools.py | 8 +++----- devito/ir/equations/algorithms.py | 8 +++++++- devito/symbolics/manipulation.py | 2 ++ devito/symbolics/search.py | 6 +++--- tests/test_derivatives.py | 6 ++++++ 6 files changed, 23 insertions(+), 11 deletions(-) diff --git a/devito/finite_differences/finite_difference.py b/devito/finite_differences/finite_difference.py index d2b72ec36c..6ada6c40db 100644 --- a/devito/finite_differences/finite_difference.py +++ b/devito/finite_differences/finite_difference.py @@ -158,6 +158,8 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici # `coefficients` method (`taylor` or `symbolic`) if weights is None: weights = fd_weights_registry[coefficients](expr, deriv_order, indices, x0) + elif wdim is not None: + weights = [weights._subs(wdim, i) for i in range(len(indices))] # Enforce fixed precision FD coefficients to avoid variations in results weights = [sympify(w).evalf(_PRECISION) for w in weights] @@ -191,8 +193,6 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici deriv = DiffDerivative(expr*weights, {dim: indices.free_dim}) else: terms = [] - if wdim is not None: - weights = [weights._subs(wdim, i) for i in range(len(indices))] for i, c in zip(indices, weights): # The FD term term = expr._subs(dim, i) * c diff --git a/devito/finite_differences/tools.py b/devito/finite_differences/tools.py index 8651ea15f7..b9afd8ca02 100644 --- a/devito/finite_differences/tools.py +++ b/devito/finite_differences/tools.py @@ -88,11 +88,9 @@ def generate_fd_shortcuts(dims, so, to=0): def diff_f(expr, deriv_order, dims, fd_order, side=None, **kwargs): # Spearate dimension to always have cross derivatives return nested # derivatives. - # Reverse to match the syntax `u.dxdy = (u.dx).dy` with x the inner - # derivative - dims = as_tuple(dims)[::-1] - deriv_order = as_tuple(deriv_order)[::-1] - fd_order = as_tuple(fd_order)[::-1] + dims = as_tuple(dims) + deriv_order = as_tuple(deriv_order) + fd_order = as_tuple(fd_order) deriv = Derivative(expr, dims[0], deriv_order=deriv_order[0], fd_order=fd_order[0], side=side, **kwargs) for (d, do, fo) in zip(dims[1:], deriv_order[1:], fd_order[1:]): diff --git a/devito/ir/equations/algorithms.py b/devito/ir/equations/algorithms.py index 05cee3089d..85c4a67646 100644 --- a/devito/ir/equations/algorithms.py +++ b/devito/ir/equations/algorithms.py @@ -6,6 +6,7 @@ from devito.tools import Ordering, as_tuple, flatten, filter_sorted, filter_ordered from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension, ConditionalDimension) +from devito.types.array import Array from devito.types.basic import AbstractFunction from devito.types.grid import MultiSubDimension @@ -135,8 +136,13 @@ def _lower_exprs(expressions, subs): if dimension_map: indices = [j.xreplace(dimension_map) for j in indices] - mapper[i] = f.indexed[indices] + # Handle Array + if isinstance(f, Array) and f.initvalue is not None: + initv = [_lower_exprs(i, subs) for i in f.initvalue] + # TODO: fix rebuild to avoid new name + f = f._rebuild(name='%si' % f.name, initvalue=initv) + mapper[i] = f.indexed[indices] # Add dimensions map to the mapper in case dimensions are used # as an expression, i.e. Eq(u, x, subdomain=xleft) mapper.update(dimension_map) diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 1762c4250d..40f331f5b5 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -2,6 +2,7 @@ from collections.abc import Iterable from functools import singledispatch +import numpy as np from sympy import Pow, Add, Mul, Min, Max, S, SympifyError, Tuple, sympify from sympy.core.add import _addsort from sympy.core.mul import _mulsort @@ -98,6 +99,7 @@ def _(expr, rule): return _uxreplace(expr, rule) +@_uxreplace_dispatch.register(np.ndarray) @_uxreplace_dispatch.register(tuple) @_uxreplace_dispatch.register(Tuple) @_uxreplace_dispatch.register(list) diff --git a/devito/symbolics/search.py b/devito/symbolics/search.py index 9d2194abce..9d57cf8135 100644 --- a/devito/symbolics/search.py +++ b/devito/symbolics/search.py @@ -144,11 +144,11 @@ def retrieve_indexed(exprs, mode='all', deep=False): return search(exprs, q_indexed, mode, 'dfs', deep) -def retrieve_functions(exprs, mode='all'): +def retrieve_functions(exprs, mode='all', deep=False): """Shorthand to retrieve the DiscreteFunctions in `exprs`.""" - indexeds = search(exprs, q_indexed, mode, 'dfs') + indexeds = search(exprs, q_indexed, mode, 'dfs', deep) - functions = search(exprs, q_function, mode, 'dfs') + functions = search(exprs, q_function, mode, 'dfs', deep) functions.update({i.function for i in indexeds}) return functions diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index 3472fcedb7..77c59ea390 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -731,6 +731,12 @@ def test_issue_2442(self): dfdxdy_split = f.dxc.dyc assert dfdxdy.evaluate == dfdxdy_split.evaluate + def test_cross_newnest(self): + grid = Grid((11, 11)) + f = Function(name="f", grid=grid, space_order=2) + + assert f.dxdy == f.dx.dy + class TestTwoStageEvaluation: