Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Expand All @@ -33,10 +34,11 @@ FunctionWrappersWrappers = "0.1"
IteratorInterfaceExtensions = "^0.1, ^1"
Preferences = "1.3"
RecipesBase = "0.7.0, 0.8, 1.0"
RecursiveArrayTools = "2.14"
RecursiveArrayTools = "2.33"
RuntimeGeneratedFunctions = "0.5"
StaticArrays = "1"
StaticArraysCore = "1"
SymbolicIndexingInterface = "0.2"
Tables = "1"
julia = "1.6"

Expand Down
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module SciMLBase

using ConstructionBase
using RecipesBase, RecursiveArrayTools, Tables
using SymbolicIndexingInterface
using DocStringExtensions
using LinearAlgebra
using Statistics
Expand Down
25 changes: 18 additions & 7 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,13 @@ function getobserved(integrator::DEIntegrator)
end
end

sym_to_index(sym, integrator::DEIntegrator) = sym_to_index(sym, getsyms(integrator))
function sym_to_index(sym, integrator::DEIntegrator)
if has_sys(integrator.f) && is_state_sym(integrator.f.sys, sym)
return state_sym_to_index(integrator.f.sys, sym)
else
return sym_to_index(sym, getsyms(integrator))
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator,
I::Union{Int, AbstractArray{Int},
Expand All @@ -432,13 +438,18 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
i = sym
end

indepsym = getindepsym(A)
paramsyms = getparamsyms(A)
if i === nothing
if issymbollike(sym) && indepsym !== nothing && Symbol(sym) == indepsym
A.t
elseif issymbollike(sym) && paramsyms !== nothing && Symbol(sym) in paramsyms
A.p[findfirst(isequal(Symbol(sym)), paramsyms)]
if issymbollike(sym)
if has_sys(A.f) && is_indep_sym(A.f.sys, sym) ||
Symbol(sym) == getindepsym(A)
return A.t
elseif has_sys(A.f) && is_param_sym(A.f.sys, sym)
return A.p[param_sym_to_index(A.f.sys, sym)]
elseif has_paramsyms(A.f) && Symbol(sym) in getparamsyms(A)
return A.p[findfirst(x -> isequal(x, Symbol(sym)), getparamsyms(A))]
else
return observed(A, sym)
end
else
observed(A, sym)
end
Expand Down
1 change: 1 addition & 0 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3654,6 +3654,7 @@ has_tgrad(f::AbstractSciMLFunction) = __has_tgrad(f) && f.tgrad !== nothing
has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing
has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing
has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing
has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing
has_syms(f::AbstractSciMLFunction) = __has_syms(f) && f.syms !== nothing
has_indepsym(f::AbstractSciMLFunction) = __has_indepsym(f) && f.indepsym !== nothing
has_paramsyms(f::AbstractSciMLFunction) = __has_paramsyms(f) && f.paramsyms !== nothing
Expand Down
40 changes: 31 additions & 9 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,31 @@ end
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
idxs::Integer, continuity) where {deriv}
A = sol.interp(t, idxs, deriv, sol.prob.p, continuity)
syms = hasproperty(sol.prob.f, :syms) && sol.prob.f.syms !== nothing ?
[sol.prob.f.syms[idxs]] : nothing
observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p)
if has_sys(sol.prob.f)
DiffEqArray{typeof(A).parameters[1:4]..., typeof(sol.prob.f.sys), typeof(observed),
typeof(p)}(A.u, A.t, sol.prob.f.sys, observed, p)
else
syms = hasproperty(sol.prob.f, :syms) && sol.prob.f.syms !== nothing ?
[sol.prob.f.syms[idxs]] : nothing
DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p)
end
end
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
idxs::AbstractVector{<:Integer},
continuity) where {deriv}
A = sol.interp(t, idxs, deriv, sol.prob.p, continuity)
syms = hasproperty(sol.prob.f, :syms) && sol.prob.f.syms !== nothing ?
sol.prob.f.syms[idxs] : nothing
observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p)
if has_sys(sol.prob.f)
DiffEqArray{typeof(A).parameters[1:4]..., typeof(sol.prob.f.sys), typeof(observed),
typeof(p)}(A.u, A.t, sol.prob.f.sys, observed, p)
else
syms = hasproperty(sol.prob.f, :syms) && sol.prob.f.syms !== nothing ?
sol.prob.f.syms[idxs] : nothing
DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p)
end
end

