From fd8dc6b5aee8c1aa4ed821150333da0c0780b5aa Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Fri, 22 Feb 2019 16:15:09 -0500 Subject: [PATCH 1/5] Some limited definitions for nested AD This adds a bunch of definitions to make nested AD work and adds tests for second-order AD. Unfortunately by the time we get to third-order AD, the types get so large that base decides to go on vacation while it thinks about whether or not it might be willing to compile a function with a type of such complexity. Additionally, Zygote introduces some unnecessary stacks, which then prevent higher order AD. I plan to work on both of those issues, but in the meantime, here are the changes to Zygote required to make this work. --- src/lib/array.jl | 21 ++++++++++++++----- src/lib/base.jl | 2 +- src/lib/lib.jl | 53 +++++++++++++++++++++++++++++++++++++++-------- src/lib/real.jl | 4 ++++ test/features.jl | 26 +++++++++++++++++++++++ test/gradcheck.jl | 25 ++++++++++++++++++++++ 6 files changed, 116 insertions(+), 15 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 329676648..aed553f92 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -7,12 +7,23 @@ Base.zero(xs::AbstractArray{Any}) = fill!(similar(xs), nothing) +struct ∇getindex{T,S} + xs::T + i::S +end + +function (g::∇getindex)(Δ) + Δ′ = zero(g.xs) + Δ′[g.i...] = Δ + (Δ′, map(_ -> nothing, g.i)...) +end + +@adjoint function (g::∇getindex)(Δ) + g(Δ), Δ′′->(nothing, Δ′′[1][g.i...]) +end + @adjoint function getindex(xs::Array, i...) - xs[i...], function (Δ) - Δ′ = zero(xs) - Δ′[i...] = Δ - (Δ′, map(_ -> nothing, i)...) - end + xs[i...], ∇getindex(xs, i) end @adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...), diff --git a/src/lib/base.jl b/src/lib/base.jl index fd582f5b7..1304d41ac 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -16,7 +16,7 @@ end d[k], function (Δ) grad = grad_mut(__context__, d) grad[k] = accum(get(grad, k, nothing), Δ) - return (grad, nothing) + return (nobacksies(:getindex, grad), nothing) end end diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 0944407b1..260a83671 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -18,6 +18,19 @@ accum(x::AbstractArray, y::AbstractArray) = accum.(x, y) Expr(:tuple, [:($f=accum(x.$f, $(grad(f)))) for f in fieldnames(x)]...) end +""" + accum_sum(Δ) + +`accum_sum` is to `sum` as `accum` is to `+` (i.e. accum_sum treats `nothing` +as a strong zero). +""" +accum_sum(Δ) = reduce(accum, Δ) +@adjoint accum_sum(Δ) = accum_sum(Δ), Δ -> (FillArray(Δ, size(xs)),) + +@adjoint function (T::Type{<:FillArray})(value, size) + T(value, size), Δ->(nothing, accum_sum(Δ), nothing) +end + # Core functions @nograd Core.apply_type, Core.typeof, nfields, fieldtype, @@ -41,18 +54,24 @@ end unwrap(x) = x -@adjoint unwrap(x) = unwrap(x), Δ -> accum_param(__context__, x, Δ) +@adjoint unwrap(x) = unwrap(x), Δ ->(accum_param(__context__, x, Δ); (Δ,)) + +nobacksies(s, x) = x +@adjoint nobacksies(s, x) = x, Δ->error("Nested AD not defined for $s") # Tuples @adjoint tuple(xs...) = xs, identity +tuple_at(Δ, i, N) = ntuple(j -> i == j ? Δ : nothing, Val(N)) +@adjoint tuple_at(Δ, i, N) = tuple_at(Δ, i, N), Δ′->(Δ′[i], nothing, nothing) + @adjoint getindex(xs::NTuple{N,Any}, i::Integer) where N = - (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) + (xs[i], Δ -> (tuple_at(Δ, i, N), nothing)) # Needed for iteration lowering @adjoint Core.getfield(xs::NTuple{N,Any}, i::Integer) where N = - (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) + (xs[i], Δ -> (tuple_at(Δ, i, N), nothing)) @adjoint function Base.first(xs::Tuple) drest = map(_->nothing, tail(xs)) @@ -79,7 +98,7 @@ unapply(t, xs) = _unapply(t, xs)[1] st = map(_empty, args) y, function (Δ) Δ = back(Δ) - (first(Δ), unapply(st, Base.tail(Δ))...) + (nobacksies(:apply, first(Δ)), unapply(st, Base.tail(Δ))...) end end @@ -94,23 +113,39 @@ function deref!(x::Ref) end @generated nt_nothing(x) = Expr(:tuple, [:($f=nothing) for f in fieldnames(x)]...) +@generated nt_nothing_type(::Type{T}) where {T} = Expr(:tuple, [:($f=nothing) for f in fieldnames(T)]...) @generated pair(::Val{k}, v) where k = :($k = v,) +struct ∇getfield{T} + f::Symbol +end + +function (g::∇getfield{T})(Δ) where {T} + ((;nt_nothing_type(T)...,pair(Val(g.f), Δ)...), nothing) +end + +@adjoint function (g::∇getfield)(Δ) + g(Δ), Δ′′->(nothing, getfield(Δ′′[1], g.f)) +end + # TODO make this inferrable # Right now constant prop is too fragile ... @adjoint function getfield(x, f::Symbol) val = getfield(x, f) - unwrap(val), function (Δ) - accum_param(__context__, val, Δ) - if isimmutable(x) - ((;nt_nothing(x)...,pair(Val(f), Δ)...), nothing) - else + back = if isimmutable(x) + g = ∇getfield{typeof(x)}(f) + isimmutable(val) ? g : Δ->(accum_param(__context__, val, Δ); g(Δ)) + else + # TODO: Nested AD for mutable structs + function (Δ) + accum_param(__context__, val, Δ) dx = getfield(grad_mut(__context__, x), f) dx[] = accum(dx[], Δ) return end end + unwrap(val), back end # ... so we have Zygote call this version where we can. diff --git a/src/lib/real.jl b/src/lib/real.jl index aee1274cb..ef3c9c453 100644 --- a/src/lib/real.jl +++ b/src/lib/real.jl @@ -19,6 +19,10 @@ end @adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), Δ -> (nothing, Δ) +for T in Base.uniontypes(Core.BuiltinInts) + @adjoint (::Type{T})(x::Core.BuiltinInts) = T(x), Δ -> (Δ,) +end + @adjoint Base.:+(xs...) = +(xs...), Δ -> map(_ -> Δ, xs) @adjoint function sincos(x) diff --git a/test/features.jl b/test/features.jl index 34ebf01fd..4a9fe32a7 100644 --- a/test/features.jl +++ b/test/features.jl @@ -240,3 +240,29 @@ end if VERSION >= v"1.1" @test Zygote.@code_adjoint(f(1)) isa Zygote.Adjoint end +@test Zygote.@code_adjoint(f(1)) isa Zygote.Adjoint + +# Basic nested +f_nested(x) = x^4 +@test f_nested''(1.0) = 12.0 + +# Nested AD for `sum` +@test gradient([1.0, 2.0]) do x + gradient(x) do x + sin(sum(x)) + end[1][1] +end == -sin(3.0) + +# Nested AD for getindex +@test gradient([1.0, 2.0]) do x + gradient(x) do x + sin(x[1]) + end[1][1] +end == -sin(1.0) + +# Third-order AD + +# Currently disabled pending improvements to Base and Zygote +if false + @test sin'''(1.0) == -sin(1.0) +end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index db9770631..e7fb872ff 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1,6 +1,7 @@ using Zygote, NNlib, Test, Random, LinearAlgebra using Zygote: gradient using NNlib: conv +import ForwardDiff import Random function ngradient(f, xs::AbstractArray...) @@ -233,3 +234,27 @@ end @test size(Zygote.gradient((x, y)->sum(x * y), randn(1, 1), randn(1, 10))[1]) == (1, 1) @test size(Zygote.gradient((x, y)->sum(x * y), randn(1, 1), randn(1, 10))[2]) == (1, 10) end + +# Currently disabled pending improvements in Zygote and Base +if false + @testset "third order AD (indexing)" begin + # Nested AD for getindex + grad_tracker = gradient([1.0, 2.0, 3.0]) do x + sum(gradient(x) do x + sum(gradient(x) do x + sum(x[1:2])^4 + end[1]) + end[1]) + end[1] + # We compare to ForwardDiff, since the high order derivative is not + # numerically stable under finite differencing. + grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x + sum(ForwardDiff.gradient(x) do x + sum(ForwardDiff.gradient(x) do x + sum(x[1:2])^4 + end) + end) + end + @test grad_tracker ≈ grad_forward ≈ [288.0, 288.0, 0.0] + end +end From e78e418d33de42d76ce5925569c402287a671e3a Mon Sep 17 00:00:00 2001 From: axsk Date: Mon, 12 Apr 2021 12:22:03 +0200 Subject: [PATCH 2/5] =?UTF-8?q?fix=20adjoint=20for=20=E2=88=87getindex?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/lib/array.jl | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 3bfad52f6..61aded918 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -32,16 +32,23 @@ end @adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds) -∇getindex(x::AbstractArray, inds) = dy -> begin - if inds isa NTuple{<:Any, Integer} - dx = _zero(x, typeof(dy)) - dx[inds...] = dy - else - dx = _zero(x, eltype(dy)) - dxv = view(dx, inds...) - dxv .= accum.(dxv, _droplike(dy, dxv)) - end - return (dx, map(_->nothing, inds)...) +∇getindex(x::AbstractArray, inds) = dy -> (_zerosetindex(x, inds, dy), map(_->nothing, inds)...) + +function _zerosetindex(x, inds::NTuple{<:Any, Integer}, dy) + dx = _zero(x, typeof(dy)) + dx[inds...] = dy + dx +end + +function _zerosetindex(x, inds, dy) + dx = _zero(x, eltype(dy)) + dxv = view(dx, inds...) + dxv .= accum.(dxv, _droplike(dy, dxv)) + dx +end + +@adjoint function _zerosetindex(x, inds, dy) + _zerosetindex(x, inds, dy), ddx -> (nothing, nothing, ddx[inds...]) end _zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs))) From 2a23756d1fe2c8ce13028056eae81199e0ce2ccd Mon Sep 17 00:00:00 2001 From: axsk Date: Mon, 12 Apr 2021 13:19:35 +0200 Subject: [PATCH 3/5] fix tests --- test/features.jl | 13 +++---------- test/utils.jl | 4 ++-- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/test/features.jl b/test/features.jl index ff674d863..2258fb0a1 100644 --- a/test/features.jl +++ b/test/features.jl @@ -485,25 +485,18 @@ end # Basic nested f_nested(x) = x^4 -@test f_nested''(1.0) = 12.0 +@test f_nested''(1.0) == 12.0 # Nested AD for `sum` @test gradient([1.0, 2.0]) do x gradient(x) do x sin(sum(x)) end[1][1] -end == -sin(3.0) +end[1][1] == -sin(3.0) # Nested AD for getindex @test gradient([1.0, 2.0]) do x gradient(x) do x sin(x[1]) end[1][1] -end == -sin(1.0) - -# Third-order AD - -# Currently disabled pending improvements to Base and Zygote -if false - @test sin'''(1.0) == -sin(1.0) -end +end[1][1] == -sin(1.0) diff --git a/test/utils.jl b/test/utils.jl index d09fc2dc2..9d7cf32d8 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -7,8 +7,8 @@ using Zygote: hessian_dual, hessian_reverse @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] # original docstring version else - @test_broken hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] # can't differentiate ∇getindex - @test_broken hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] + @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] # can't differentiate ∇getindex + @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] end @test hess(x -> sum(x.^3), [1 2; 3 4]) ≈ Diagonal([6, 18, 12, 24]) @test hess(sin, pi/2) ≈ -1 From c665e6e4e5568476fc9bb0593263073b030d5fe5 Mon Sep 17 00:00:00 2001 From: axsk Date: Mon, 12 Apr 2021 13:28:51 +0200 Subject: [PATCH 4/5] remove duplicate test --- test/features.jl | 1 - test/gradcheck.jl | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/test/features.jl b/test/features.jl index 2258fb0a1..f60bfe2f4 100644 --- a/test/features.jl +++ b/test/features.jl @@ -481,7 +481,6 @@ end Zygote.gradient(loss_adjoint,[1.0]) @test x[1] == x[2] end -@test Zygote.@code_adjoint(f(1)) isa Zygote.Adjoint # Basic nested f_nested(x) = x^4 diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 30699f8a2..c7147712e 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1696,4 +1696,4 @@ if false end @test grad_tracker ≈ grad_forward ≈ [288.0, 288.0, 0.0] end -end \ No newline at end of file +end From 0079c661c951a6235ed14325f3a456806cb0c0f6 Mon Sep 17 00:00:00 2001 From: axsk Date: Mon, 12 Apr 2021 16:09:00 +0200 Subject: [PATCH 5/5] cleanup duplicate hess tests --- test/utils.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index 9d7cf32d8..38eac121c 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -3,13 +3,8 @@ using Zygote: hessian_dual, hessian_reverse @testset "hessian: $hess" for hess in [hessian_dual, hessian_reverse] - if hess == hessian_dual - @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] - @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] # original docstring version - else - @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] # can't differentiate ∇getindex - @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] - end + @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] + @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] # original docstring version @test hess(x -> sum(x.^3), [1 2; 3 4]) ≈ Diagonal([6, 18, 12, 24]) @test hess(sin, pi/2) ≈ -1