Skip to content

Commit 071ad2f

Browse files
chore: format
1 parent 453a933 commit 071ad2f

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ end
120120
VA[sym], ODESolution_getindex_pullback
121121
end
122122

123-
@adjoint function Base.getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T
123+
@adjoint function Base.getindex(
124+
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T}
124125
function ODESolution_getindex_pullback(Δ)
125126
sym = sym isa Tuple ? collect(sym) : sym
126127
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)

test/downstream/symbol_indexing.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Zygote, Test
1+
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface,
2+
Zygote, Test
23
using Optimization, OptimizationOptimJL
34
using ModelingToolkit: t_nounits as t, D_nounits as D
45

@@ -98,29 +99,29 @@ end
9899
@test size(sol[[lorenz1.x, lorenz2.x], :]) == size(sol[[1, 2], :]) == size(sol[1:2, :])
99100

100101
gs_sym, = Zygote.gradient(sol) do sol
101-
sum(sol[lorenz1.x])
102+
sum(sol[lorenz1.x])
102103
end
103104
idx_sym = SymbolicIndexingInterface.variable_index(sys, lorenz1.x)
104105
true_grad_sym = zeros(length(ModelingToolkit.unknowns(sys)))
105-
true_grad_sym[idx_sym] = 1.
106+
true_grad_sym[idx_sym] = 1.0
106107

107108
@test all(map(x -> x == true_grad_sym, gs_sym))
108109

109110
gs_vec, = Zygote.gradient(sol) do sol
110-
sum(sum.(sol[[lorenz1.x, lorenz2.x]]))
111+
sum(sum.(sol[[lorenz1.x, lorenz2.x]]))
111112
end
112113
idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
113114
true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys)))
114-
true_grad_vecsym[idx_vecsym] .= 1.
115+
true_grad_vecsym[idx_vecsym] .= 1.0
115116

116117
@test all(map(x -> x == true_grad_vecsym, gs_vec))
117118

118119
gs_tup, = Zygote.gradient(sol) do sol
119-
sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)])))
120+
sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)])))
120121
end
121122
idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lorenz2.x])
122123
true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys)))
123-
true_grad_tupsym[idx_tupsym] .= 1.
124+
true_grad_tupsym[idx_tupsym] .= 1.0
124125

125126
@test all(map(x -> x == true_grad_tupsym, gs_tup))
126127

0 commit comments

Comments
 (0)