Skip to content

Commit

Permalink
Use finite differences in differentiate2c for implicit functions (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
JCGoran authored Oct 3, 2024
1 parent 0c90584 commit 8f4fc42
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
30 changes: 26 additions & 4 deletions python/nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,13 @@ def forwards_euler2c(diff_string, dt_var, vars, function_calls):
return f"{sp.ccode(x)} = {sp.ccode(solution, user_functions=custom_fcts)}"


def differentiate2c(expression, dependent_var, vars, prev_expressions=None):
def differentiate2c(
expression,
dependent_var,
vars,
prev_expressions=None,
stepsize=1e-3,
):
"""Analytically differentiate supplied expression, return solution as C code.
Expression should be of the form "f(x)", where "x" is
Expand All @@ -595,11 +601,15 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None):
vars: set of all other variables used in expression, e.g. {"a", "b", "c"}
prev_expressions: time-ordered list of preceeding expressions
to evaluate & substitute, e.g. ["b = x + c", "a = 12*b"]
stepsize: in case an analytic expression is not possible, finite differences are used;
this argument sets the step size
Returns:
string containing analytic derivative of expression (including any substitutions
of variables from supplied prev_expressions) w.r.t. dependent_var as C code.
"""
if stepsize <= 0:
raise ValueError("arg `stepsize` must be > 0")
prev_expressions = prev_expressions or []
# every symbol (a.k.a variable) that SymPy
# is going to manipulate needs to be declared
Expand Down Expand Up @@ -643,15 +653,27 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None):
# differentiate w.r.t. x
diff = expr.diff(x).simplify()

# could be something generic like f'(x), in which case we use finite differences
if needs_finite_differences(diff):
diff = (
transform_expression(diff, discretize_derivative)
.subs({finite_difference_step_variable(x): stepsize})
.evalf()
)

# the codegen method does not like undefined function calls, so we extract
# them here
custom_fcts = {str(f.func): str(f.func) for f in diff.atoms(sp.Function)}

# try to simplify expression in terms of existing variables
# ignore any exceptions here, since we already have a valid solution
# so if this further simplification step fails the error is not fatal
try:
# if expression is equal to one of the supplied vars, replace with this var
# can do a simple string comparison here since a var cannot be further simplified
diff_as_string = sp.ccode(diff)
diff_as_string = sp.ccode(diff, user_functions=custom_fcts)
for v in sympy_vars:
if diff_as_string == sp.ccode(sympy_vars[v]):
if diff_as_string == sp.ccode(sympy_vars[v], user_functions=custom_fcts):
diff = sympy_vars[v]

# or if equal to rhs of one of the supplied equations, replace with lhs
Expand All @@ -672,4 +694,4 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None):
pass

# return result as C code in NEURON format
return sp.ccode(diff.evalf())
return sp.ccode(diff.evalf(), user_functions=custom_fcts)
33 changes: 33 additions & 0 deletions test/unit/ode/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# SPDX-License-Identifier: Apache-2.0

from nmodl.ode import differentiate2c, integrate2c
import pytest

import sympy as sp

Expand Down Expand Up @@ -100,6 +101,38 @@ def test_differentiate2c():
"g",
)

result = differentiate2c(
"-f(x)",
"x",
{},
)
# instead of comparing the expression as a string, we convert the string
# back to an expression and compare with an explicit function
size = 100
for index in range(size):
a, b = -5, 5
value = (b - a) * index / size + a
pytest.approx(
float(
sp.sympify(result)
.subs(sp.Function("f"), sp.sin)
.subs({"x": value})
.evalf()
)
) == float(
-sp.Derivative(sp.sin("x"))
.as_finite_difference(1e-3)
.subs({"x": value})
.evalf()
)
with pytest.raises(ValueError):
differentiate2c(
"-f(x)",
"x",
{},
stepsize=-1,
)


def test_integrate2c():

Expand Down

0 comments on commit 8f4fc42

Please sign in to comment.