diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index e4abf842..c5c434c0 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -30,7 +30,6 @@ jobs: - Mooncake - Tracker - ReverseDiff - - Zygote steps: - uses: actions/checkout@v4 diff --git a/Project.toml b/Project.toml index 6a8a8ce6..4a8b2158 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,6 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BijectorsDistributionsADExt = "DistributionsAD" @@ -40,7 +39,6 @@ BijectorsMooncakeExt = "Mooncake" BijectorsReverseDiffExt = "ReverseDiff" BijectorsReverseDiffChainRulesExt = ["ChainRules", "ReverseDiff"] BijectorsTrackerExt = "Tracker" -BijectorsZygoteExt = "Zygote" [compat] ArgCheck = "1, 2" @@ -64,7 +62,6 @@ ReverseDiff = "1" Roots = "1.3.15, 2" Statistics = "1" Tracker = "0.2" -Zygote = "0.6.63, 0.7" julia = "1.10.8" [extras] @@ -75,4 +72,3 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/docs/Project.toml b/docs/Project.toml index 9c81e994..e64b9b85 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,10 +2,8 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Documenter = "0.27" Functors = "0.3" -StableRNGs = "1" -Zygote = "0.6" \ No newline at end of file +StableRNGs = "1" \ No newline at end of file diff --git a/docs/src/examples.md b/docs/src/examples.md index 67b1f2a3..368d4bed 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -112,7 +112,7 @@ y = rand(rng, td) Want to fit the flow? ```@repl normalizing-flows -using Zygote +using ForwardDiff # Construct the flow. b = PlanarLayer(2) @@ -145,7 +145,7 @@ f = NLLObjective(reconstruct, MvNormal(2, 1), xs); # Train using gradient descent. ε = 1e-3; for i in 1:100 - (∇s,) = Zygote.gradient(f, θs) + ∇s = ForwardDiff.gradient(θ -> f(θ), θs) θs = fmap(θs, ∇s) do θ, ∇ θ - ε .* ∇ end diff --git a/ext/BijectorsZygoteExt.jl b/ext/BijectorsZygoteExt.jl deleted file mode 100644 index 0db59007..00000000 --- a/ext/BijectorsZygoteExt.jl +++ /dev/null @@ -1,198 +0,0 @@ -module BijectorsZygoteExt - -using Zygote: Zygote, @adjoint, pullback -using Bijectors: - Elementwise, - SimplexBijector, - simplex_link_jacobian, - simplex_invlink_jacobian, - simplex_logabsdetjac_gradient, - Inverse, - maphcat, - IrrationalConstants, - Distributions, - logabsdetjac, - _logabsdetjac_scale, - _simplex_bijector, - _simplex_inv_bijector, - replace_diag, - jacobian, - _transform_ordered, - _transform_inverse_ordered, - find_alpha, - pd_logpdf_with_trans, - istraining, - eachcolmaphcat, - sumeachcol, - pd_link, - pd_from_lower, - lower_triangular, - upper_triangular, - getlogp - -using Bijectors.LinearAlgebra -using Bijectors.Distributions: LocationScale - -@adjoint istraining() = true, _ -> nothing - -@adjoint function eachcolmaphcat(f, x1, x2) - function g(f, x1, x2) - init = reshape(f(view(x1, :, 1), x2[1]), :, 1) - return reduce(hcat, [f(view(x1, :, i), x2[i]) for i in 2:size(x1, 2)]; init=init) - end - return pullback(g, f, x1, x2) -end -@adjoint function eachcolmaphcat(f, x) - function g(f, x) - init = reshape(f(view(x, :, 1)), :, 1) - return reduce(hcat, [f(view(x, :, i)) for i in 2:size(x, 2)]; init=init) - end - return pullback(g, f, x) -end -@adjoint function sumeachcol(f, x1, x2) - g(f, x1, x2) = sum([f(view(x1, :, i), x2[i]) for i in 1:size(x1, 2)]) - return pullback(g, f, x1, x2) -end - -@adjoint function logabsdetjac(b::Elementwise{typeof(log)}, x::AbstractVector) - return -sum(log, x), Δ -> (nothing, -Δ ./ x) -end - -# AD implementations -@adjoint function _logabsdetjac_scale(a::Real, x::Real, ::Val{0}) - return _logabsdetjac_scale(a, x, Val(0)), Δ -> (inv(a) .* Δ, nothing, nothing) -end -@adjoint function _logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{0}) - J = fill(inv.(a), length(x)) - return _logabsdetjac_scale(a, x, Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing) -end -@adjoint function _logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{0}) - J = fill(size(x, 1) / a, size(x, 2)) - return _logabsdetjac_scale(a, x, Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing) -end -@adjoint function _logabsdetjac_scale(a::AbstractVector, x::AbstractVector, ::Val{1}) - # ∂ᵢ (∑ⱼ log|aⱼ|) = ∑ⱼ δᵢⱼ ∂ᵢ log|aⱼ| - # = ∂ᵢ log |aᵢ| - # = (1 / aᵢ) ∂ᵢ aᵢ - # = (1 / aᵢ) - J = inv.(a) - return _logabsdetjac_scale(a, x, Val(1)), Δ -> (J .* Δ, nothing, nothing) -end -@adjoint function _logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{1}) - Jᵀ = repeat(inv.(a), 1, size(x, 2)) - return _logabsdetjac_scale(a, x, Val(1)), Δ -> (Jᵀ * Δ, nothing, nothing) -end -## Positive definite matrices -@adjoint function replace_diag(::typeof(log), X) - f(i, j) = i == j ? log(X[i, j]) : X[i, j] - out = f.(1:size(X, 1), (1:size(X, 2))') - return out, ∇ -> begin - g(i, j) = i == j ? ∇[i, j] / X[i, j] : ∇[i, j] - (nothing, g.(1:size(X, 1), (1:size(X, 2))')) - end -end -@adjoint function replace_diag(::typeof(exp), X) - f(i, j) = ifelse(i == j, exp(X[i, j]), X[i, j]) - out = f.(1:size(X, 1), (1:size(X, 2))') - return out, ∇ -> begin - g(i, j) = ifelse(i == j, ∇[i, j] * exp(X[i, j]), ∇[i, j]) - (nothing, g.(1:size(X, 1), (1:size(X, 2))')) - end -end - -@adjoint function pd_logpdf_with_trans(d, X::AbstractMatrix{<:Real}, transform::Bool) - return pullback(pd_logpdf_with_trans_zygote, d, X, transform) -end -function pd_logpdf_with_trans_zygote(d, X::AbstractMatrix{<:Real}, transform::Bool) - T = eltype(X) - Xcf = cholesky(X; check=false) - if !issuccess(Xcf) - Xcf = cholesky(X + max(eps(T), eps(T) * norm(X)) * I; check=true) - end - lp = getlogp(d, Xcf, X) - if transform && isfinite(lp) - factors = Xcf.factors - n = size(d, 1) - k = n + 2 - @inbounds for i in diagind(factors) - k -= 1 - lp += k * log(factors[i]) - end - lp += n * oftype(lp, IrrationalConstants.logtwo) - end - return lp -end - -# Simplex adjoints - -@adjoint function _simplex_bijector(X::AbstractVector, b::SimplexBijector) - return _simplex_bijector(X, b), Δ -> (simplex_link_jacobian(X)' * Δ, nothing) -end -@adjoint function _simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector) - return _simplex_inv_bijector(Y, b), Δ -> (simplex_invlink_jacobian(Y)' * Δ, nothing) -end - -@adjoint function _simplex_bijector(X::AbstractMatrix, b::SimplexBijector) - return _simplex_bijector(X, b), - Δ -> begin - maphcat(eachcol(X), eachcol(Δ)) do c1, c2 - simplex_link_jacobian(c1)' * c2 - end, - nothing - end -end -@adjoint function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector) - return _simplex_inv_bijector(Y, b), - Δ -> begin - maphcat(eachcol(Y), eachcol(Δ)) do c1, c2 - simplex_invlink_jacobian(c1)' * c2 - end, - nothing - end -end - -@adjoint function logabsdetjac(b::SimplexBijector, x::AbstractVector) - return logabsdetjac(b, x), Δ -> begin - (nothing, simplex_logabsdetjac_gradient(x) * Δ) - end -end - -# LocationScale fix -# TODO: Remove this. -@adjoint function Base.minimum(d::Distributions.LocationScale) - function _minimum(d) - m = minimum(d.ρ) - if isfinite(m) - return d.μ + d.σ * m - else - return m - end - end - return pullback(_minimum, d) -end -@adjoint function Base.maximum(d::LocationScale) - function _maximum(d) - m = maximum(d.ρ) - if isfinite(m) - return d.μ + d.σ * m - else - return m - end - end - return pullback(_maximum, d) -end -@adjoint function pd_from_lower(X::AbstractMatrix) - return LowerTriangular(X) * LowerTriangular(X)', - Δ -> begin - Xl = LowerTriangular(X) - return (LowerTriangular(Δ' * Xl + Δ * Xl),) - end -end -@adjoint function pd_link(X::AbstractMatrix{<:Real}) - return pullback(X) do X - Y = cholesky(X; check=true).L - return replace_diag(log, Y) - end -end - -end diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index 088b405b..e6a235d9 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -1,6 +1,6 @@ struct PDBijector <: Bijector end -# This function has custom adjoints defined for Tracker, Zygote and ReverseDiff. +# This function has custom adjoints defined for Tracker and ReverseDiff. # I couldn't find a mutation-free implementation that maintains TrackedArrays in Tracker # and ReverseDiff, hence the need for custom adjoints. function replace_diag(f, X) diff --git a/src/chainrules.jl b/src/chainrules.jl index f15e1c22..a676f1cf 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -286,5 +286,6 @@ function ChainRulesCore.rrule(::typeof(pd_from_upper), X::AbstractMatrix) end end -# Fixes Zygote's issues with `@debug` +# Fixes AD issues with `@debug` ChainRulesCore.@non_differentiable _debug(::Any) + diff --git a/test/Project.toml b/test/Project.toml index 64657bb0..69849a10 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -25,7 +25,6 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "5" @@ -50,5 +49,4 @@ Mooncake = "0.4" ReverseDiff = "1.4.2" StableRNGs = "1" Tracker = "0.2.11" -Zygote = "0.6.63, 0.7" julia = "1.10" diff --git a/test/ad/utils.jl b/test/ad/utils.jl index c36ce2dd..2abf7622 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -6,7 +6,6 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) if !( b in ( :ForwardDiff, - :Zygote, :Mooncake, :ReverseDiff, :Enzyme, @@ -32,16 +31,6 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end - if AD == "All" || AD == "Zygote" - if :Zygote in broken - @test_broken Zygote.gradient(f, x)[1] ≈ finitediff rtol = rtol atol = atol - else - ∇zygote = Zygote.gradient(f, x)[1] - @test (all(iszero, finitediff) && ∇zygote === nothing) || - isapprox(∇zygote, finitediff; rtol=rtol, atol=atol) - end - end - if AD == "All" || AD == "ReverseDiff" if :ReverseDiff in broken @test_broken ReverseDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol diff --git a/test/runtests.jl b/test/runtests.jl index c07404af..4a1e4c0c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,7 +13,6 @@ using LogExpFunctions using Mooncake using ReverseDiff using Tracker -using Zygote using Random, LinearAlgebra, Test