Skip to content

Commit

Permalink
Merge pull request #1996 from gaurav-arya/ag-wprot
Browse files Browse the repository at this point in the history
Support user-specified `W_prototype <: AbstractSciMLOperator`
  • Loading branch information
ChrisRackauckas authored Jul 31, 2023
2 parents b807375 + 93f46e8 commit 6a5bf95
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ PrecompileTools = "1"
Preferences = "1.3"
RecursiveArrayTools = "2.36"
Reexport = "0.2, 1.0"
SciMLBase = "1.90"
SciMLBase = "1.94"
SciMLNLSolve = "0.1"
SciMLOperators = "0.2.12, 0.3"
SimpleNonlinearSolve = "0.1.4"
Expand Down
3 changes: 3 additions & 0 deletions src/caches/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ function alg_cache(alg::RadauIIA5, u, rate_prototype, ::Type{uEltypeNoUnits},
fw3 = zero(rate_prototype)

J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
if J isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by RadauIIA5.")
end
W2 = similar(J, Complex{eltype(W1)})

du1 = zero(rate_prototype)
Expand Down
48 changes: 36 additions & 12 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,21 @@ end
SciMLBase.isinplace(::WOperator{IIP}, i) where {IIP} = IIP
Base.eltype(W::WOperator) = eltype(W.J)

set_gamma!(W::WOperator, gamma) = (W.gamma = gamma; W)
function SciMLOperators.update_coefficients!(W::WOperator, u, p, t)
update_coefficients!(W.J, u, p, t)
update_coefficients!(W.mass_matrix, u, p, t)
!isnothing(W.jacvec) && update_coefficients!(W.jacvec, u, p, t)
# In WOperator update_coefficients!, accept both missing u/p/t and missing dtgamma/transform and don't update them in that case.
# This helps support partial updating logic used with Newton solvers.
function SciMLOperators.update_coefficients!(W::WOperator,
u = nothing,
p = nothing,
t = nothing;
dtgamma = nothing,
transform = nothing)
if (u !== nothing) && (p !== nothing) && (t !== nothing)
update_coefficients!(W.J, u, p, t)
update_coefficients!(W.mass_matrix, u, p, t)
!isnothing(W.jacvec) && update_coefficients!(W.jacvec, u, p, t)
end
dtgamma !== nothing && (W.gamma = dtgamma)
transform !== nothing && (W.transform = transform)
W
end

Expand Down Expand Up @@ -672,10 +682,15 @@ function calc_W!(W, integrator, nlsolver::Union{Nothing, AbstractNLSolver}, cach
end

# calculate W
if W isa WOperator
isnewton(nlsolver) || update_coefficients!(W, uprev, p, t) # we will call `update_coefficients!` in NLNewton
W.transform = W_transform
set_gamma!(W, dtgamma)
if W isa AbstractSciMLOperator && !(W isa Union{WOperator, StaticWOperator})
update_coefficients!(W, uprev, p, t; transform = W_transform, dtgamma)
elseif W isa WOperator
if isnewton(nlsolver)
# we will call `update_coefficients!` for u/p/t in NLNewton
update_coefficients!(W; transform = W_transform, dtgamma)
else
update_coefficients!(W, uprev, p, t; transform = W_transform, dtgamma)
end
if W.J !== nothing && !(W.J isa AbstractSciMLOperator)
islin, isode = islinearfunction(integrator)
islin ? (J = isode ? f.f : f.f1.f) :
Expand All @@ -687,7 +702,6 @@ function calc_W!(W, integrator, nlsolver::Union{Nothing, AbstractNLSolver}, cach
islin, isode = islinearfunction(integrator)
islin ? (J = isode ? f.f : f.f1.f) :
(new_jac && (calc_J!(J, integrator, lcache, next_step)))
update_coefficients!(W, uprev, p, t)
new_W && !isdae && jacobian2W!(W, mass_matrix, dtgamma, J, W_transform)
end
if isnewton(nlsolver)
Expand Down Expand Up @@ -724,7 +738,10 @@ end
islin, isode = islinearfunction(integrator)
!isdae && update_coefficients!(mass_matrix, uprev, p, t)

if islin
if cache.W isa AbstractSciMLOperator && !(cache.W isa Union{WOperator, StaticWOperator})
J = update_coefficients(cache.J, uprev, p, t)
W = update_coefficients(cache.W, uprev, p, t; dtgamma, transform = W_transform)
elseif islin
J = isode ? f.f : f.f1.f # unwrap the Jacobian accordingly
W = WOperator{false}(mass_matrix, dtgamma, J, uprev; transform = W_transform)
elseif DiffEqBase.has_jac(f)
Expand Down Expand Up @@ -831,7 +848,14 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
# TODO - if jvp given, make it SciMLOperators.FunctionOperator
# TODO - make mass matrix a SciMLOperator so it can be updated with time. Default to IdentityOperator
islin, isode = islinearfunction(f, alg)
if f.jac_prototype isa AbstractSciMLOperator
if isdefined(f, :W_prototype) && (f.W_prototype isa AbstractSciMLOperator)
# We use W_prototype when it is provided as a SciMLOperator, and in this case we require jac_prototype to be a SciMLOperator too.
if !(f.jac_prototype isa AbstractSciMLOperator)
error("SciMLOperator for W_prototype only supported when jac_prototype is a SciMLOperator, but got $(typeof(f.jac_prototype))")
end
W = f.W_prototype
J = f.jac_prototype
elseif f.jac_prototype isa AbstractSciMLOperator
W = WOperator{IIP}(f, u, dt)
J = W.J
elseif IIP && f.jac_prototype !== nothing && concrete_jac(alg) === nothing &&
Expand Down
24 changes: 17 additions & 7 deletions src/nlsolve/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Equations II, Springer Series in Computational Mathematics. ISBN
978-3-642-05221-7. Section IV.8.
[doi:10.1007/978-3-642-05221-7](https://doi.org/10.1007/978-3-642-05221-7).
"""
@muladd function compute_step!(nlsolver::NLSolver{<:NLNewton, false}, integrator)
@muladd function compute_step!(nlsolver::NLSolver{<:NLNewton, false}, integrator, γW)
@unpack uprev, t, p, dt, opts = integrator
@unpack z, tmp, γ, α, cache = nlsolver
@unpack tstep, W, invγdt = cache
Expand All @@ -63,8 +63,10 @@ Equations II, Springer Series in Computational Mathematics. ISBN
end

# update W
if W isa AbstractSciMLOperator
W = update_coefficients!(W, ustep, p, tstep)
if W isa Union{WOperator, StaticWOperator}
update_coefficients!(W, ustep, p, tstep)
elseif W isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by out-of-place Newton solve.")
end

dz = _reshape(W \ _vec(ztmp), axes(ztmp))
Expand All @@ -88,7 +90,7 @@ Equations II, Springer Series in Computational Mathematics. ISBN
ndz
end

@muladd function compute_step!(nlsolver::NLSolver{<:NLNewton, true}, integrator)
@muladd function compute_step!(nlsolver::NLSolver{<:NLNewton, true}, integrator, γW)
@unpack uprev, t, p, dt, opts = integrator
@unpack z, tmp, ztmp, γ, α, iter, cache = nlsolver
@unpack W_γdt, ustep, tstep, k, atmp, dz, W, new_W, invγdt, linsolve, weight = cache
Expand All @@ -103,8 +105,11 @@ end
b, ustep = _compute_rhs!(nlsolver, integrator, f, z)

# update W
if W isa AbstractSciMLOperator
if W isa Union{WOperator, StaticWOperator}
update_coefficients!(W, ustep, p, tstep)
elseif W isa AbstractSciMLOperator
# logic for generic AbstractSciMLOperator does not yet support partial state updates, so provide full state
update_coefficients!(W, ustep, p, tstep; dtgamma = γW, transform = true)
end

if integrator.opts.adaptive
Expand Down Expand Up @@ -155,7 +160,9 @@ end
ndz
end

@muladd function compute_step!(nlsolver::NLSolver{<:NLNewton, true, <:Array}, integrator)
@muladd function compute_step!(nlsolver::NLSolver{<:NLNewton, true, <:Array},
integrator,
γW)
@unpack uprev, t, p, dt, opts = integrator
@unpack z, tmp, ztmp, γ, α, iter, cache = nlsolver
@unpack W_γdt, ustep, tstep, k, atmp, dz, W, new_W, invγdt, linsolve, weight = cache
Expand All @@ -169,8 +176,11 @@ end
b, ustep = _compute_rhs!(nlsolver, integrator, f, z)

# update W
if W isa AbstractSciMLOperator
if W isa Union{WOperator, StaticWOperator}
update_coefficients!(W, ustep, p, tstep)
elseif W isa AbstractSciMLOperator
# logic for generic AbstractSciMLOperator does not yet support partial state updates, so provide full state
update_coefficients!(W, ustep, p, tstep; dtgamma = γW, transform = true)
end

if integrator.opts.adaptive
Expand Down
7 changes: 6 additions & 1 deletion src/nlsolve/nlsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ function nlsolve!(nlsolver::AbstractNLSolver, integrator::DiffEqBase.DEIntegrato

# compute next step and calculate norm of residuals
iter > 1 && (ndzprev = ndz)
ndz = compute_step!(nlsolver, integrator)
if isnewton(nlsolver)
# Newton solve requires γW in order to update W
ndz = compute_step!(nlsolver, integrator, γW)
else
ndz = compute_step!(nlsolver, integrator)
end
if !isfinite(ndz)
nlsolver.status = Divergence
nlsolver.nfails += 1
Expand Down
2 changes: 1 addition & 1 deletion test/interface/utility_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using OrdinaryDiffEq: WOperator, set_gamma!, calc_W, calc_W!
using OrdinaryDiffEq: WOperator, calc_W, calc_W!
using OrdinaryDiffEq, LinearAlgebra, SparseArrays, Random, Test, LinearSolve

@testset "calc_W and calc_W!" begin
Expand Down
58 changes: 58 additions & 0 deletions test/interface/wprototype_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import ODEProblemLibrary: prob_ode_vanderpol_stiff
using SciMLOperators
using OrdinaryDiffEq
using LinearAlgebra
using LinearSolve
using Test

for prob in (prob_ode_vanderpol_stiff,)
# Ensure all solutions use the same linear solve for fair comparison.
# TODO: in future, ensure and test that polyalg chooses the best linear solve when unspecified.
for alg in (Rosenbrock23(linsolve = KrylovJL_GMRES()),
FBDF(linsolve = KrylovJL_GMRES()))
# Manually construct a custom W operator using the Jacobian
N = length(prob.u0)
J_op = MatrixOperator(rand(N, N); update_func! = prob.f.jac)
gamma_op = ScalarOperator(0.0;
update_func = (old_val, u, p, t; dtgamma) -> dtgamma,
accepted_kwargs = (:dtgamma,))
transform_op = ScalarOperator(0.0;
update_func = (old_op, u, p, t; dtgamma, transform) -> transform ?
inv(dtgamma) :
one(dtgamma),
accepted_kwargs = (:dtgamma, :transform))
W_op = -(I - gamma_op * J_op) * transform_op

# Make problem with custom MatrixOperator jac_prototype
f_J = ODEFunction(prob.f.f; jac_prototype = J_op)
prob_J = remake(prob; f = f_J)

# Test that the custom jacobian is used
integrator = init(prob_J, Rosenbrock23())
@test integrator.cache.J isa typeof(J_op)

# Make problem with custom SciMLOperator W_prototype
f_W = ODEFunction(prob.f.f; jac_prototype = J_op, W_prototype = W_op)
prob_W = remake(prob; f = f_W)

# Test that the custom W operator is used
integrator = init(prob_W, alg)
if hasproperty(integrator.cache, :W)
@test integrator.cache.W isa typeof(W_op)
elseif hasproperty(integrator.cache.nlsolver.cache, :W)
@test integrator.cache.nlsolver.cache.W isa typeof(W_op)
else
error("W-prototype test expected W in integrator.cache or integrator.cache.nlsolver.cache")
end

# Run solves
sol = solve(prob, alg)
sol_J = solve(prob_J, alg) # note: direct linsolve in this case is broken, see #1998
sol_W = solve(prob_W, alg)

@test all(isapprox.(sol_J.t, sol.t))
@test all(isapprox.(sol_J.u, sol.u))
@test all(isapprox.(sol_W.t, sol.t))
@test all(isapprox.(sol_W.u, sol.u))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ end
@time @safetestset "Sparse Diff Tests" include("interface/sparsediff_tests.jl")
@time @safetestset "Enum Tests" include("interface/enums.jl")
@time @safetestset "Mass Matrix Tests" include("interface/mass_matrix_tests.jl")
@time @safetestset "W-Operator prototype tests" include("interface/wprototype_tests.jl")
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "InterfaceIII" || GROUP == "Interface")
Expand Down

0 comments on commit 6a5bf95

Please sign in to comment.