-
Notifications
You must be signed in to change notification settings - Fork 40
Closed as not planned
Description
I've found that Zygote fails to compute gradients when using the method of logpdf defined here
Here's a MWE:
using Bijectors
using DistributionsAD
using Flux
using Zygote
d = MvNormal(zeros(2), ones(2))
b = PlanarLayer(2)
flow = transformed(d, b)
x = [0.42 0.24; 0.42 0.24]
"""Use the optimised `logpdf` call."""
loss_(flow, x) = -sum(logpdf(flow, x))
"""Rearrange to use default `logpdf` in `Distributions`."""
function loss2_(flow, x)
things = map(eachcol(x)) do obs
logpdf(flow, obs)
end
return -sum(things)
end
@show loss_(flow, x)
@show loss2_(flow, x)
println()
gs = gradient(() -> loss_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]]
gs = gradient(() -> loss2_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]];With output:
loss_(flow, x) = 3.089176357252711
loss2_(flow, x) = 3.089176357252711
gs.grads[(Flux.params(b))[1]] = nothing
gs.grads[(Flux.params(b))[1]] = [-2.603210756288831, -4.3264084139896095]
tested on Bijectors v0.10.0.
I'm not sure, but maybe the optimised dispatch for logpdf (or some of the methods called within) need additional chainrules support?
Metadata
Metadata
Assignees
Labels
No labels