diff --git a/Project.toml b/Project.toml index 281d13aeb7..76d0c8c7f5 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] @@ -33,10 +34,11 @@ FunctionWrappersWrappers = "0.1" IteratorInterfaceExtensions = "^0.1, ^1" Preferences = "1.3" RecipesBase = "0.7.0, 0.8, 1.0" -RecursiveArrayTools = "2.14" +RecursiveArrayTools = "2.33" RuntimeGeneratedFunctions = "0.5" StaticArrays = "1" StaticArraysCore = "1" +SymbolicIndexingInterface = "0.2" Tables = "1" julia = "1.6" diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 5708d96b71..d318d478e6 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -2,6 +2,7 @@ module SciMLBase using ConstructionBase using RecipesBase, RecursiveArrayTools, Tables +using SymbolicIndexingInterface using DocStringExtensions using LinearAlgebra using Statistics diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index c538f23381..46493ebf31 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -411,7 +411,13 @@ function getobserved(integrator::DEIntegrator) end end -sym_to_index(sym, integrator::DEIntegrator) = sym_to_index(sym, getsyms(integrator)) +function sym_to_index(sym, integrator::DEIntegrator) + if has_sys(integrator.f) && is_state_sym(integrator.f.sys, sym) + return state_sym_to_index(integrator.f.sys, sym) + else + return sym_to_index(sym, getsyms(integrator)) + end +end Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, I::Union{Int, AbstractArray{Int}, @@ -432,13 +438,18 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym) i = sym end - indepsym = getindepsym(A) - paramsyms = getparamsyms(A) if i === nothing - if issymbollike(sym) && indepsym !== nothing && Symbol(sym) == indepsym - A.t - elseif issymbollike(sym) && paramsyms !== nothing && Symbol(sym) in paramsyms - A.p[findfirst(isequal(Symbol(sym)), paramsyms)] + if issymbollike(sym) + if has_sys(A.f) && is_indep_sym(A.f.sys, sym) || + Symbol(sym) == getindepsym(A) + return A.t + elseif has_sys(A.f) && is_param_sym(A.f.sys, sym) + return A.p[param_sym_to_index(A.f.sys, sym)] + elseif has_paramsyms(A.f) && Symbol(sym) in getparamsyms(A) + return A.p[findfirst(x -> isequal(x, Symbol(sym)), getparamsyms(A))] + else + return observed(A, sym) + end else observed(A, sym) end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index e60bdf9e52..551d734671 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -3654,6 +3654,7 @@ has_tgrad(f::AbstractSciMLFunction) = __has_tgrad(f) && f.tgrad !== nothing has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing +has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing has_syms(f::AbstractSciMLFunction) = __has_syms(f) && f.syms !== nothing has_indepsym(f::AbstractSciMLFunction) = __has_indepsym(f) && f.indepsym !== nothing has_paramsyms(f::AbstractSciMLFunction) = __has_paramsyms(f) && f.paramsyms !== nothing diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 99cabcff92..0b899a5256 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -82,21 +82,31 @@ end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs::Integer, continuity) where {deriv} A = sol.interp(t, idxs, deriv, sol.prob.p, continuity) - syms = hasproperty(sol.prob.f, :syms) && sol.prob.f.syms !== nothing ? - [sol.prob.f.syms[idxs]] : nothing observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p) + if has_sys(sol.prob.f) + DiffEqArray{typeof(A).parameters[1:4]..., typeof(sol.prob.f.sys), typeof(observed), + typeof(p)}(A.u, A.t, sol.prob.f.sys, observed, p) + else + syms = hasproperty(sol.prob.f, :syms) && sol.prob.f.syms !== nothing ? + [sol.prob.f.syms[idxs]] : nothing + DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p) + end end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs::AbstractVector{<:Integer}, continuity) where {deriv} A = sol.interp(t, idxs, deriv, sol.prob.p, continuity) - syms = hasproperty(sol.prob.f, :syms) && sol.prob.f.syms !== nothing ? - sol.prob.f.syms[idxs] : nothing observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p) + if has_sys(sol.prob.f) + DiffEqArray{typeof(A).parameters[1:4]..., typeof(sol.prob.f.sys), typeof(observed), + typeof(p)}(A.u, A.t, sol.prob.f.sys, observed, p) + else + syms = hasproperty(sol.prob.f, :syms) && sol.prob.f.syms !== nothing ? + sol.prob.f.syms[idxs] : nothing + DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p) + end end function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs, @@ -118,7 +128,12 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - DiffEqArray(interp_sol[idxs], t, [idxs], getindepsym(sol), observed, p) + if has_sys(sol.prob.f) + return DiffEqArray(interp_sol[idxs], t, [idxs], + independent_variables(sol.prob.f.sys), observed, p) + else + return DiffEqArray(interp_sol[idxs], t, [idxs], getindepsym(sol), observed, p) + end end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, @@ -127,8 +142,15 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, idxs, - getindepsym(sol), observed, p) + if has_sys(sol.prob.f) + return DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, + idxs, + independent_variables(sol.prob.f.sys), observed, p) + else + return DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, + idxs, + getindepsym(sol), observed, p) + end end function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 5180114245..525696136c 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -18,11 +18,19 @@ function Base.show(io::IO, m::MIME"text/plain", A::AbstractNoTimeSolution) end # For augmenting system information to enable symbol based indexing of interpolated solutions -function augment(A::DiffEqArray, sol::AbstractODESolution) - syms = hasproperty(sol.prob.f, :syms) ? sol.prob.f.syms : nothing +function augment(A::DiffEqArray{T, N, Q, B}, sol::AbstractODESolution) where {T, N, Q, B} observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p) + if has_sys(sol.prob.f) + DiffEqArray{T, N, Q, B, typeof(sol.prob.f.sys), typeof(observed), typeof(p)}(A.u, + A.t, + sol.prob.f.sys, + observed, + p) + else + syms = hasproperty(sol.prob.f, :syms) ? sol.prob.f.syms : nothing + DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p) + end end # Symbol Handling @@ -61,14 +69,15 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, end Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, sym) - paramsyms = getparamsyms(A) if issymbollike(sym) if sym isa AbstractArray return A[collect(sym)] end i = sym_to_index(sym, A) elseif all(issymbollike, sym) - if all(in(paramsyms), Symbol.(sym)) + if has_sys(A.prob.f) && all(Base.Fix1(is_param_sym, A.prob.f.sys), sym) || + !has_sys(A.prob.f) && has_paramsyms(A.prob.f) && + all(in(getparamsyms(A)), Symbol.(sym)) return getindex.((A,), sym) else return [getindex.((A,), sym, i) for i in eachindex(A)] @@ -77,13 +86,18 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s i = sym end - indepsym = getindepsym(A) - paramsyms = getparamsyms(A) if i === nothing - if issymbollike(sym) && indepsym !== nothing && Symbol(sym) == indepsym - A.t - elseif issymbollike(sym) && paramsyms !== nothing && Symbol(sym) in paramsyms - A.prob.p[findfirst(x -> isequal(x, Symbol(sym)), paramsyms)] + if issymbollike(sym) + if has_sys(A.prob.f) && is_indep_sym(A.prob.f.sys, sym) || + Symbol(sym) == getindepsym(A) + return A.t + elseif has_sys(A.prob.f) && is_param_sym(A.prob.f.sys, sym) + return A.prob.p[param_sym_to_index(A.prob.f.sys, sym)] + elseif has_paramsyms(A.prob.f) && Symbol(sym) in getparamsyms(A) + return A.prob.p[findfirst(x -> isequal(x, Symbol(sym)), getparamsyms(A))] + else + return observed(A, sym, :) + end else observed(A, sym, :) end @@ -106,9 +120,9 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s i = sym end - indepsym = getindepsym(A) if i === nothing - if issymbollike(sym) && indepsym !== nothing && Symbol(sym) == indepsym + if issymbollike(sym) && has_sys(A.prob.f) && is_indep_sym(A.prob.f.sys, sym) || + Symbol(sym) == getindepsym(A) A.t[args...] else observed(A, sym, args...) @@ -411,7 +425,13 @@ function cleansym(sym::Symbol) return str end -sym_to_index(sym, sol::AbstractSciMLSolution) = sym_to_index(sym, getsyms(sol)) +function sym_to_index(sym, sol::AbstractSciMLSolution) + if has_sys(sol.prob.f) && is_state_sym(sol.prob.f.sys, sym) + return state_sym_to_index(sol.prob.f.sys, sym) + else + return sym_to_index(sym, getsyms(sol)) + end +end sym_to_index(sym, syms) = findfirst(isequal(Symbol(sym)), syms) const issymbollike = RecursiveArrayTools.issymbollike diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index fa4335497b..fc40eb4077 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -11,7 +11,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] BoundaryValueDiffEq = "2.10" -ModelingToolkit = "8.35" +ModelingToolkit = "8.37" Optimization = "3" OptimizationOptimJL = "0.1" OrdinaryDiffEq = "6.33"