function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
Expand All @@ -118,7 +128,12 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
DiffEqArray(interp_sol[idxs], t, [idxs], getindepsym(sol), observed, p)
if has_sys(sol.prob.f)
return DiffEqArray(interp_sol[idxs], t, [idxs],
independent_variables(sol.prob.f.sys), observed, p)
else
return DiffEqArray(interp_sol[idxs], t, [idxs], getindepsym(sol), observed, p)
end
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
Expand All @@ -127,8 +142,15 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, idxs,
getindepsym(sol), observed, p)
if has_sys(sol.prob.f)
return DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t,
idxs,
independent_variables(sol.prob.f.sys), observed, p)
else
return DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t,
idxs,
getindepsym(sol), observed, p)
end
end

function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
Expand Down
48 changes: 34 additions & 14 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@ function Base.show(io::IO, m::MIME"text/plain", A::AbstractNoTimeSolution)
end

# For augmenting system information to enable symbol based indexing of interpolated solutions
function augment(A::DiffEqArray, sol::AbstractODESolution)
syms = hasproperty(sol.prob.f, :syms) ? sol.prob.f.syms : nothing
function augment(A::DiffEqArray{T, N, Q, B}, sol::AbstractODESolution) where {T, N, Q, B}
observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p)
if has_sys(sol.prob.f)
DiffEqArray{T, N, Q, B, typeof(sol.prob.f.sys), typeof(observed), typeof(p)}(A.u,
A.t,
sol.prob.f.sys,
observed,
p)
else
syms = hasproperty(sol.prob.f, :syms) ? sol.prob.f.syms : nothing
DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p)
end
end

# Symbol Handling
Expand Down Expand Up @@ -61,14 +69,15 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution,
end

Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, sym)
paramsyms = getparamsyms(A)
if issymbollike(sym)
if sym isa AbstractArray
return A[collect(sym)]
end
i = sym_to_index(sym, A)
elseif all(issymbollike, sym)
if all(in(paramsyms), Symbol.(sym))
if has_sys(A.prob.f) && all(Base.Fix1(is_param_sym, A.prob.f.sys), sym) ||
!has_sys(A.prob.f) && has_paramsyms(A.prob.f) &&
all(in(getparamsyms(A)), Symbol.(sym))
return getindex.((A,), sym)
else
return [getindex.((A,), sym, i) for i in eachindex(A)]
Expand All @@ -77,13 +86,18 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s
i = sym
end

indepsym = getindepsym(A)
paramsyms = getparamsyms(A)
if i === nothing
if issymbollike(sym) && indepsym !== nothing && Symbol(sym) == indepsym
A.t
elseif issymbollike(sym) && paramsyms !== nothing && Symbol(sym) in paramsyms
A.prob.p[findfirst(x -> isequal(x, Symbol(sym)), paramsyms)]
if issymbollike(sym)
if has_sys(A.prob.f) && is_indep_sym(A.prob.f.sys, sym) ||
Symbol(sym) == getindepsym(A)
return A.t
elseif has_sys(A.prob.f) && is_param_sym(A.prob.f.sys, sym)
return A.prob.p[param_sym_to_index(A.prob.f.sys, sym)]
elseif has_paramsyms(A.prob.f) && Symbol(sym) in getparamsyms(A)
return A.prob.p[findfirst(x -> isequal(x, Symbol(sym)), getparamsyms(A))]
else
return observed(A, sym, :)
end
else
observed(A, sym, :)
end
Expand All @@ -106,9 +120,9 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s
i = sym
end

indepsym = getindepsym(A)
if i === nothing
if issymbollike(sym) && indepsym !== nothing && Symbol(sym) == indepsym
if issymbollike(sym) && has_sys(A.prob.f) && is_indep_sym(A.prob.f.sys, sym) ||
Symbol(sym) == getindepsym(A)
A.t[args...]
else
observed(A, sym, args...)
Expand Down Expand Up @@ -411,7 +425,13 @@ function cleansym(sym::Symbol)
return str
end

sym_to_index(sym, sol::AbstractSciMLSolution) = sym_to_index(sym, getsyms(sol))
function sym_to_index(sym, sol::AbstractSciMLSolution)
if has_sys(sol.prob.f) && is_state_sym(sol.prob.f.sys, sym)
return state_sym_to_index(sol.prob.f.sys, sym)
else
return sym_to_index(sym, getsyms(sol))
end
end
sym_to_index(sym, syms) = findfirst(isequal(Symbol(sym)), syms)
const issymbollike = RecursiveArrayTools.issymbollike

Expand Down
2 changes: 1 addition & 1 deletion test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
BoundaryValueDiffEq = "2.10"
ModelingToolkit = "8.35"
ModelingToolkit = "8.37"
Optimization = "3"
OptimizationOptimJL = "0.1"
OrdinaryDiffEq = "6.33"
Expand Down