Skip to content

Zygote AD & logpdf for transformed multivariate #217

@tpgillam

Description

@tpgillam

I've found that Zygote fails to compute gradients when using the method of logpdf defined here

Here's a MWE:

using Bijectors
using DistributionsAD
using Flux
using Zygote

d = MvNormal(zeros(2), ones(2))
b = PlanarLayer(2)
flow = transformed(d, b)

x = [0.42 0.24; 0.42 0.24]

"""Use the optimised `logpdf` call."""
loss_(flow, x) = -sum(logpdf(flow, x))

"""Rearrange to use default `logpdf` in `Distributions`."""
function loss2_(flow, x)
    things = map(eachcol(x)) do obs
        logpdf(flow, obs)
    end
    return -sum(things)
end

@show loss_(flow, x)
@show loss2_(flow, x)

println()

gs = gradient(() -> loss_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]]

gs = gradient(() -> loss2_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]];

With output:

loss_(flow, x) = 3.089176357252711
loss2_(flow, x) = 3.089176357252711

gs.grads[(Flux.params(b))[1]] = nothing
gs.grads[(Flux.params(b))[1]] = [-2.603210756288831, -4.3264084139896095]

tested on Bijectors v0.10.0.

I'm not sure, but maybe the optimised dispatch for logpdf (or some of the methods called within) need additional chainrules support?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions