Skip to content

Commit 548b9d5

Browse files
test: extensively test getu/setu and their type-stability
1 parent 78aed2f commit 548b9d5

File tree

1 file changed

+106
-11
lines changed

1 file changed

+106
-11
lines changed

test/downstream/symbol_indexing.jl

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,11 @@ l2x_idx = variable_index(sol, lorenz2.x)
188188
l1y_idx = variable_index(sol, lorenz1.y)
189189
l2y_idx = variable_index(sol, lorenz2.y)
190190

191-
@test getx(sol) == sol[:, l1x_idx]
192-
@test get_arr(sol) == sol[:, [l1x_idx, l2x_idx]]
193-
@test get_tuple(sol) == tuple.(sol[:, l1x_idx], sol[:, l2x_idx])
194-
@test get_obs(sol) == sol[:, l1x_idx] + sol[:, l2x_idx]
195-
@test get_obs_arr(sol) == vcat.(sol[:, l1x_idx] + sol[:, l2x_idx], sol[:, l1y_idx] + sol[:, l2y_idx])
191+
@test getx(sol) == sol[l1x_idx, :]
192+
@test get_arr(sol) == vcat.(sol[l1x_idx, :], sol[l2x_idx, :])
193+
@test get_tuple(sol) == tuple.(sol[l1x_idx, :], sol[l2x_idx, :])
194+
@test get_obs(sol) == sol[l1x_idx, :] + sol[l2x_idx, :]
195+
@test get_obs_arr(sol) == vcat.(sol[l1x_idx, :] + sol[l2x_idx, :], sol[l1y_idx, :] + sol[l2y_idx, :])
196196

197197
#=
198198
using Plots
@@ -217,14 +217,109 @@ sol = solve(prob, Tsit5())
217217
@test sol[@nonamespace sys.x] isa Vector{<:Vector}
218218
@test sol.ps[p] == [1, 2, 3]
219219

220-
getx = getu(sys, x)
221-
get_mix_arr = getu(sys, [x, y])
222-
get_mix_tuple = getu(sys, (x, y))
223220
x_idx = variable_index.((sys,), [x[1], x[2], x[3]])
224221
y_idx = variable_index(sys, y)
225-
@test getx(sol) == sol[:, x_idx]
226-
@test get_mix_arr(sol) == vcat.(sol[:, x_idx], sol[:, y_idx])
227-
@test get_mix_tuple(sol) == tuple.(sol[:, x_idx], sol[:, y_idx])
222+
x_val = vcat.(getindex.((sol,), x_idx, :)...)
223+
y_val = sol[y_idx, :]
224+
obs_val = sol[x[1] + y]
225+
226+
# checking inference for mixed-type arrays will always fail
227+
for (sym, val, check_inference) in [
228+
(x, x_val, true),
229+
(y, y_val, true),
230+
(y_idx, y_val, true),
231+
(x_idx, x_val, true),
232+
(x[1] + y, obs_val, true),
233+
([x[1], x[2]], sol[[x[1], x[2]]], true),
234+
([x[1], x_idx[2]], sol[[x[1], x[2]]], true),
235+
([x, x[1] + y], [[i, j] for (i, j) in zip(x_val, obs_val)], false),
236+
([x, y], [[i, j] for (i, j) in zip(x_val, y_val)], false),
237+
([x, y_idx], [[i, j] for (i, j) in zip(x_val, y_val)], false),
238+
([x, x], [[i, i] for i in x_val], true),
239+
([x, x_idx], [[i, i] for i in x_val], false),
240+
((x, y), [(i, j) for (i, j) in zip(x_val, y_val)], true),
241+
((x, y_idx), [(i, j) for (i, j) in zip(x_val, y_val)], true),
242+
((x, x), [(i, i) for i in x_val], true),
243+
((x, x_idx), [(i, i) for i in x_val], true),
244+
((x, x[1]+y), [(i, j) for (i, j) in zip(x_val, obs_val)], true),
245+
((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], true),
246+
([x, [x[1] + y, y]], [[i, [k, j]] for (i, j, k) in zip(x_val, y_val, obs_val)], false),
247+
((x, [x[1] + y, y], (x[1] + y, y_idx)), [(i, [k, j], (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], missing),
248+
([x, [x[1] + y, y], (x[1] + y, y_idx)], [[i, [k, j], (k, j)] for (i, j, k) in zip(x_val, y_val, obs_val)], false),
249+
]
250+
if check_inference === missing
251+
@test_broken @inferred getu(prob, sym)(sol)
252+
elseif check_inference
253+
@inferred getu(prob, sym)(sol)
254+
end
255+
@test getu(prob, sym)(sol) == val
256+
end
257+
258+
x_newval = [3.0, 6.0, 9.0]
259+
y_newval = 4.0
260+
x_probval = prob[x]
261+
y_probval = prob[y]
262+
263+
for (sym, oldval, newval, check_inference) in [
264+
(x, x_probval, x_newval, true),
265+
(y, y_probval, y_newval, true),
266+
(x_idx, x_probval, x_newval, true),
267+
(y_idx, y_probval, y_newval, true),
268+
((x, y), (x_probval, y_probval), (x_newval, y_newval), true),
269+
([x, y], [x_probval, y_probval], [x_newval, y_newval], false),
270+
((x, y_idx), (x_probval, y_probval), (x_newval, y_newval), true),
271+
([x, y_idx], [x_probval, y_probval], [x_newval, y_newval], false),
272+
((x_idx, y), (x_probval, y_probval), (x_newval, y_newval), true),
273+
([x_idx, y], [x_probval, y_probval], [x_newval, y_newval], false),
274+
([x[1:2], [y_idx, x[3]]], [x_probval[1:2], [y_probval, x_probval[3]]], [x_newval[1:2], [y_newval, x_newval[3]]], true),
275+
([x[1:2], (y_idx, x[3])], [x_probval[1:2], (y_probval, x_probval[3])], [x_newval[1:2], (y_newval, x_newval[3])], false),
276+
((x[1:2], [y_idx, x[3]]), (x_probval[1:2], [y_probval, x_probval[3]]), (x_newval[1:2], [y_newval, x_newval[3]]), true),
277+
((x[1:2], (y_idx, x[3])), (x_probval[1:2], (y_probval, x_probval[3])), (x_newval[1:2], (y_newval, x_newval[3])), true),
278+
]
279+
getter = getu(prob, sym)
280+
setter! = setu(prob, sym)
281+
if check_inference
282+
@inferred getter(prob)
283+
end
284+
@test getter(prob) == oldval
285+
if check_inference
286+
@inferred setter!(prob, newval)
287+
else
288+
setter!(prob, newval)
289+
end
290+
@test getter(prob) == newval
291+
setter!(prob, oldval)
292+
@test getter(prob) == oldval
293+
end
294+
295+
pval = [1.0, 2.0, 3.0]
296+
pval_new = [4.0, 5.0, 6.0]
297+
298+
for (sym, oldval, newval, check_inference) in [
299+
(p[1], pval[1], pval_new[1], true),
300+
(p, pval, pval_new, true),
301+
((p[1], p[2]), Tuple(pval[1:2]), Tuple(pval_new[1:2]), true),
302+
([p[1], p[2]], pval[1:2], pval_new[1:2], true),
303+
((p[1], p[2:3]), (pval[1], pval[2:3]), (pval_new[1], pval_new[2:3]), true),
304+
([p[1], p[2:3]], [pval[1], pval[2:3]], [pval_new[1], pval_new[2:3]], false),
305+
((p[1], (p[2],), [p[3]]), (pval[1], (pval[2],), [pval[3]]), (pval_new[1], (pval_new[2],), [pval_new[3]]), true),
306+
([p[1], (p[2],), [p[3]]], [pval[1], (pval[2],), [pval[3]]], [pval_new[1], (pval_new[2],), [pval_new[3]]], false),
307+
]
308+
getter = getp(prob, sym)
309+
setter! = setp(prob, sym)
310+
if check_inference
311+
@inferred getter(prob)
312+
end
313+
@test getter(prob) == oldval
314+
if check_inference
315+
@inferred setter!(prob, newval)
316+
else
317+
setter!(prob, newval)
318+
end
319+
@test getter(prob) == newval
320+
setter!(prob, oldval)
321+
@test getter(prob) == oldval
322+
end
228323

229324
# accessing parameters
230325
@variables t x(t)

0 commit comments

Comments
 (0)