Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ function build_torn_function(sys;
out,
rhss)

states = fullvars[states_idxs]
states = Any[fullvars[i] for i in states_idxs]
@set! sys.solver_states = states
syms = map(Symbol, states)

pre = get_postprocess_fbody(sys)
Expand Down
15 changes: 12 additions & 3 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ for prop in [:eqs
:tearing_state
:substitutions
:metadata
:discrete_subsystems]
:discrete_subsystems
:solver_states]
fname1 = Symbol(:get_, prop)
fname2 = Symbol(:has_, prop)
@eval begin
Expand Down Expand Up @@ -466,7 +467,7 @@ function namespace_expr(O, sys, n = nameof(sys))
end
end

function SymbolicIndexingInterface.states(sys::AbstractSystem)
function states(sys::AbstractSystem)
sts = get_states(sys)
systems = get_systems(sys)
unique(isempty(systems) ?
Expand Down Expand Up @@ -580,8 +581,16 @@ end

SymbolicIndexingInterface.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys))

function solver_states(sys::AbstractSystem)
sts = states(sys)
if has_solver_states(sys)
sts = something(get_solver_states(sys), sts)
end
return sts
end

function SymbolicIndexingInterface.state_sym_to_index(sys::AbstractSystem, sym)
findfirst(isequal(sym), SymbolicIndexingInterface.states(sys))
findfirst(isequal(sym), solver_states(sys))
end
function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym)
!isnothing(SymbolicIndexingInterface.state_sym_to_index(sys, sym))
Expand Down
11 changes: 8 additions & 3 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,21 @@ struct ODESystem <: AbstractODESystem
"""
complete::Bool
"""
discrete_subsystems: a list of discrete subsystems
discrete_subsystems: a list of discrete subsystems.
"""
discrete_subsystems::Any
"""
solver_states: a list of actual solver states. Only used for ODAEProblem.
"""
solver_states::Union{Nothing, Vector{Any}}

function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
torn_matching, connector_type, preface, cevents,
devents, metadata = nothing, tearing_state = nothing,
substitutions = nothing, complete = false,
discrete_subsystems = nothing; checks::Union{Bool, Int} = true)
discrete_subsystems = nothing, solver_states = nothing;
checks::Union{Bool, Int} = true)
if checks == true || (checks & CheckComponents) > 0
check_variables(dvs, iv)
check_parameters(ps, iv)
Expand All @@ -149,7 +154,7 @@ struct ODESystem <: AbstractODESystem
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
connector_type, preface, cevents, devents, metadata, tearing_state,
substitutions, complete, discrete_subsystems)
substitutions, complete, discrete_subsystems, solver_states)
end
end

Expand Down