@@ -46,6 +46,9 @@ function StructuralIdentifiability.eval_at_nemo(e::SymbolicUtils.BasicSymbolic,
4646 return args[1 ]^ args[2 ]
4747 end
4848 return 1 // args[1 ]^ (- args[2 ])
49+ # dirty way, assumes that all shifts should be just removed
50+ elseif startswith (String (Symbol (Symbolics. operation (e))), " Shift" )
51+ return args[1 ]
4952 end
5053 throw (Base. ArgumentError (" Function $(Symbolics. operation (e)) is not supported" ))
5154 elseif e isa Symbolics. Symbolic
@@ -71,12 +74,11 @@ function StructuralIdentifiability.eval_at_nemo(
7174end
7275
7376function get_measured_quantities (ode:: ModelingToolkit.ODESystem )
74- if any (ModelingToolkit. isoutput (eq. lhs) for eq in ModelingToolkit. equations (ode))
75- @info " Measured quantities are not provided, trying to find the outputs in input ODE."
76- return filter (
77- eq -> (ModelingToolkit. isoutput (eq. lhs)),
78- ModelingToolkit. equations (ode),
79- )
77+ outputs = filter (eq -> ModelingToolkit. isoutput (eq. lhs), ModelingToolkit. equations (ode))
78+ if ! isempty (outputs)
79+ return outputs
80+ elseif ! isempty (ModelingToolkit. observed (ode))
81+ return ModelingToolkit. observed (ode)
8082 else
8183 throw (
8284 error (
@@ -103,6 +105,9 @@ function StructuralIdentifiability.mtk_to_si(
103105 de:: ModelingToolkit.AbstractTimeDependentSystem ,
104106 measured_quantities:: Array{ModelingToolkit.Equation} ,
105107)
108+ if isempty (measured_quantities)
109+ measured_quantities = get_measured_quantities (de)
110+ end
106111 return __mtk_to_si (
107112 de,
108113 [(replace (string (e. lhs), " (t)" => " " ), e. rhs) for e in measured_quantities],
@@ -153,6 +158,20 @@ function preprocess_ode(
153158 return mtk_to_si (de, measured_quantities)
154159end
155160
161+ # ------------------------------------------------------------------------------
162+ function clean_calls (funcs)
163+ res = []
164+ for f in funcs
165+ if length (Symbolics. arguments (f)) == 1 &&
166+ ! Symbolics. iscall (first (Symbolics. arguments (f)))
167+ push! (res, f)
168+ else
169+ push! (res, first (Symbolics. arguments (f)))
170+ end
171+ end
172+ return res
173+ end
174+
156175# ------------------------------------------------------------------------------
157176"""
158177 function __mtk_to_si(de::ModelingToolkit.AbstractTimeDependentSystem, measured_quantities::Array{Tuple{String, SymbolicUtils.BasicSymbolic}})
@@ -186,11 +205,10 @@ function __mtk_to_si(
186205 end
187206
188207 y_functions = [each[2 ] for each in measured_quantities]
189- inputs = filter (v -> ModelingToolkit. isinput (v), ModelingToolkit. unknowns (de))
190- state_vars = filter (
191- s -> ! (ModelingToolkit. isinput (s) || ModelingToolkit. isoutput (s)),
192- ModelingToolkit. unknowns (de),
193- )
208+ state_vars =
209+ filter (s -> ! ModelingToolkit. isoutput (s), clean_calls (map (e -> e. lhs, diff_eqs)))
210+ all_funcs = collect (Set (clean_calls (ModelingToolkit. unknowns (de))))
211+ inputs = filter (s -> ! ModelingToolkit. isoutput (s), setdiff (all_funcs, state_vars))
194212 params = ModelingToolkit. parameters (de)
195213 t = ModelingToolkit. arguments (diff_eqs[1 ]. lhs)[1 ]
196214 params_from_measured_quantities = union (
@@ -240,7 +258,7 @@ function __mtk_to_si(
240258end
241259# -----------------------------------------------------------------------------
242260"""
243- function assess_local_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ ModelingToolkit.Equation} [], funcs_to_check=Array{}[], prob_threshold::Float64=0.99, type=:SE, loglevel=Logging.Info)
261+ function assess_local_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=ModelingToolkit.Equation[], funcs_to_check=Array{}[], prob_threshold::Float64=0.99, type=:SE, loglevel=Logging.Info)
244262
245263Input:
246264- `ode` - the ODESystem object from ModelingToolkit
@@ -263,7 +281,7 @@ The return value is a tuple consisting of the array of bools and the number of e
263281"""
264282function StructuralIdentifiability. assess_local_identifiability (
265283 ode:: ModelingToolkit.ODESystem ;
266- measured_quantities = Array{ ModelingToolkit. Equation} [],
284+ measured_quantities = ModelingToolkit. Equation[],
267285 funcs_to_check = Array{}[],
268286 prob_threshold:: Float64 = 0.99 ,
269287 type = :SE ,
@@ -288,28 +306,13 @@ end
288306 prob_threshold:: Float64 = 0.99 ,
289307 type = :SE ,
290308)
291- if length (measured_quantities) == 0
292- if any (ModelingToolkit. isoutput (eq. lhs) for eq in ModelingToolkit. equations (ode))
293- @info " Measured quantities are not provided, trying to find the outputs in input ODE."
294- measured_quantities = filter (
295- eq -> (ModelingToolkit. isoutput (eq. lhs)),
296- ModelingToolkit. equations (ode),
297- )
298- else
299- throw (
300- error (
301- " Measured quantities (output functions) were not provided and no outputs were found." ,
302- ),
303- )
304- end
305- end
306- if length (funcs_to_check) == 0
307- funcs_to_check = vcat (
308- [e for e in ModelingToolkit. unknowns (ode) if ! ModelingToolkit. isoutput (e)],
309- ModelingToolkit. parameters (ode),
310- )
311- end
312309 ode, conversion = mtk_to_si (ode, measured_quantities)
310+ @info " System parsed into $ode "
311+ conversion_back = Dict (v => k for (k, v) in conversion)
312+ if isempty (funcs_to_check)
313+ funcs_to_check = [conversion_back[x] for x in [ode. x_vars... , ode. parameters... ]]
314+ end
315+
313316 funcs_to_check_ = [eval_at_nemo (x, conversion) for x in funcs_to_check]
314317
315318 if isequal (type, :SE )
340343# ------------------------------------------------------------------------------
341344
342345"""
343- assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ ModelingToolkit.Equation} [], funcs_to_check=[], known_ic=[], prob_threshold = 0.99, loglevel=Logging.Info)
346+ assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=ModelingToolkit.Equation[], funcs_to_check=[], known_ic=[], prob_threshold = 0.99, loglevel=Logging.Info)
344347
345348Input:
346349- `ode` - the ModelingToolkit.ODESystem object that defines the model
@@ -356,7 +359,7 @@ If known initial conditions are provided, the identifiability results for the st
356359"""
357360function StructuralIdentifiability. assess_identifiability (
358361 ode:: ModelingToolkit.ODESystem ;
359- measured_quantities = Array{ ModelingToolkit. Equation} [],
362+ measured_quantities = ModelingToolkit. Equation[],
360363 funcs_to_check = [],
361364 known_ic = [],
362365 prob_threshold = 0.99 ,
@@ -376,16 +379,13 @@ end
376379
377380function _assess_identifiability (
378381 ode:: ModelingToolkit.ODESystem ;
379- measured_quantities = Array{ ModelingToolkit. Equation} [],
382+ measured_quantities = ModelingToolkit. Equation[],
380383 funcs_to_check = [],
381384 known_ic = [],
382385 prob_threshold = 0.99 ,
383386)
384- if isempty (measured_quantities)
385- measured_quantities = get_measured_quantities (ode)
386- end
387-
388387 ode, conversion = mtk_to_si (ode, measured_quantities)
388+ @info " System parsed into $ode "
389389 conversion_back = Dict (v => k for (k, v) in conversion)
390390 if isempty (funcs_to_check)
391391 funcs_to_check = [conversion_back[x] for x in [ode. x_vars... , ode. parameters... ]]
@@ -470,43 +470,29 @@ function _assess_local_identifiability(
470470 known_ic = Array{}[],
471471 prob_threshold:: Float64 = 0.99 ,
472472)
473- if length (measured_quantities) == 0
474- if any (ModelingToolkit. isoutput (eq. lhs) for eq in ModelingToolkit. equations (dds))
475- @info " Measured quantities are not provided, trying to find the outputs in input dynamical system."
476- measured_quantities = filter (
477- eq -> (ModelingToolkit. isoutput (eq. lhs)),
478- ModelingToolkit. equations (dds),
479- )
480- else
481- throw (
482- error (
483- " Measured quantities (output functions) were not provided and no outputs were found." ,
484- ),
485- )
486- end
487- end
488-
489473 # Converting the finite difference operator in the right-hand side to
490474 # the corresponding shift operator
491475 eqs = filter (eq -> ! (ModelingToolkit. isoutput (eq. lhs)), ModelingToolkit. equations (dds))
492476
493477 dds_aux_ode, conversion = mtk_to_si (dds, measured_quantities)
494478 dds_aux = StructuralIdentifiability. DDS {QQMPolyRingElem} (dds_aux_ode)
479+ @info " Parsed into the following model: $dds_aux "
495480 if length (funcs_to_check) == 0
496481 params = parameters (dds)
497482 params_from_measured_quantities = union (
498483 [filter (s -> ! iscall (s), get_variables (y)) for y in measured_quantities]. .. ,
499484 )
500485 funcs_to_check = vcat (
501486 [
502- x for x in unknowns (dds) if
487+ x for x in clean_calls ( unknowns (dds) ) if
503488 conversion[x] in StructuralIdentifiability. x_vars (dds_aux)
504489 ],
505490 union (params, params_from_measured_quantities),
506491 )
507492 end
508493 funcs_to_check_ = [eval_at_nemo (x, conversion) for x in funcs_to_check]
509494 known_ic_ = [eval_at_nemo (x, conversion) for x in known_ic]
495+ @info " Functions to check are $([" $f " for f in funcs_to_check_]) and initial conditions are known for $([" $f " for f in known_ic_]) "
510496
511497 result = StructuralIdentifiability. _assess_local_identifiability_discrete_aux (
512498 dds_aux,
@@ -568,7 +554,7 @@ find_identifiable_functions(de, measured_quantities = [y1 ~ x0])
568554"""
569555function StructuralIdentifiability. find_identifiable_functions (
570556 ode:: ModelingToolkit.ODESystem ;
571- measured_quantities = Array{ ModelingToolkit. Equation} [],
557+ measured_quantities = ModelingToolkit. Equation[],
572558 known_ic = [],
573559 prob_threshold:: Float64 = 0.99 ,
574560 seed = 42 ,
@@ -595,18 +581,15 @@ end
595581
596582function _find_identifiable_functions (
597583 ode:: ModelingToolkit.ODESystem ;
598- measured_quantities = Array{ ModelingToolkit. Equation} [],
599- known_ic = Array{ Symbolics. Num} [],
584+ measured_quantities = ModelingToolkit. Equation[],
585+ known_ic = Symbolics. Num[],
600586 prob_threshold:: Float64 = 0.99 ,
601587 seed = 42 ,
602588 with_states = false ,
603589 simplify = :standard ,
604590 rational_interpolator = :VanDerHoevenLecerf ,
605591)
606592 Random. seed! (seed)
607- if isempty (measured_quantities)
608- measured_quantities = get_measured_quantities (ode)
609- end
610593 ode, conversion = mtk_to_si (ode, measured_quantities)
611594 known_ic_ = [eval_at_nemo (each, conversion) for each in known_ic]
612595 result = nothing
0 commit comments