Skip to content

Zygote error differentiating Coupling #203

@Red-Portal

Description

@Red-Portal

Hi, Coupling currently has an issue with differentiation.
Here's a reproducible example.

using Bijectors
using Flux
using ProgressMeter
using StatsBase
using StatsPlots
using Turing
using Zygote

function main()
    n_iter         = 3000
    lr                = 1e-3
    n_samples = 10
    n_batch     = 4
    data          = randn(2, n_samples)

    base_dist = MvNormal(zeros(2), ones(2))
    layers    = Bijectors.Coupling-> Bijectors.Shift(θ)  Bijectors.Scale(θ), 2)

    flow = transformed(base_dist, layers)
    pars = Flux.params(flow)
    prog = ProgressMeter.Progress(n_iter)
    opt  = ADAM(lr)

    for i = 1:n_iter
        batch_idx = sample(1:n_samples, n_batch, replace=false)
        batch     = view(data, :, batch_idx)
        loss, back = Zygote.pullback(pars) do
            -mean(logpdf.(Ref(flow), eachcol(batch)))
        end
        grad       = back(one(loss))

        Flux.Optimise.update!(opt, pars, grad)
        ProgressMeter.next!(prog; showvalues=[(:loss, loss),]) 
    end
end
julia> main()
ERROR: MethodError: Cannot `convert` an object of type ChainRulesCore.ZeroTangent to an object of type ChainRulesCore.NoTangent
Closest candidates are:
  convert(::Type{T}, ::T) where T at essentials.jl:205
Stacktrace:
  [1] fill!(dest::Vector{ChainRulesCore.NoTangent}, x::ChainRulesCore.ZeroTangent)
    @ Base ./array.jl:333
  [2] _map_notzeropres!(f::typeof(Zygote.accum), fillvalue::ChainRulesCore.ZeroTangent, C::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, A::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, B::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64})
    @ SparseArrays.HigherOrderFns /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:345
  [3] _noshapecheck_map(f::typeof(Zygote.accum), A::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, Bs::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64})
    @ SparseArrays.HigherOrderFns /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:166
  [4] _shapecheckbc(::Function, ::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, ::Vararg{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, N} where N)
    @ SparseArrays.HigherOrderFns /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:1026
  [5] _copy(::Function, ::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, ::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64})
    @ SparseArrays.HigherOrderFns /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:1016
  [6] copy
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:1012 [inlined]
  [7] materialize
    @ ./broadcast.jl:883 [inlined]
  [8] accum(x::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, ys::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:25
  [9] macro expansion
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27 [inlined]
 [10] accum(x::NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}, y::NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27
 [11] macro expansion
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27 [inlined]
 [12] accum(x::NamedTuple{(, :mask), Tuple{Nothing, NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}}}, y::NamedTuple{(, :mask), Tuple{Nothing, NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}}})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27
 [13] macro expansion
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27 [inlined]
 [14] accum(x::NamedTuple{(:orig,), Tuple{NamedTuple{(, :mask), Tuple{Nothing, NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}}}}}, y::NamedTuple{(:orig,), Tuple{NamedTuple{(, :mask), Tuple{Nothing, NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}}}}})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27
 [15] getindex
    @ ./tuple.jl:29 [inlined]
 [16] gradindex
    @ ~/.julia/packages/Zygote/fDJjj/src/compiler/reverse.jl:12 [inlined]
 [17] Pullback
    @ ~/.julia/packages/Bijectors/du7oP/src/interface.jl:102 [inlined]
 [18] (::typeof((forward)))(Δ::NamedTuple{(:rv, :logabsdetjac), Tuple{Vector{Float64}, Float64}})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [19] macro expansion
    @ ~/.julia/packages/Bijectors/du7oP/src/bijectors/composed.jl:0 [inlined]
 [20] Pullback
    @ ~/.julia/packages/Bijectors/du7oP/src/bijectors/composed.jl:222 [inlined]
 [21] (::typeof((forward)))(Δ::NamedTuple{(:rv, :logabsdetjac), Tuple{Vector{Float64}, Float64}})
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [22] Pullback
    @ ~/.julia/packages/Bijectors/du7oP/src/transformed_distribution.jl:108 [inlined]
 [23] (::typeof((_logpdf)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [24] Pullback
    @ ~/.julia/packages/Distributions/1313k/src/multivariates.jl:201 [inlined]
 [25] (::typeof((logpdf)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
 [26] #1073
    @ ~/.julia/packages/Zygote/fDJjj/src/lib/broadcast.jl:188 [inlined]
(PartialSMC) pkg> status
      Status `~/Projects/PartialSMC/Project.toml`
  [76274a88] Bijectors v0.9.9
  [31c24e10] Distributions v0.25.20
  [ced4e74d] DistributionsAD v0.6.31
  [634d3b9d] DrWatson v2.6.0
  [587475ba] Flux v0.12.7
  [5ab0869b] KernelDensity v0.6.3
  [872c559c] NNlib v0.7.29
  [90014a1f] PDMats v0.11.1
  [91a5bcdd] Plots v1.22.6
  [92933f4c] ProgressMeter v1.7.1
  [d330b81b] PyPlot v2.10.0
  [74087812] Random123 v1.4.2
  [e6cf234a] RandomNumbers v1.5.3
  [276daf66] SpecialFunctions v1.7.0
  [2913bbd2] StatsBase v0.33.11
  [4c63d2b9] StatsFuns v0.9.12
  [f3b207a7] StatsPlots v0.14.28
  [fce5fe82] Turing v0.18.0
  [e88e6eb3] Zygote v0.6.28
  [9a3f8284] Random
  [10745b16] Statistics

Seems like an issue with the chain rule of sparse arrays?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions