Skip to content

Commit

Permalink
why this fail
Browse files Browse the repository at this point in the history
  • Loading branch information
xtalax committed Mar 9, 2023
1 parent 82df76c commit bb17abc
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 47 deletions.
8 changes: 1 addition & 7 deletions src/problems/ode_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,15 @@ struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <:
kwargs::K
"""An internal argument for storing traits about the solving process."""
problem_type::PT
"""Whether the output has all saved states."""
dense_output::Bool
@add_kwonly function ODEProblem{iip}(f::AbstractODEFunction{iip},
u0, tspan, p = NullParameters(),
problem_type = StandardODEProblem();
dense_output = true,
kwargs...) where {iip}
_tspan = promote_tspan(tspan)
new{typeof(u0), typeof(_tspan),
isinplace(f), typeof(p), typeof(f),
typeof(kwargs),
typeof(problem_type)}(f, u0, _tspan, p, kwargs, problem_type, dense_output)
typeof(problem_type)}(f, u0, _tspan, p, kwargs, problem_type)
end

"""
Expand Down Expand Up @@ -487,6 +484,3 @@ function IncrementingODEProblem{iip}(f::IncrementingODEFunction, u0, tspan,
p = NullParameters(); kwargs...) where {iip}
ODEProblem(f, u0, tspan, p, IncrementingODEProblem{iip}(); kwargs...)
end

is_dense_output(prob::ODEProblem) = prob.dense_output
is_dense_output(prob) = true
6 changes: 2 additions & 4 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,10 @@ function remake(prob::ODEProblem; f = missing,
end

if kwargs === missing
ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type;
dense_output = prob.dense_output, prob.kwargs...,
ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; prob.kwargs...,
_kwargs...)
else
ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type;
dense_output = prob.dense_output, kwargs...)
ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; kwargs...)
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/solutions/dae_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
HermiteInterpolation(t, u, du),
retcode = ReturnCode.Default,
destats = nothing,
sym_map = nothing,
sym_map = default_sym_map(prob),
kwargs...)
T = eltype(eltype(u))

Expand Down
2 changes: 1 addition & 1 deletion src/solutions/nonlinear_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function build_solution(prob::AbstractNonlinearProblem,
original = nothing,
left = nothing,
right = nothing,
sym_map = nothing,
sym_map = default_sym_map(prob),
dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing]),
kwargs...)
T = eltype(eltype(u))
Expand Down
5 changes: 4 additions & 1 deletion src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense
if isnothing(dep_idxs)
dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing])
end
if isnothing(sym_map)
sym_map = default_sym_map(prob)
end
return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
typeof(k), typeof(prob), typeof(alg), typeof(interp),
typeof(destats),
Expand Down Expand Up @@ -173,7 +176,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
k = nothing,
alg_choice = nothing,
interp = LinearInterpolation(t, u),
sym_map = nothing, dep_idxs = nothing,
sym_map = default_sym_map(prob), dep_idxs = nothing,
retcode = ReturnCode.Default, destats = nothing, kwargs...)
T = eltype(eltype(u))

Expand Down
2 changes: 1 addition & 1 deletion src/solutions/optimization_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function build_solution(prob::AbstractOptimizationProblem,
alg, u, objective;
retcode = ReturnCode.Default,
original = nothing,
sym_map = nothing,
sym_map = default_sym_map(prob),
kwargs...)
T = eltype(eltype(u))
N = ndims(u)
Expand Down
2 changes: 1 addition & 1 deletion src/solutions/rode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem},
interp = LinearInterpolation(t, u),
retcode = ReturnCode.Default,
alg_choice = nothing,
sym_map = nothing,
sym_map = default_sym_map(prob),
seed = UInt64(0), destats = nothing, kwargs...)
T = eltype(eltype(u))
N = length((size(prob.u0)..., length(u)))
Expand Down
60 changes: 29 additions & 31 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s
if sym isa AbstractArray
return A[collect(sym)]
end

if hasfield(typeof(A), :sym_map) && !isnothing(A.sym_map)
i = get(A.sym_map, sym, nothing)
else
i = sym_to_index(sym, A)
end
i = state_sym_to_index(A, sym)
elseif all(issymbollike, sym)
if has_sys(A.prob.f) && all(Base.Fix1(is_param_sym, A.prob.f.sys), sym) ||
!has_sys(A.prob.f) && has_paramsyms(A.prob.f) &&
Expand All @@ -87,6 +82,8 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s
else
return [getindex.((A,), sym, i) for i in eachindex(A)]
end
elseif is_symbolic_expr(sym)
return convert_to_getindex(A, sym)
else
i = sym
end
Expand All @@ -104,6 +101,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s
return observed(A, sym, :)
end
else
@show sym
observed(A, sym, :)
end
elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer}
Expand All @@ -118,13 +116,11 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s
if sym isa AbstractArray
return A[collect(sym), args...]
end
if hasfield(typeof(A), :sym_map) && !isnothing(A.sym_map)
i = get(A.sym_map, sym, nothing)
else
i = sym_to_index(sym, A)
end
i = state_sym_to_index(A, sym)
elseif all(issymbollike, sym)
return reduce(vcat, map(s -> A[s, args...]', sym))
elseif is_symbolic_expr(sym)
return convert_to_getindex(A, sym, args...)
else
i = sym
end
Expand All @@ -134,6 +130,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s
Symbol(sym) == getindepsym(A)
A.t[args...]
else
@show sym
observed(A, sym, args...)
end
elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer}
Expand All @@ -143,31 +140,33 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s
end
end

function _get_dep_idxs(A::AbstractTimeseriesSolution)
function _get_dep_idxs(A::AbstractSciMLSolution)
SII = SymbolicIndexingInterface
if has_sys(A.prob.f) && has_observed(A.prob.f)
if !isnothing(A.sym_map)
is_ODAE = hasfield(typeof(A.prob.f.sys), :unknown_states) &&
!isnothing(getfield(A.prob.f.sys, :unknown_states))
if is_ODAE
sts = getfield(A.prob.f.sys, :unknown_states)
return map(x -> sym_to_index(x, A),
sts = unknown_states(A.prob.f.sys)
return map(x -> state_sym_to_index(A, x),
get_deps_of_observed(sts,
SII.observed(A.prob.f.sys)))
end
return map(x -> A.sym_map[x], get_deps_of_observed(A.prob.f.sys))
@show "not ODAE"
return map(x -> A.sym_map[safe_unwrap(x)], get_deps_of_observed(A.prob.f.sys))
end
end
return [nothing]
end

idxs_initialized(idxs) = isempty(idxs) || !isnothing(first(idxs))

function get_dep_idxs(A::AbstractTimeseriesSolution)
function get_dep_idxs(A::AbstractSciMLSolution)
if hasfield(typeof(A), :dep_idxs)
if idxs_initialized(A.dep_idxs[])
return A.dep_idxs[]
else
@show "recomputing dep_idxs"
idxs = _get_dep_idxs(A)
A.dep_idxs[] = idxs
return A.dep_idxs[]
Expand All @@ -178,27 +177,25 @@ function get_dep_idxs(A::AbstractTimeseriesSolution)
end

function observed(A::AbstractTimeseriesSolution, sym, i::Int)
dense = is_dense_output(A.prob)
idxs = dense ? [nothing] : get_dep_idxs(A)
if dense || !idxs_initialized(idxs)
idxs = get_dep_idxs(A)
if !idxs_initialized(idxs)
return getobserved(A)(sym, A.u[i], A.prob.p, A.t[i])
end
getobserved(A)(sym, A[i][idxs], A.prob.p, A.t[i])
end

function observed(A::AbstractTimeseriesSolution, sym, is::AbstractArray{Int})
dense = is_dense_output(A.prob)
idxs = dense ? [nothing] : get_dep_idxs(A)
if dense || !idxs_initialized(idxs)
idxs = get_dep_idxs(A)
if !idxs_initialized(idxs)
return getobserved(A)(sym, A.u[i], A.prob.p, A.t[i])
end
getobserved(A).((sym,), map(j -> A.u[j][idxs], is), (A.prob.p,), A.t[is])
end

function observed(A::AbstractTimeseriesSolution, sym, i::Colon)
dense = is_dense_output(A.prob)
idxs = dense ? [nothing] : get_dep_idxs(A)
if dense || !idxs_initialized(idxs)
idxs = get_dep_idxs(A)
@show idxs
if !idxs_initialized(idxs)
return getobserved(A).((sym,), A.u, (A.prob.p,), A.t)
end
getobserved(A).((sym,), map(j -> A.u[j][idxs], eachindex(A.t)), (A.prob.p,), A.t)
Expand All @@ -209,11 +206,8 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
if sym isa AbstractArray
return A[collect(sym)]
end
if hasfield(typeof(A), :sym_map) && !isnothing(A.sym_map)
i = get(A.sym_map, sym, nothing)
else
i = sym_to_index(sym, A)
end
i = state_sym_to_index(A, sym)

elseif all(issymbollike, sym)
return reduce(vcat, map(s -> A[s]', sym))
else
Expand All @@ -235,7 +229,11 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
end

function observed(A::AbstractNoTimeSolution, sym)
getobserved(A)(sym, A.u, A.prob.p)
idxs = get_dep_idxs(A)
if !idxs_initialized(idxs)
return getobserved(A)(sym, A.u[i], A.prob.p, A.t[i])
end
getobserved(A)(sym, A.u[idxs], A.prob.p)
end

function observed(A::AbstractOptimizationSolution, sym)
Expand Down
8 changes: 8 additions & 0 deletions src/symbolic_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ function getparamsyms(sol::AbstractOptimizationSolution)
end
end

function SymbolicIndexingInterface.state_sym_to_index(A::S, sym) where {S <: AbstractSciMLSolution}
if hasfield(S, :sym_map) && !isnothing(A.sym_map)
return get(A.sym_map, sym, nothing)
else
return sym_to_index(sym, A)
end
end

# Only for compatibility!
function getindepsym_defaultt(sol)
if has_indepsym(sol.prob.f)
Expand Down
9 changes: 9 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,12 @@ end

_unwrap_val(::Val{B}) where {B} = B
_unwrap_val(B) = B

function default_sym_map(prob)
if has_sys(prob.f)
sts = safe_unwrap.(unknown_states(prob.f.sys))
return Dict(sts .=> eachindex(sts))
else
return nothing
end
end

0 comments on commit bb17abc

Please sign in to comment.