Skip to content

Commit

Permalink
tests: fix unevalution tests with new cross deriv
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 25, 2024
1 parent 3360d4f commit 263e354
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 28 deletions.
43 changes: 29 additions & 14 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,25 +220,37 @@ def _process_weights(cls, **kwargs):

def __call__(self, x0=None, fd_order=None, side=None, method=None, weights=None):
side = side or self._side
method = method or self._method
weights = weights if weights is not None else self._weights

x0 = self._process_x0(self.dims, x0=x0)
_x0 = frozendict({**self.x0, **x0})
if self.ndims == 1:
fd_order = fd_order or self._fd_order
method = method or self._method
weights = weights if weights is not None else self._weights
return self._rebuild(fd_order=fd_order, side=side, x0=_x0, method=method,
weights=weights)

# Cross derivative

_fd_order = dict(self.fd_order.getters)
try:
_fd_order.update(fd_order or {})
_fd_order = tuple(_fd_order.values())
_fd_order = DimensionTuple(*_fd_order, getters=self.dims)
except TypeError:
assert self.ndims == 1
_fd_order.update({self.dims[0]: fd_order or self.fd_order[0]})
except AttributeError:
raise TypeError("Multi-dimensional Derivative, input expected as a dict")
raise TypeError("fd_order incompatible with dimensions")

# In case this was called on a cross derivative we need to propagate
# the call to the nested derivative
if isinstance(self.expr, Derivative):
_fd_orders = {k: v for k, v in _fd_order.items() if k in self.expr.dims}
_x0s = {k: v for k, v in _x0.items() if k in self.expr.dims and
k not in self.dims}
new_expr = self.expr(x0=_x0s, fd_order=_fd_orders, side=side,
method=method, weights=weights)
else:
new_expr = self.expr

_fd_order = tuple(v for k, v in _fd_order.items() if k in self.dims)
_fd_order = DimensionTuple(*_fd_order, getters=self.dims)

return self._rebuild(fd_order=_fd_order, x0=_x0, side=side)
return self._rebuild(fd_order=_fd_order, x0=_x0, side=side, method=method,
weights=weights, expr=new_expr)

def _rebuild(self, *args, **kwargs):
kwargs['preprocessed'] = True
Expand Down Expand Up @@ -291,7 +303,10 @@ def _xreplace(self, subs):
except AttributeError:
return new, True

new_expr = self.expr.xreplace(subs)
# Resolve nested derivatives
dsubs = {k: v for k, v in subs.items() if isinstance(k, Derivative)}
new_expr = self.expr.xreplace(dsubs)

subs = self._ppsubs + (subs,) # Postponed substitutions
return self._rebuild(subs=subs, expr=new_expr), True

Expand Down Expand Up @@ -445,7 +460,7 @@ def _eval_fd(self, expr, **kwargs):
# Step 3: Evaluate FD of the new expression
if self.method == 'RSFD':
assert len(self.dims) == 1
assert self.deriv_order == 1
assert self.deriv_order[0] == 1
res = d45(expr, self.dims[0], x0=self.x0, expand=expand)
elif len(self.dims) > 1:
assert self.method == 'FD'
Expand Down
3 changes: 2 additions & 1 deletion devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,8 @@ def interp_for_fd(expr, x0, **kwargs):

@interp_for_fd.register(sympy.Derivative)
def _(expr, x0, **kwargs):
return expr.func(expr=interp_for_fd(expr.expr, x0, **kwargs))
x0_expr = {d: v for d, v in x0.items() if d not in expr.dims}
return expr.func(expr=interp_for_fd(expr.expr, x0_expr, **kwargs))


