Skip to content

Commit

Permalink
Merge pull request #1944 from oscardssmith/jacobian-for-initialization
Browse files Browse the repository at this point in the history
use jac for the ShampineCollocationInit nlsolve
  • Loading branch information
ChrisRackauckas authored Jun 16, 2023
2 parents 874db54 + d6c44da commit 73bcbf0
Showing 1 changed file with 49 additions and 8 deletions.
57 changes: 49 additions & 8 deletions src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,20 +152,32 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation

nlequation! = @closure (out, u, p) -> begin
update_coefficients!(M, u, p, t)
#M * (u-u0)/dt - f(u,p,t)
# f(u,p,t) + M * (u0 - u)/dt
tmp = isAD ? PreallocationTools.get_tmp(_tmp, u) : _tmp
@. tmp = (u - u0) / dt
@. tmp = (u0 - u) / dt
mul!(_vec(out), M, _vec(tmp))
f(tmp, u, p, t)
out .-= tmp
out .+= tmp
nothing
end

nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD)
jac = if isnothing(f.jac)
f.jac
else
@closure (J, u, p) -> begin
# f(u,p,t) + M * (u0 - u)/dt
# df(u,p,t)/du - M/dt
f.jac(J, u, p, t)
J .-= M .* inv(dt)
nothing
end
end

nlfunc = NonlinearFunction(nlequation!;
jac_prototype = f.jac_prototype)
jac_prototype = f.jac_prototype,
jac = jac)
nlprob = NonlinearProblem(nlfunc, integrator.u, p)
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD)
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
reltol = integrator.opts.reltol)
integrator.u .= nlsol.u
Expand Down Expand Up @@ -227,10 +239,19 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
M * (u - u0) / dt - f(u, p, t)
end

jac = if isnothing(f.jac)
f.jac
else
@closure (u, p) -> begin
return M * (u .- u0) ./ dt .- f.jac(u, p, t)
end
end

nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0)

nlfunc = NonlinearFunction(nlequation_oop;
jac_prototype = f.jac_prototype)
jac_prototype = f.jac_prototype,
jac = jac)
nlprob = NonlinearProblem(nlfunc, u0)
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
reltol = integrator.opts.reltol)
Expand Down Expand Up @@ -281,10 +302,20 @@ function _initialize_dae!(integrator, prob::DAEProblem,
nothing
end

nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD)
jac = if isnothing(f.jac)
f.jac
else
@closure (J, u, p) -> begin
f.jac(J, u, p, inv(dt), t)
nothing
end
end

nlfunc = NonlinearFunction(nlequation!; jac_prototype = f.jac_prototype)
nlfunc = NonlinearFunction(nlequation!;
jac_prototype = f.jac_prototype,
jac = jac)
nlprob = NonlinearProblem(nlfunc, u0, p)
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD)
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
reltol = integrator.opts.reltol)

Expand Down Expand Up @@ -318,6 +349,16 @@ function _initialize_dae!(integrator, prob::DAEProblem,
resid = f(integrator.du, u0, p, t)
integrator.opts.internalnorm(resid, t) <= integrator.opts.abstol && return

jac = if isnothing(f.jac)
f.jac
else
@closure (u, p) -> begin
return f.jac(u, p, inv(dt), t)
end
end
nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype,
jac = jac)
nlprob = NonlinearProblem(nlfunc, u0)
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0)

nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype)
Expand Down

0 comments on commit 73bcbf0

Please sign in to comment.