using Bijectors
using Flux
using ProgressMeter
using StatsBase
using StatsPlots
using Turing
using Zygote
function main()
n_iter = 3000
lr = 1e-3
n_samples = 10
n_batch = 4
data = randn(2, n_samples)
base_dist = MvNormal(zeros(2), ones(2))
layers = Bijectors.Coupling(θ -> Bijectors.Shift(θ) ∘ Bijectors.Scale(θ), 2)
flow = transformed(base_dist, layers)
pars = Flux.params(flow)
prog = ProgressMeter.Progress(n_iter)
opt = ADAM(lr)
for i = 1:n_iter
batch_idx = sample(1:n_samples, n_batch, replace=false)
batch = view(data, :, batch_idx)
loss, back = Zygote.pullback(pars) do
-mean(logpdf.(Ref(flow), eachcol(batch)))
end
grad = back(one(loss))
Flux.Optimise.update!(opt, pars, grad)
ProgressMeter.next!(prog; showvalues=[(:loss, loss),])
end
end
julia> main()
ERROR: MethodError: Cannot `convert` an object of type ChainRulesCore.ZeroTangent to an object of type ChainRulesCore.NoTangent
Closest candidates are:
convert(::Type{T}, ::T) where T at essentials.jl:205
Stacktrace:
[1] fill!(dest::Vector{ChainRulesCore.NoTangent}, x::ChainRulesCore.ZeroTangent)
@ Base ./array.jl:333
[2] _map_notzeropres!(f::typeof(Zygote.accum), fillvalue::ChainRulesCore.ZeroTangent, C::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, A::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, B::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64})
@ SparseArrays.HigherOrderFns /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:345
[3] _noshapecheck_map(f::typeof(Zygote.accum), A::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, Bs::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64})
@ SparseArrays.HigherOrderFns /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:166
[4] _shapecheckbc(::Function, ::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, ::Vararg{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, N} where N)
@ SparseArrays.HigherOrderFns /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:1026
[5] _copy(::Function, ::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, ::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64})
@ SparseArrays.HigherOrderFns /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:1016
[6] copy
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/SparseArrays/src/higherorderfns.jl:1012 [inlined]
[7] materialize
@ ./broadcast.jl:883 [inlined]
[8] accum(x::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, ys::SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64})
@ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:25
[9] macro expansion
@ ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27 [inlined]
[10] accum(x::NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}, y::NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}})
@ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27
[11] macro expansion
@ ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27 [inlined]
[12] accum(x::NamedTuple{(:θ, :mask), Tuple{Nothing, NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}}}, y::NamedTuple{(:θ, :mask), Tuple{Nothing, NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}}})
@ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27
[13] macro expansion
@ ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27 [inlined]
[14] accum(x::NamedTuple{(:orig,), Tuple{NamedTuple{(:θ, :mask), Tuple{Nothing, NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}}}}}, y::NamedTuple{(:orig,), Tuple{NamedTuple{(:θ, :mask), Tuple{Nothing, NamedTuple{(:A_1, :A_2, :A_3), Tuple{SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}, SparseArrays.SparseMatrixCSC{ChainRulesCore.NoTangent, Int64}}}}}}})
@ Zygote ~/.julia/packages/Zygote/fDJjj/src/lib/lib.jl:27
[15] getindex
@ ./tuple.jl:29 [inlined]
[16] gradindex
@ ~/.julia/packages/Zygote/fDJjj/src/compiler/reverse.jl:12 [inlined]
[17] Pullback
@ ~/.julia/packages/Bijectors/du7oP/src/interface.jl:102 [inlined]
[18] (::typeof(∂(forward)))(Δ::NamedTuple{(:rv, :logabsdetjac), Tuple{Vector{Float64}, Float64}})
@ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
[19] macro expansion
@ ~/.julia/packages/Bijectors/du7oP/src/bijectors/composed.jl:0 [inlined]
[20] Pullback
@ ~/.julia/packages/Bijectors/du7oP/src/bijectors/composed.jl:222 [inlined]
[21] (::typeof(∂(forward)))(Δ::NamedTuple{(:rv, :logabsdetjac), Tuple{Vector{Float64}, Float64}})
@ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
[22] Pullback
@ ~/.julia/packages/Bijectors/du7oP/src/transformed_distribution.jl:108 [inlined]
[23] (::typeof(∂(_logpdf)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
[24] Pullback
@ ~/.julia/packages/Distributions/1313k/src/multivariates.jl:201 [inlined]
[25] (::typeof(∂(logpdf)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/fDJjj/src/compiler/interface2.jl:0
[26] #1073
@ ~/.julia/packages/Zygote/fDJjj/src/lib/broadcast.jl:188 [inlined]
Hi,
Couplingcurrently has an issue with differentiation.Here's a reproducible example.
Seems like an issue with the chain rule of sparse arrays?