Skip to content

Commit af0abd1

Browse files
authored
Merge pull request #346 from SciML/dds_mtk
Updating MTK interface
2 parents d6d9043 + 1115554 commit af0abd1

File tree

5 files changed

+176
-78
lines changed

5 files changed

+176
-78
lines changed

.github/workflows/SpellCheck.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ jobs:
1111
uses: actions/checkout@v4
1212
- name: Check spelling
1313
uses: crate-ci/[email protected]
14+
with:
15+
files: ./src ./docs

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ PrecompileTools = "1.2"
4848
Primes = "0.5"
4949
Random = "1.6, 1.7"
5050
SpecialFunctions = "2"
51-
SymbolicUtils = "2"
52-
Symbolics = "5.30.1"
51+
SymbolicUtils = "2, 3"
52+
Symbolics = "5.30.1, 6"
5353
Test = "1.6, 1.7"
5454
TestSetExtensions = "2"
5555
TimerOutputs = "0.5"

docs/src/tutorials/discrete_time.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ eqs = [
8989
R(k) ~ R(k - 1) + α * I(k - 1),
9090
]
9191
92-
@mtkbuild sys = DiscreteSystem(eqs, t)
92+
@named sys = DiscreteSystem(eqs, t)
9393
9494
assess_local_identifiability(sys, measured_quantities = [I])
9595
```

ext/ModelingToolkitSIExt.jl

Lines changed: 47 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ function StructuralIdentifiability.eval_at_nemo(e::SymbolicUtils.BasicSymbolic,
4646
return args[1]^args[2]
4747
end
4848
return 1 // args[1]^(-args[2])
49+
# dirty way, assumes that all shifts should be just removed
50+
elseif startswith(String(Symbol(Symbolics.operation(e))), "Shift")
51+
return args[1]
4952
end
5053
throw(Base.ArgumentError("Function $(Symbolics.operation(e)) is not supported"))
5154
elseif e isa Symbolics.Symbolic
@@ -71,12 +74,11 @@ function StructuralIdentifiability.eval_at_nemo(
7174
end
7275

7376
function get_measured_quantities(ode::ModelingToolkit.ODESystem)
74-
if any(ModelingToolkit.isoutput(eq.lhs) for eq in ModelingToolkit.equations(ode))
75-
@info "Measured quantities are not provided, trying to find the outputs in input ODE."
76-
return filter(
77-
eq -> (ModelingToolkit.isoutput(eq.lhs)),
78-
ModelingToolkit.equations(ode),
79-
)
77+
outputs = filter(eq -> ModelingToolkit.isoutput(eq.lhs), ModelingToolkit.equations(ode))
78+
if !isempty(outputs)
79+
return outputs
80+
elseif !isempty(ModelingToolkit.observed(ode))
81+
return ModelingToolkit.observed(ode)
8082
else
8183
throw(
8284
error(
@@ -103,6 +105,9 @@ function StructuralIdentifiability.mtk_to_si(
103105
de::ModelingToolkit.AbstractTimeDependentSystem,
104106
measured_quantities::Array{ModelingToolkit.Equation},
105107
)
108+
if isempty(measured_quantities)
109+
measured_quantities = get_measured_quantities(de)
110+
end
106111
return __mtk_to_si(
107112
de,
108113
[(replace(string(e.lhs), "(t)" => ""), e.rhs) for e in measured_quantities],
@@ -153,6 +158,20 @@ function preprocess_ode(
153158
return mtk_to_si(de, measured_quantities)
154159
end
155160

161+
#------------------------------------------------------------------------------
162+
function clean_calls(funcs)
163+
res = []
164+
for f in funcs
165+
if length(Symbolics.arguments(f)) == 1 &&
166+
!Symbolics.iscall(first(Symbolics.arguments(f)))
167+
push!(res, f)
168+
else
169+
push!(res, first(Symbolics.arguments(f)))
170+
end
171+
end
172+
return res
173+
end
174+
156175
#------------------------------------------------------------------------------
157176
"""
158177
function __mtk_to_si(de::ModelingToolkit.AbstractTimeDependentSystem, measured_quantities::Array{Tuple{String, SymbolicUtils.BasicSymbolic}})
@@ -186,11 +205,10 @@ function __mtk_to_si(
186205
end
187206

188207
y_functions = [each[2] for each in measured_quantities]
189-
inputs = filter(v -> ModelingToolkit.isinput(v), ModelingToolkit.unknowns(de))
190-
state_vars = filter(
191-
s -> !(ModelingToolkit.isinput(s) || ModelingToolkit.isoutput(s)),
192-
ModelingToolkit.unknowns(de),
193-
)
208+
state_vars =
209+
filter(s -> !ModelingToolkit.isoutput(s), clean_calls(map(e -> e.lhs, diff_eqs)))
210+
all_funcs = collect(Set(clean_calls(ModelingToolkit.unknowns(de))))
211+
inputs = filter(s -> !ModelingToolkit.isoutput(s), setdiff(all_funcs, state_vars))
194212
params = ModelingToolkit.parameters(de)
195213
t = ModelingToolkit.arguments(diff_eqs[1].lhs)[1]
196214
params_from_measured_quantities = union(
@@ -240,7 +258,7 @@ function __mtk_to_si(
240258
end
241259
# -----------------------------------------------------------------------------
242260
"""
243-
function assess_local_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=Array{}[], prob_threshold::Float64=0.99, type=:SE, loglevel=Logging.Info)
261+
function assess_local_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=ModelingToolkit.Equation[], funcs_to_check=Array{}[], prob_threshold::Float64=0.99, type=:SE, loglevel=Logging.Info)
244262
245263
Input:
246264
- `ode` - the ODESystem object from ModelingToolkit
@@ -263,7 +281,7 @@ The return value is a tuple consisting of the array of bools and the number of e
263281
"""
264282
function StructuralIdentifiability.assess_local_identifiability(
265283
ode::ModelingToolkit.ODESystem;
266-
measured_quantities = Array{ModelingToolkit.Equation}[],
284+
measured_quantities = ModelingToolkit.Equation[],
267285
funcs_to_check = Array{}[],
268286
prob_threshold::Float64 = 0.99,
269287
type = :SE,
@@ -288,28 +306,13 @@ end
288306
prob_threshold::Float64 = 0.99,
289307
type = :SE,
290308
)
291-
if length(measured_quantities) == 0
292-
if any(ModelingToolkit.isoutput(eq.lhs) for eq in ModelingToolkit.equations(ode))
293-
@info "Measured quantities are not provided, trying to find the outputs in input ODE."
294-
measured_quantities = filter(
295-
eq -> (ModelingToolkit.isoutput(eq.lhs)),
296-
ModelingToolkit.equations(ode),
297-
)
298-
else
299-
throw(
300-
error(
301-
"Measured quantities (output functions) were not provided and no outputs were found.",
302-
),
303-
)
304-
end
305-
end
306-
if length(funcs_to_check) == 0
307-
funcs_to_check = vcat(
308-
[e for e in ModelingToolkit.unknowns(ode) if !ModelingToolkit.isoutput(e)],
309-
ModelingToolkit.parameters(ode),
310-
)
311-
end
312309
ode, conversion = mtk_to_si(ode, measured_quantities)
310+
@info "System parsed into $ode"
311+
conversion_back = Dict(v => k for (k, v) in conversion)
312+
if isempty(funcs_to_check)
313+
funcs_to_check = [conversion_back[x] for x in [ode.x_vars..., ode.parameters...]]
314+
end
315+
313316
funcs_to_check_ = [eval_at_nemo(x, conversion) for x in funcs_to_check]
314317

315318
if isequal(type, :SE)
@@ -340,7 +343,7 @@ end
340343
# ------------------------------------------------------------------------------
341344

342345
"""
343-
assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=[], known_ic=[], prob_threshold = 0.99, loglevel=Logging.Info)
346+
assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=ModelingToolkit.Equation[], funcs_to_check=[], known_ic=[], prob_threshold = 0.99, loglevel=Logging.Info)
344347
345348
Input:
346349
- `ode` - the ModelingToolkit.ODESystem object that defines the model
@@ -356,7 +359,7 @@ If known initial conditions are provided, the identifiability results for the st
356359
"""
357360
function StructuralIdentifiability.assess_identifiability(
358361
ode::ModelingToolkit.ODESystem;
359-
measured_quantities = Array{ModelingToolkit.Equation}[],
362+
measured_quantities = ModelingToolkit.Equation[],
360363
funcs_to_check = [],
361364
known_ic = [],
362365
prob_threshold = 0.99,
@@ -376,16 +379,13 @@ end
376379

377380
function _assess_identifiability(
378381
ode::ModelingToolkit.ODESystem;
379-
measured_quantities = Array{ModelingToolkit.Equation}[],
382+
measured_quantities = ModelingToolkit.Equation[],
380383
funcs_to_check = [],
381384
known_ic = [],
382385
prob_threshold = 0.99,
383386
)
384-
if isempty(measured_quantities)
385-
measured_quantities = get_measured_quantities(ode)
386-
end
387-
388387
ode, conversion = mtk_to_si(ode, measured_quantities)
388+
@info "System parsed into $ode"
389389
conversion_back = Dict(v => k for (k, v) in conversion)
390390
if isempty(funcs_to_check)
391391
funcs_to_check = [conversion_back[x] for x in [ode.x_vars..., ode.parameters...]]
@@ -470,43 +470,29 @@ function _assess_local_identifiability(
470470
known_ic = Array{}[],
471471
prob_threshold::Float64 = 0.99,
472472
)
473-
if length(measured_quantities) == 0
474-
if any(ModelingToolkit.isoutput(eq.lhs) for eq in ModelingToolkit.equations(dds))
475-
@info "Measured quantities are not provided, trying to find the outputs in input dynamical system."
476-
measured_quantities = filter(
477-
eq -> (ModelingToolkit.isoutput(eq.lhs)),
478-
ModelingToolkit.equations(dds),
479-
)
480-
else
481-
throw(
482-
error(
483-
"Measured quantities (output functions) were not provided and no outputs were found.",
484-
),
485-
)
486-
end
487-
end
488-
489473
# Converting the finite difference operator in the right-hand side to
490474
# the corresponding shift operator
491475
eqs = filter(eq -> !(ModelingToolkit.isoutput(eq.lhs)), ModelingToolkit.equations(dds))
492476

493477
dds_aux_ode, conversion = mtk_to_si(dds, measured_quantities)
494478
dds_aux = StructuralIdentifiability.DDS{QQMPolyRingElem}(dds_aux_ode)
479+
@info "Parsed into the following model: $dds_aux"
495480
if length(funcs_to_check) == 0
496481
params = parameters(dds)
497482
params_from_measured_quantities = union(
498483
[filter(s -> !iscall(s), get_variables(y)) for y in measured_quantities]...,
499484
)
500485
funcs_to_check = vcat(
501486
[
502-
x for x in unknowns(dds) if
487+
x for x in clean_calls(unknowns(dds)) if
503488
conversion[x] in StructuralIdentifiability.x_vars(dds_aux)
504489
],
505490
union(params, params_from_measured_quantities),
506491
)
507492
end
508493
funcs_to_check_ = [eval_at_nemo(x, conversion) for x in funcs_to_check]
509494
known_ic_ = [eval_at_nemo(x, conversion) for x in known_ic]
495+
@info "Functions to check are $(["$f" for f in funcs_to_check_]) and initial conditions are known for $(["$f" for f in known_ic_])"
510496

511497
result = StructuralIdentifiability._assess_local_identifiability_discrete_aux(
512498
dds_aux,
@@ -568,7 +554,7 @@ find_identifiable_functions(de, measured_quantities = [y1 ~ x0])
568554
"""
569555
function StructuralIdentifiability.find_identifiable_functions(
570556
ode::ModelingToolkit.ODESystem;
571-
measured_quantities = Array{ModelingToolkit.Equation}[],
557+
measured_quantities = ModelingToolkit.Equation[],
572558
known_ic = [],
573559
prob_threshold::Float64 = 0.99,
574560
seed = 42,
@@ -595,18 +581,15 @@ end
595581

596582
function _find_identifiable_functions(
597583
ode::ModelingToolkit.ODESystem;
598-
measured_quantities = Array{ModelingToolkit.Equation}[],
599-
known_ic = Array{Symbolics.Num}[],
584+
measured_quantities = ModelingToolkit.Equation[],
585+
known_ic = Symbolics.Num[],
600586
prob_threshold::Float64 = 0.99,
601587
seed = 42,
602588
with_states = false,
603589
simplify = :standard,
604590
rational_interpolator = :VanDerHoevenLecerf,
605591
)
606592
Random.seed!(seed)
607-
if isempty(measured_quantities)
608-
measured_quantities = get_measured_quantities(ode)
609-
end
610593
ode, conversion = mtk_to_si(ode, measured_quantities)
611594
known_ic_ = [eval_at_nemo(each, conversion) for each in known_ic]
612595
result = nothing

0 commit comments

Comments
 (0)