@@ -188,11 +188,11 @@ l2x_idx = variable_index(sol, lorenz2.x)
188188l1y_idx = variable_index (sol, lorenz1. y)
189189l2y_idx = variable_index (sol, lorenz2. y)
190190
191- @test getx (sol) == sol[:, l1x_idx ]
192- @test get_arr (sol) == sol[:, [l1x_idx, l2x_idx]]
193- @test get_tuple (sol) == tuple .(sol[:, l1x_idx ], sol[:, l2x_idx ])
194- @test get_obs (sol) == sol[:, l1x_idx ] + sol[:, l2x_idx ]
195- @test get_obs_arr (sol) == vcat .(sol[:, l1x_idx ] + sol[:, l2x_idx ], sol[:, l1y_idx ] + sol[:, l2y_idx ])
191+ @test getx (sol) == sol[l1x_idx, : ]
192+ @test get_arr (sol) == vcat .( sol[l1x_idx, :], sol[l2x_idx, :])
193+ @test get_tuple (sol) == tuple .(sol[l1x_idx, : ], sol[l2x_idx, : ])
194+ @test get_obs (sol) == sol[l1x_idx, : ] + sol[l2x_idx, : ]
195+ @test get_obs_arr (sol) == vcat .(sol[l1x_idx, : ] + sol[l2x_idx, : ], sol[l1y_idx, : ] + sol[l2y_idx, : ])
196196
197197#=
198198using Plots
@@ -217,14 +217,109 @@ sol = solve(prob, Tsit5())
217217@test sol[@nonamespace sys. x] isa Vector{<: Vector }
218218@test sol. ps[p] == [1 , 2 , 3 ]
219219
220- getx = getu (sys, x)
221- get_mix_arr = getu (sys, [x, y])
222- get_mix_tuple = getu (sys, (x, y))
223220x_idx = variable_index .((sys,), [x[1 ], x[2 ], x[3 ]])
224221y_idx = variable_index (sys, y)
225- @test getx (sol) == sol[:, x_idx]
226- @test get_mix_arr (sol) == vcat .(sol[:, x_idx], sol[:, y_idx])
227- @test get_mix_tuple (sol) == tuple .(sol[:, x_idx], sol[:, y_idx])
222+ x_val = vcat .(getindex .((sol,), x_idx, :)... )
223+ y_val = sol[y_idx, :]
224+ obs_val = sol[x[1 ] + y]
225+
226+ # checking inference for mixed-type arrays will always fail
227+ for (sym, val, check_inference) in [
228+ (x, x_val, true ),
229+ (y, y_val, true ),
230+ (y_idx, y_val, true ),
231+ (x_idx, x_val, true ),
232+ (x[1 ] + y, obs_val, true ),
233+ ([x[1 ], x[2 ]], sol[[x[1 ], x[2 ]]], true ),
234+ ([x[1 ], x_idx[2 ]], sol[[x[1 ], x[2 ]]], true ),
235+ ([x, x[1 ] + y], [[i, j] for (i, j) in zip (x_val, obs_val)], false ),
236+ ([x, y], [[i, j] for (i, j) in zip (x_val, y_val)], false ),
237+ ([x, y_idx], [[i, j] for (i, j) in zip (x_val, y_val)], false ),
238+ ([x, x], [[i, i] for i in x_val], true ),
239+ ([x, x_idx], [[i, i] for i in x_val], false ),
240+ ((x, y), [(i, j) for (i, j) in zip (x_val, y_val)], true ),
241+ ((x, y_idx), [(i, j) for (i, j) in zip (x_val, y_val)], true ),
242+ ((x, x), [(i, i) for i in x_val], true ),
243+ ((x, x_idx), [(i, i) for i in x_val], true ),
244+ ((x, x[1 ]+ y), [(i, j) for (i, j) in zip (x_val, obs_val)], true ),
245+ ((x, (x[1 ] + y, y)), [(i, (k, j)) for (i, j, k) in zip (x_val, y_val, obs_val)], true ),
246+ ([x, [x[1 ] + y, y]], [[i, [k, j]] for (i, j, k) in zip (x_val, y_val, obs_val)], false ),
247+ ((x, [x[1 ] + y, y], (x[1 ] + y, y_idx)), [(i, [k, j], (k, j)) for (i, j, k) in zip (x_val, y_val, obs_val)], missing ),
248+ ([x, [x[1 ] + y, y], (x[1 ] + y, y_idx)], [[i, [k, j], (k, j)] for (i, j, k) in zip (x_val, y_val, obs_val)], false ),
249+ ]
250+ if check_inference === missing
251+ @test_broken @inferred getu (prob, sym)(sol)
252+ elseif check_inference
253+ @inferred getu (prob, sym)(sol)
254+ end
255+ @test getu (prob, sym)(sol) == val
256+ end
257+
258+ x_newval = [3.0 , 6.0 , 9.0 ]
259+ y_newval = 4.0
260+ x_probval = prob[x]
261+ y_probval = prob[y]
262+
263+ for (sym, oldval, newval, check_inference) in [
264+ (x, x_probval, x_newval, true ),
265+ (y, y_probval, y_newval, true ),
266+ (x_idx, x_probval, x_newval, true ),
267+ (y_idx, y_probval, y_newval, true ),
268+ ((x, y), (x_probval, y_probval), (x_newval, y_newval), true ),
269+ ([x, y], [x_probval, y_probval], [x_newval, y_newval], false ),
270+ ((x, y_idx), (x_probval, y_probval), (x_newval, y_newval), true ),
271+ ([x, y_idx], [x_probval, y_probval], [x_newval, y_newval], false ),
272+ ((x_idx, y), (x_probval, y_probval), (x_newval, y_newval), true ),
273+ ([x_idx, y], [x_probval, y_probval], [x_newval, y_newval], false ),
274+ ([x[1 : 2 ], [y_idx, x[3 ]]], [x_probval[1 : 2 ], [y_probval, x_probval[3 ]]], [x_newval[1 : 2 ], [y_newval, x_newval[3 ]]], true ),
275+ ([x[1 : 2 ], (y_idx, x[3 ])], [x_probval[1 : 2 ], (y_probval, x_probval[3 ])], [x_newval[1 : 2 ], (y_newval, x_newval[3 ])], false ),
276+ ((x[1 : 2 ], [y_idx, x[3 ]]), (x_probval[1 : 2 ], [y_probval, x_probval[3 ]]), (x_newval[1 : 2 ], [y_newval, x_newval[3 ]]), true ),
277+ ((x[1 : 2 ], (y_idx, x[3 ])), (x_probval[1 : 2 ], (y_probval, x_probval[3 ])), (x_newval[1 : 2 ], (y_newval, x_newval[3 ])), true ),
278+ ]
279+ getter = getu (prob, sym)
280+ setter! = setu (prob, sym)
281+ if check_inference
282+ @inferred getter (prob)
283+ end
284+ @test getter (prob) == oldval
285+ if check_inference
286+ @inferred setter! (prob, newval)
287+ else
288+ setter! (prob, newval)
289+ end
290+ @test getter (prob) == newval
291+ setter! (prob, oldval)
292+ @test getter (prob) == oldval
293+ end
294+
295+ pval = [1.0 , 2.0 , 3.0 ]
296+ pval_new = [4.0 , 5.0 , 6.0 ]
297+
298+ for (sym, oldval, newval, check_inference) in [
299+ (p[1 ], pval[1 ], pval_new[1 ], true ),
300+ (p, pval, pval_new, true ),
301+ ((p[1 ], p[2 ]), Tuple (pval[1 : 2 ]), Tuple (pval_new[1 : 2 ]), true ),
302+ ([p[1 ], p[2 ]], pval[1 : 2 ], pval_new[1 : 2 ], true ),
303+ ((p[1 ], p[2 : 3 ]), (pval[1 ], pval[2 : 3 ]), (pval_new[1 ], pval_new[2 : 3 ]), true ),
304+ ([p[1 ], p[2 : 3 ]], [pval[1 ], pval[2 : 3 ]], [pval_new[1 ], pval_new[2 : 3 ]], false ),
305+ ((p[1 ], (p[2 ],), [p[3 ]]), (pval[1 ], (pval[2 ],), [pval[3 ]]), (pval_new[1 ], (pval_new[2 ],), [pval_new[3 ]]), true ),
306+ ([p[1 ], (p[2 ],), [p[3 ]]], [pval[1 ], (pval[2 ],), [pval[3 ]]], [pval_new[1 ], (pval_new[2 ],), [pval_new[3 ]]], false ),
307+ ]
308+ getter = getp (prob, sym)
309+ setter! = setp (prob, sym)
310+ if check_inference
311+ @inferred getter (prob)
312+ end
313+ @test getter (prob) == oldval
314+ if check_inference
315+ @inferred setter! (prob, newval)
316+ else
317+ setter! (prob, newval)
318+ end
319+ @test getter (prob) == newval
320+ setter! (prob, oldval)
321+ @test getter (prob) == oldval
322+ end
228323
229324# accessing parameters
230325@variables t x (t)
0 commit comments