Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: revamp cross derivative shortcuts #2458

Merged
merged 6 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 48 additions & 28 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,8 @@ def __new__(cls, expr, *dims, **kwargs):
obj = Differentiable.__new__(cls, expr, *var_count)
obj._dims = tuple(OrderedDict.fromkeys(new_dims))

skip = kwargs.get('preprocessed', False) or obj.ndims == 1

obj._fd_order = fd_o if skip else DimensionTuple(*fd_o, getters=obj._dims)
obj._deriv_order = orders if skip else DimensionTuple(*orders, getters=obj._dims)
obj._fd_order = DimensionTuple(*as_tuple(fd_o), getters=obj._dims)
EdCaunt marked this conversation as resolved.
Show resolved Hide resolved
obj._deriv_order = DimensionTuple(*as_tuple(orders), getters=obj._dims)
obj._side = kwargs.get("side")
obj._transpose = kwargs.get("transpose", direct)
obj._method = kwargs.get("method", 'FD')
Expand Down Expand Up @@ -137,7 +135,7 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
fd_orders = kwargs.get('fd_order')
deriv_orders = kwargs.get('deriv_order')
if len(dims) == 1:
dims = tuple([dims[0]]*max(1, deriv_orders))
dims = tuple([dims[0]]*max(1, deriv_orders[0]))
variable_count = [sympy.Tuple(s, dims.count(s))
for s in filter_ordered(dims)]
return dims, deriv_orders, fd_orders, variable_count
Expand Down Expand Up @@ -222,25 +220,34 @@ def _process_weights(cls, **kwargs):

def __call__(self, x0=None, fd_order=None, side=None, method=None, weights=None):
side = side or self._side
method = method or self._method
EdCaunt marked this conversation as resolved.
Show resolved Hide resolved
weights = weights if weights is not None else self._weights

x0 = self._process_x0(self.dims, x0=x0)
_x0 = frozendict({**self.x0, **x0})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we actually need the underscore for all these variable names ?

it'd get way less verbose and then easier to read without the initial underscore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some yes but will try to cleanup a bit

if self.ndims == 1:
fd_order = fd_order or self._fd_order
method = method or self._method
weights = weights if weights is not None else self._weights
return self._rebuild(fd_order=fd_order, side=side, x0=_x0, method=method,
weights=weights)

# Cross derivative

_fd_order = dict(self.fd_order.getters)
try:
_fd_order.update(fd_order or {})
_fd_order = tuple(_fd_order.values())
_fd_order = DimensionTuple(*_fd_order, getters=self.dims)
except TypeError:
assert self.ndims == 1
_fd_order.update({self.dims[0]: fd_order or self.fd_order[0]})
except AttributeError:
raise TypeError("Multi-dimensional Derivative, input expected as a dict")
raise TypeError("fd_order incompatible with dimensions")

if isinstance(self.expr, Derivative):
# In case this was called on a perfect cross-derivative `u.dxdy`
# we need to propagate the call to the nested derivative
x0s = self._filter_dims(self.expr._filter_dims(_x0), neg=True)
expr = self.expr(x0=x0s, fd_order=self.expr._filter_dims(_fd_order),
side=side, method=method)
else:
expr = self.expr

_fd_order = self._filter_dims(_fd_order, as_tuple=True)

return self._rebuild(fd_order=_fd_order, x0=_x0, side=side)
return self._rebuild(fd_order=_fd_order, x0=_x0, side=side, method=method,
weights=weights, expr=expr)

def _rebuild(self, *args, **kwargs):
kwargs['preprocessed'] = True
Expand Down Expand Up @@ -293,15 +300,32 @@ def _xreplace(self, subs):
except AttributeError:
return new, True

# Resolve nested derivatives
dsubs = {k: v for k, v in subs.items() if isinstance(k, Derivative)}
expr = self.expr.xreplace(dsubs)

subs = self._ppsubs + (subs,) # Postponed substitutions
return self._rebuild(subs=subs), True
return self._rebuild(subs=subs, expr=expr), True

@cached_property
def _metadata(self):
ret = [self.dims] + [getattr(self, i) for i in self.__rkwargs__]
ret.append(self.expr.staggered or (None,))
return tuple(ret)

def _filter_dims(self, col, as_tuple=False, neg=False):
"""
Filter collection to only keep the Derivative's dimensions as keys.
"""
if neg:
filtered = {k: v for k, v in col.items() if k not in self.dims}
else:
filtered = {k: v for k, v in col.items() if k in self.dims}
if as_tuple:
return DimensionTuple(*filtered.values(), getters=self.dims)
else:
return filtered

