Skip to content

Commit

Permalink
compiler: fix handling of initvalue
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 25, 2024
1 parent d847ab3 commit 3360d4f
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 11 deletions.
4 changes: 2 additions & 2 deletions devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici
# `coefficients` method (`taylor` or `symbolic`)
if weights is None:
weights = fd_weights_registry[coefficients](expr, deriv_order, indices, x0)
elif wdim is not None:
weights = [weights._subs(wdim, i) for i in range(len(indices))]

# Enforce fixed precision FD coefficients to avoid variations in results
weights = [sympify(w).evalf(_PRECISION) for w in weights]
Expand Down Expand Up @@ -191,8 +193,6 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici
deriv = DiffDerivative(expr*weights, {dim: indices.free_dim})
else:
terms = []
if wdim is not None:
weights = [weights._subs(wdim, i) for i in range(len(indices))]
for i, c in zip(indices, weights):
# The FD term
term = expr._subs(dim, i) * c
Expand Down
8 changes: 3 additions & 5 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,9 @@ def generate_fd_shortcuts(dims, so, to=0):
def diff_f(expr, deriv_order, dims, fd_order, side=None, **kwargs):
# Spearate dimension to always have cross derivatives return nested
# derivatives.
# Reverse to match the syntax `u.dxdy = (u.dx).dy` with x the inner
# derivative
dims = as_tuple(dims)[::-1]
deriv_order = as_tuple(deriv_order)[::-1]
fd_order = as_tuple(fd_order)[::-1]
dims = as_tuple(dims)
deriv_order = as_tuple(deriv_order)
fd_order = as_tuple(fd_order)
deriv = Derivative(expr, dims[0], deriv_order=deriv_order[0],
fd_order=fd_order[0], side=side, **kwargs)
for (d, do, fo) in zip(dims[1:], deriv_order[1:], fd_order[1:]):
Expand Down
8 changes: 7 additions & 1 deletion devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from devito.tools import Ordering, as_tuple, flatten, filter_sorted, filter_ordered
from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension,
ConditionalDimension)
from devito.types.array import Array
from devito.types.basic import AbstractFunction
from devito.types.grid import MultiSubDimension

Expand Down Expand Up @@ -135,8 +136,13 @@ def _lower_exprs(expressions, subs):
if dimension_map:
indices = [j.xreplace(dimension_map) for j in indices]

mapper[i] = f.indexed[indices]
# Handle Array
if isinstance(f, Array) and f.initvalue is not None:
initv = [_lower_exprs(i, subs) for i in f.initvalue]
# TODO: fix rebuild to avoid new name
f = f._rebuild(name='%si' % f.name, initvalue=initv)

mapper[i] = f.indexed[indices]
# Add dimensions map to the mapper in case dimensions are used
# as an expression, i.e. Eq(u, x, subdomain=xleft)
mapper.update(dimension_map)
Expand Down
2 changes: 2 additions & 0 deletions devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Iterable
from functools import singledispatch

import numpy as np
from sympy import Pow, Add, Mul, Min, Max, S, SympifyError, Tuple, sympify
from sympy.core.add import _addsort
from sympy.core.mul import _mulsort
Expand Down Expand Up @@ -98,6 +99,7 @@ def _(expr, rule):
return _uxreplace(expr, rule)


@_uxreplace_dispatch.register(np.ndarray)
@_uxreplace_dispatch.register(tuple)
@_uxreplace_dispatch.register(Tuple)
@_uxreplace_dispatch.register(list)
Expand Down
6 changes: 3 additions & 3 deletions devito/symbolics/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ def retrieve_indexed(exprs, mode='all', deep=False):
return search(exprs, q_indexed, mode, 'dfs', deep)


def retrieve_functions(exprs, mode='all'):
def retrieve_functions(exprs, mode='all', deep=False):
"""Shorthand to retrieve the DiscreteFunctions in `exprs`."""
indexeds = search(exprs, q_indexed, mode, 'dfs')
indexeds = search(exprs, q_indexed, mode, 'dfs', deep)

functions = search(exprs, q_function, mode, 'dfs')
functions = search(exprs, q_function, mode, 'dfs', deep)
functions.update({i.function for i in indexeds})

return functions
Expand Down
6 changes: 6 additions & 0 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,12 @@ def test_issue_2442(self):
dfdxdy_split = f.dxc.dyc
assert dfdxdy.evaluate == dfdxdy_split.evaluate

def test_cross_newnest(self):
grid = Grid((11, 11))
f = Function(name="f", grid=grid, space_order=2)

assert f.dxdy == f.dx.dy


class TestTwoStageEvaluation:

Expand Down

0 comments on commit 3360d4f

Please sign in to comment.