@interp_for_fd.register(sympy.Expr)
Expand Down
2 changes: 1 addition & 1 deletion devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def cross_derivative(expr, dims, fd_order, deriv_order, x0=None, side=None, **kw
Semantically, this is equivalent to
>>> (f*g).dxdy
Derivative(f(x, y)*g(x, y), x, y)
Derivative(Derivative(f(x, y)*g(x, y), x), y)
The only difference is that in the latter case derivatives remain unevaluated.
The expanded form is obtained via ``evaluate``
Expand Down
25 changes: 13 additions & 12 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,19 @@ def test_stencil_derivative(self, SymbolType, dim):

@pytest.mark.parametrize('SymbolType, derivative, dim, expected', [
(Function, ['dx2'], 3, 'Derivative(u(x, y, z), (x, 2))'),
(Function, ['dx2dy'], 3, 'Derivative(u(x, y, z), (x, 2), y)'),
(Function, ['dx2dydz'], 3, 'Derivative(u(x, y, z), (x, 2), y, z)'),
(Function, ['dx2dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'),
(Function, ['dx2dydz'], 3,
'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), z)'),
(Function, ['dx2', 'dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'),
(Function, ['dx2dy', 'dz2'], 3,
'Derivative(Derivative(u(x, y, z), (x, 2), y), (z, 2))'),
'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), (z, 2))'),
(TimeFunction, ['dx2'], 3, 'Derivative(u(t, x, y, z), (x, 2))'),
(TimeFunction, ['dx2dy'], 3, 'Derivative(u(t, x, y, z), (x, 2), y)'),
(TimeFunction, ['dx2dy'], 3, 'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'),
(TimeFunction, ['dx2', 'dy'], 3,
'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'),
(TimeFunction, ['dx', 'dy', 'dx2', 'dz', 'dydz'], 3,
'Derivative(Derivative(Derivative(Derivative(Derivative(u(t, x, y, z), x), y),' +
' (x, 2)), z), y, z)')
'Derivative(Derivative(Derivative(Derivative(Derivative(Derivative(' +
'u(t, x, y, z), x), y), (x, 2)), z), y), z)')
])
def test_unevaluation(self, SymbolType, derivative, dim, expected):
u = SymbolType(name='u', grid=self.grid, time_order=2, space_order=2)
Expand All @@ -111,13 +112,13 @@ def test_unevaluation(self, SymbolType, derivative, dim, expected):

@pytest.mark.parametrize('expr,expected', [
('u.dx + u.dy', 'Derivative(u, x) + Derivative(u, y)'),
('u.dxdy', 'Derivative(u, x, y)'),
('u.dxdy', 'Derivative(Derivative(u, x), y)'),
('u.laplace',
'Derivative(u, (x, 2)) + Derivative(u, (y, 2)) + Derivative(u, (z, 2))'),
('(u.dx + u.dy).dx', 'Derivative(Derivative(u, x) + Derivative(u, y), x)'),
('((u.dx + u.dy).dx + u.dxdy).dx',
'Derivative(Derivative(Derivative(u, x) + Derivative(u, y), x) +' +
' Derivative(u, x, y), x)'),
' Derivative(Derivative(u, x), y), x)'),
('(u**4).dx', 'Derivative(u**4, x)'),
('(u/4).dx', 'Derivative(u/4, x)'),
('((u.dx + v.dy).dx * v.dx).dy.dz',
Expand Down Expand Up @@ -659,9 +660,9 @@ def test_zero_spec(self):
drv1 = Derivative(f, (x, 2), (y, 0))
assert drv0.dims == (x,)
assert drv1.dims == (x, y)
assert drv0.fd_order == 2
assert drv0.fd_order == (2,)
assert drv1.fd_order == (2, 2)
assert drv0.deriv_order == 2
assert drv0.deriv_order == (2,)
assert drv1.deriv_order == (2, 0)

assert drv0.evaluate == drv1.evaluate
Expand Down Expand Up @@ -990,8 +991,8 @@ def test_laplacian_opt(self):
df = f.laplacian(order=2, shift=.5)
for (v, d) in zip(df.args, grid.dimensions):
assert v.dims[0] == d
assert v.fd_order == 2
assert v.deriv_order == 2
assert v.fd_order == (2,)
assert v.deriv_order == (2,)
assert d in v.x0


Expand Down

0 comments on commit 263e354

Please sign in to comment.