|
1 | 1 | function Lux.Training.compute_gradients_impl( |
2 | 2 | ::AutoZygote, objective_function::F, data, ts::Lux.Training.TrainState) where {F} |
3 | | - (loss, st, stats), back = Zygote.pullback( |
4 | | - objective_function, ts.model, ts.parameters, ts.states, data) |
5 | | - grads = back((one(loss), nothing, nothing))[2] |
| 3 | + @static if pkgversion(Zygote) ≥ v"0.7-" |
| 4 | + # Zygote 0.7 doesn't aggressively unthunk everything, so it is better to use a |
| 5 | + # closure here |
| 6 | + (loss, st, stats), back = Zygote.pullback( |
| 7 | + ps -> objective_function(ts.model, ps, ts.states, data), ts.parameters) |
| 8 | + grads = only(back((one(loss), nothing, nothing))) |
| 9 | + else |
| 10 | + (loss, st, stats), back = Zygote.pullback( |
| 11 | + objective_function, ts.model, ts.parameters, ts.states, data |
| 12 | + ) |
| 13 | + grads = back((one(loss), nothing, nothing))[2] |
| 14 | + end |
6 | 15 | @set! ts.states = st |
7 | 16 | return grads, loss, stats, ts |
8 | 17 | end |
0 commit comments