@property
def dims(self):
return self._dims
Expand Down Expand Up @@ -422,13 +446,9 @@ def _eval_fd(self, expr, **kwargs):
"""
# Step 1: Evaluate non-derivative x0. We currently enforce a simple 2nd order
# interpolation to avoid very expensive finite differences on top of it
x0_interp = {}
x0_deriv = {}
for d, v in self.x0.items():
if d in self.dims:
x0_deriv[d] = v
elif not d.is_Time:
x0_interp[d] = v
x0_deriv = self._filter_dims(self.x0)
x0_interp = {d: v for d, v in self.x0.items()
if d not in x0_deriv and not d.is_Time}

if x0_interp and self.method == 'FD':
expr = interp_for_fd(expr, x0_interp, **kwargs)
Expand All @@ -446,7 +466,7 @@ def _eval_fd(self, expr, **kwargs):
# Step 3: Evaluate FD of the new expression
if self.method == 'RSFD':
assert len(self.dims) == 1
assert self.deriv_order == 1
assert self.deriv_order[0] == 1
res = d45(expr, self.dims[0], x0=self.x0, expand=expand)
elif len(self.dims) > 1:
assert self.method == 'FD'
Expand All @@ -455,8 +475,8 @@ def _eval_fd(self, expr, **kwargs):
side=self.side)
else:
assert self.method == 'FD'
res = generic_derivative(expr, self.dims[0], as_tuple(self.fd_order)[0],
self.deriv_order, weights=self.weights,
res = generic_derivative(expr, self.dims[0], self.fd_order[0],
self.deriv_order[0], weights=self.weights,
side=self.side, matvec=self.transpose,
x0=self.x0, expand=expand)

Expand Down
22 changes: 8 additions & 14 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,9 @@ def coefficients(self):
coefficients = {f.coefficients for f in self._functions}
# If there is multiple ones, we have to revert to the highest priority
# i.e we have to remove symbolic
key = lambda x: coeff_priority[x]
key = lambda x: coeff_priority.get(x, -1)
return sorted(coefficients, key=key, reverse=True)[0]

@cached_property
def _coeff_symbol(self, *args, **kwargs):
if self._uses_symbolic_coefficients:
return W
else:
raise ValueError("Couldn't find any symbolic coefficients")

def _eval_at(self, func):
if not func.is_Staggered:
# Cartesian grid, do no waste time
Expand Down Expand Up @@ -427,14 +420,14 @@ def has_free(self, *patterns):


def highest_priority(DiffOp):
prio = lambda x: getattr(x, '_fd_priority', 0)
# We want to get the object with highest priority
# We also need to make sure that the object with the largest
# set of dimensions is used when multiple ones with the same
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: this comment could be made clearer by replacing "ones" with "objects"

# priority appear
prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions))
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]


# Abstract symbol representing a symbolic coefficient
W = sympy.Function('W')


class DifferentiableOp(Differentiable):

__sympy_class__ = None
Expand Down Expand Up @@ -1018,7 +1011,8 @@ def interp_for_fd(expr, x0, **kwargs):

@interp_for_fd.register(sympy.Derivative)
def _(expr, x0, **kwargs):
return expr.func(expr=interp_for_fd(expr.expr, x0, **kwargs))
x0_expr = {d: v for d, v in x0.items() if d not in expr.dims}
return expr.func(expr=interp_for_fd(expr.expr, x0_expr, **kwargs))


@interp_for_fd.register(sympy.Expr)
Expand Down
2 changes: 1 addition & 1 deletion devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def cross_derivative(expr, dims, fd_order, deriv_order, x0=None, side=None, **kw
Semantically, this is equivalent to

>>> (f*g).dxdy
Derivative(f(x, y)*g(x, y), x, y)
Derivative(Derivative(f(x, y)*g(x, y), x), y)

The only difference is that in the latter case derivatives remain unevaluated.
The expanded form is obtained via ``evaluate``
Expand Down
15 changes: 12 additions & 3 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,14 @@ def generate_fd_shortcuts(dims, so, to=0):
from devito.finite_differences.derivative import Derivative

def diff_f(expr, deriv_order, dims, fd_order, side=None, **kwargs):
return Derivative(expr, *as_tuple(dims), deriv_order=deriv_order,
fd_order=fd_order, side=side, **kwargs)
# Separate dimensions to always have cross derivatives return nested
# derivatives. E.g `u.dxdy -> u.dx.dy`
dims = as_tuple(dims)
deriv_order = as_tuple(deriv_order)
fd_order = as_tuple(fd_order)
for (d, do, fo) in zip(dims, deriv_order, fd_order):
expr = Derivative(expr, d, deriv_order=do, fd_order=fo, side=side, **kwargs)
return expr

all_combs = dim_with_order(dims, orders)

Expand Down Expand Up @@ -225,7 +231,8 @@ def numeric_weights(function, deriv_order, indices, x0):
return finite_diff_weights(deriv_order, indices, x0)[-1][-1]


fd_weights_registry = {'taylor': numeric_weights, 'standard': numeric_weights}
fd_weights_registry = {'taylor': numeric_weights, 'standard': numeric_weights,
'symbolic': numeric_weights} # Backward compat for 'symbolic'
coeff_priority = {'taylor': 1, 'standard': 1}


Expand Down Expand Up @@ -318,6 +325,8 @@ def process_weights(weights, expr):
if weights is None:
return 0, None
elif isinstance(weights, Function):
if len(weights.dimensions) == 1:
return weights.shape[0], weights.dimensions[0]
wdim = {d for d in weights.dimensions if d not in expr.dimensions}
assert len(wdim) == 1
wdim = wdim.pop()
Expand Down
8 changes: 7 additions & 1 deletion devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from devito.tools import Ordering, as_tuple, flatten, filter_sorted, filter_ordered
from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension,
ConditionalDimension)
from devito.types.array import Array
from devito.types.basic import AbstractFunction
from devito.types.grid import MultiSubDimension

Expand Down Expand Up @@ -135,8 +136,13 @@ def _lower_exprs(expressions, subs):
if dimension_map:
indices = [j.xreplace(dimension_map) for j in indices]

mapper[i] = f.indexed[indices]
# Handle Array
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FabioLuporini There is probably a cleaner way but don't have time to spend more on this rn

if isinstance(f, Array) and f.initvalue is not None:
initvalue = [_lower_exprs(i, subs) for i in f.initvalue]
# TODO: fix rebuild to avoid new name
f = f._rebuild(name='%si' % f.name, initvalue=initvalue)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure where's my previous comment gone honestly, I mean the one I thought I had written yesterday! Perhaps I forgot to append it in the end...

anyway, I was asking: instead of renaming, how about you just pass function=None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about you just pass function=Non

That's what I did first but somehow it led to some issues I'll think about it another time


mapper[i] = f.indexed[indices]
# Add dimensions map to the mapper in case dimensions are used
# as an expression, i.e. Eq(u, x, subdomain=xleft)
mapper.update(dimension_map)
Expand Down
2 changes: 2 additions & 0 deletions devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Iterable
from functools import singledispatch

import numpy as np
from sympy import Pow, Add, Mul, Min, Max, S, SympifyError, Tuple, sympify
from sympy.core.add import _addsort
from sympy.core.mul import _mulsort
Expand Down Expand Up @@ -98,6 +99,7 @@ def _(expr, rule):
return _uxreplace(expr, rule)


@_uxreplace_dispatch.register(np.ndarray)
@_uxreplace_dispatch.register(tuple)
@_uxreplace_dispatch.register(Tuple)
@_uxreplace_dispatch.register(list)
Expand Down
6 changes: 3 additions & 3 deletions devito/symbolics/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ def retrieve_indexed(exprs, mode='all', deep=False):
return search(exprs, q_indexed, mode, 'dfs', deep)


def retrieve_functions(exprs, mode='all'):
def retrieve_functions(exprs, mode='all', deep=False):
"""Shorthand to retrieve the DiscreteFunctions in `exprs`."""
indexeds = search(exprs, q_indexed, mode, 'dfs')
indexeds = search(exprs, q_indexed, mode, 'dfs', deep)

functions = search(exprs, q_function, mode, 'dfs')
functions = search(exprs, q_function, mode, 'dfs', deep)
functions.update({i.function for i in indexeds})

return functions
Expand Down
3 changes: 2 additions & 1 deletion devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,13 @@ def _apply_coeffs(cls, expr, coefficients):
for coeff in coefficients.coefficients:
derivs = [d for d in retrieve_derivatives(expr)
if coeff.dimension in d.dims and
coeff.deriv_order == d.deriv_order]
coeff.deriv_order == d.deriv_order.get(coeff.dimension, None)]
if not derivs:
continue
mapper.update({d: d._rebuild(weights=coeff.weights) for d in derivs})
if not mapper:
return expr

return expr.xreplace(mapper)

def _evaluate(self, **kwargs):
Expand Down
36 changes: 24 additions & 12 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,19 @@ def test_stencil_derivative(self, SymbolType, dim):

@pytest.mark.parametrize('SymbolType, derivative, dim, expected', [
(Function, ['dx2'], 3, 'Derivative(u(x, y, z), (x, 2))'),
(Function, ['dx2dy'], 3, 'Derivative(u(x, y, z), (x, 2), y)'),
(Function, ['dx2dydz'], 3, 'Derivative(u(x, y, z), (x, 2), y, z)'),
(Function, ['dx2dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'),
(Function, ['dx2dydz'], 3,
'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), z)'),
(Function, ['dx2', 'dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'),
(Function, ['dx2dy', 'dz2'], 3,
'Derivative(Derivative(u(x, y, z), (x, 2), y), (z, 2))'),
'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), (z, 2))'),
(TimeFunction, ['dx2'], 3, 'Derivative(u(t, x, y, z), (x, 2))'),
(TimeFunction, ['dx2dy'], 3, 'Derivative(u(t, x, y, z), (x, 2), y)'),
(TimeFunction, ['dx2dy'], 3, 'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'),
(TimeFunction, ['dx2', 'dy'], 3,
'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'),
(TimeFunction, ['dx', 'dy', 'dx2', 'dz', 'dydz'], 3,
'Derivative(Derivative(Derivative(Derivative(Derivative(u(t, x, y, z), x), y),' +
' (x, 2)), z), y, z)')
'Derivative(Derivative(Derivative(Derivative(Derivative(Derivative(' +
'u(t, x, y, z), x), y), (x, 2)), z), y), z)')
])
def test_unevaluation(self, SymbolType, derivative, dim, expected):
u = SymbolType(name='u', grid=self.grid, time_order=2, space_order=2)
Expand All @@ -111,13 +112,13 @@ def test_unevaluation(self, SymbolType, derivative, dim, expected):

@pytest.mark.parametrize('expr,expected', [
('u.dx + u.dy', 'Derivative(u, x) + Derivative(u, y)'),
('u.dxdy', 'Derivative(u, x, y)'),
('u.dxdy', 'Derivative(Derivative(u, x), y)'),
('u.laplace',
'Derivative(u, (x, 2)) + Derivative(u, (y, 2)) + Derivative(u, (z, 2))'),
('(u.dx + u.dy).dx', 'Derivative(Derivative(u, x) + Derivative(u, y), x)'),
('((u.dx + u.dy).dx + u.dxdy).dx',
'Derivative(Derivative(Derivative(u, x) + Derivative(u, y), x) +' +
' Derivative(u, x, y), x)'),
' Derivative(Derivative(u, x), y), x)'),
('(u**4).dx', 'Derivative(u**4, x)'),
('(u/4).dx', 'Derivative(u/4, x)'),
('((u.dx + v.dy).dx * v.dx).dy.dz',
Expand Down Expand Up @@ -403,6 +404,11 @@ def test_xderiv_x0(self):
- f.dx(x0=x+h_x/2).dy(x0=y+h_y/2).evaluate
assert simplify(expr) == 0

# Check x0 is correctly set
dfdxdx = f.dx(x0=x+h_x/2).dx(x0=x-h_x/2)
assert dict(dfdxdx.x0) == {x: x-h_x/2}
assert dict(dfdxdx.expr.x0) == {x: x+h_x/2}

def test_fd_new_side(self):
grid = Grid((10,))
u = Function(name="u", grid=grid, space_order=4)
Expand Down Expand Up @@ -659,9 +665,9 @@ def test_zero_spec(self):
drv1 = Derivative(f, (x, 2), (y, 0))
assert drv0.dims == (x,)
assert drv1.dims == (x, y)
assert drv0.fd_order == 2
assert drv0.fd_order == (2,)
assert drv1.fd_order == (2, 2)
assert drv0.deriv_order == 2
assert drv0.deriv_order == (2,)
assert drv1.deriv_order == (2, 0)

assert drv0.evaluate == drv1.evaluate
Expand Down Expand Up @@ -731,6 +737,12 @@ def test_issue_2442(self):
dfdxdy_split = f.dxc.dyc
assert dfdxdy.evaluate == dfdxdy_split.evaluate

def test_cross_newnest(self):
grid = Grid((11, 11))
f = Function(name="f", grid=grid, space_order=2)

assert f.dxdy == f.dx.dy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to test that chaining derivatives with various x0 leads to sensible consolidation too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



class TestTwoStageEvaluation:

Expand Down Expand Up @@ -984,8 +996,8 @@ def test_laplacian_opt(self):
df = f.laplacian(order=2, shift=.5)
for (v, d) in zip(df.args, grid.dimensions):
assert v.dims[0] == d
assert v.fd_order == 2
assert v.deriv_order == 2
assert v.fd_order == (2,)
assert v.deriv_order == (2,)
assert d in v.x0


Expand Down
Loading
Loading