diff --git a/src/initialize_dae.jl b/src/initialize_dae.jl index 8f27d336a8..af4f96ceae 100644 --- a/src/initialize_dae.jl +++ b/src/initialize_dae.jl @@ -152,20 +152,32 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation nlequation! = @closure (out, u, p) -> begin update_coefficients!(M, u, p, t) - #M * (u-u0)/dt - f(u,p,t) + # f(u,p,t) + M * (u0 - u)/dt tmp = isAD ? PreallocationTools.get_tmp(_tmp, u) : _tmp - @. tmp = (u - u0) / dt + @. tmp = (u0 - u) / dt mul!(_vec(out), M, _vec(tmp)) f(tmp, u, p, t) - out .-= tmp + out .+= tmp nothing end - nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD) + jac = if isnothing(f.jac) + f.jac + else + @closure (J, u, p) -> begin + # f(u,p,t) + M * (u0 - u)/dt + # df(u,p,t)/du - M/dt + f.jac(J, u, p, t) + J .-= M .* inv(dt) + nothing + end + end nlfunc = NonlinearFunction(nlequation!; - jac_prototype = f.jac_prototype) + jac_prototype = f.jac_prototype, + jac = jac) nlprob = NonlinearProblem(nlfunc, integrator.u, p) + nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD) nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol, reltol = integrator.opts.reltol) integrator.u .= nlsol.u @@ -227,10 +239,19 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation M * (u - u0) / dt - f(u, p, t) end + jac = if isnothing(f.jac) + f.jac + else + @closure (u, p) -> begin + return M * (u .- u0) ./ dt .- f.jac(u, p, t) + end + end + nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0) nlfunc = NonlinearFunction(nlequation_oop; - jac_prototype = f.jac_prototype) + jac_prototype = f.jac_prototype, + jac = jac) nlprob = NonlinearProblem(nlfunc, u0) nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol, reltol = integrator.opts.reltol) @@ -281,10 +302,20 @@ function _initialize_dae!(integrator, prob::DAEProblem, nothing end - nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD) + jac = if isnothing(f.jac) + f.jac + else + @closure (J, u, p) -> begin + f.jac(J, u, p, inv(dt), t) + nothing + end + end - nlfunc = NonlinearFunction(nlequation!; jac_prototype = f.jac_prototype) + nlfunc = NonlinearFunction(nlequation!; + jac_prototype = f.jac_prototype, + jac = jac) nlprob = NonlinearProblem(nlfunc, u0, p) + nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD) nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol, reltol = integrator.opts.reltol) @@ -318,6 +349,16 @@ function _initialize_dae!(integrator, prob::DAEProblem, resid = f(integrator.du, u0, p, t) integrator.opts.internalnorm(resid, t) <= integrator.opts.abstol && return + jac = if isnothing(f.jac) + f.jac + else + @closure (u, p) -> begin + return f.jac(u, p, inv(dt), t) + end + end + nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype, + jac = jac) + nlprob = NonlinearProblem(nlfunc, u0) nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0) nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype)