Skip to content

Commit e1c3a4a

Browse files
committed
feat: update training api to account for thunking
1 parent 63750f1 commit e1c3a4a

1 file changed

Lines changed: 12 additions & 3 deletions

File tree

ext/LuxZygoteExt/training.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
function Lux.Training.compute_gradients_impl(
22
::AutoZygote, objective_function::F, data, ts::Lux.Training.TrainState) where {F}
3-
(loss, st, stats), back = Zygote.pullback(
4-
objective_function, ts.model, ts.parameters, ts.states, data)
5-
grads = back((one(loss), nothing, nothing))[2]
3+
@static if pkgversion(Zygote) v"0.7-"
4+
# Zygote 0.7 doesn't aggressively unthunk everything, so it is better to use a
5+
# closure here
6+
(loss, st, stats), back = Zygote.pullback(
7+
ps -> objective_function(ts.model, ps, ts.states, data), ts.parameters)
8+
grads = only(back((one(loss), nothing, nothing)))
9+
else
10+
(loss, st, stats), back = Zygote.pullback(
11+
objective_function, ts.model, ts.parameters, ts.states, data
12+
)
13+
grads = back((one(loss), nothing, nothing))[2]
14+
end
615
@set! ts.states = st
716
return grads, loss, stats, ts
817
end

0 commit comments

Comments
 (0)