diff --git a/src/lib/array.jl b/src/lib/array.jl index 329676648..aed553f92 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -7,12 +7,23 @@ Base.zero(xs::AbstractArray{Any}) = fill!(similar(xs), nothing) +struct ∇getindex{T,S} + xs::T + i::S +end + +function (g::∇getindex)(Δ) + Δ′ = zero(g.xs) + Δ′[g.i...] = Δ + (Δ′, map(_ -> nothing, g.i)...) +end + +@adjoint function (g::∇getindex)(Δ) + g(Δ), Δ′′->(nothing, Δ′′[1][g.i...]) +end + @adjoint function getindex(xs::Array, i...) - xs[i...], function (Δ) - Δ′ = zero(xs) - Δ′[i...] = Δ - (Δ′, map(_ -> nothing, i)...) - end + xs[i...], ∇getindex(xs, i) end @adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...), diff --git a/src/lib/base.jl b/src/lib/base.jl index fd582f5b7..1304d41ac 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -16,7 +16,7 @@ end d[k], function (Δ) grad = grad_mut(__context__, d) grad[k] = accum(get(grad, k, nothing), Δ) - return (grad, nothing) + return (nobacksies(:getindex, grad), nothing) end end diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 0944407b1..260a83671 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -18,6 +18,19 @@ accum(x::AbstractArray, y::AbstractArray) = accum.(x, y) Expr(:tuple, [:($f=accum(x.$f, $(grad(f)))) for f in fieldnames(x)]...) end +""" + accum_sum(Δ) + +`accum_sum` is to `sum` as `accum` is to `+` (i.e. accum_sum treats `nothing` +as a strong zero). +""" +accum_sum(Δ) = reduce(accum, Δ) +@adjoint accum_sum(Δ) = accum_sum(Δ), Δ -> (FillArray(Δ, size(xs)),) + +@adjoint function (T::Type{<:FillArray})(value, size) + T(value, size), Δ->(nothing, accum_sum(Δ), nothing) +end + # Core functions @nograd Core.apply_type, Core.typeof, nfields, fieldtype, @@ -41,18 +54,24 @@ end unwrap(x) = x -@adjoint unwrap(x) = unwrap(x), Δ -> accum_param(__context__, x, Δ) +@adjoint unwrap(x) = unwrap(x), Δ ->(accum_param(__context__, x, Δ); (Δ,)) + +nobacksies(s, x) = x +@adjoint nobacksies(s, x) = x, Δ->error("Nested AD not defined for $s") # Tuples @adjoint tuple(xs...) = xs, identity +tuple_at(Δ, i, N) = ntuple(j -> i == j ? Δ : nothing, Val(N)) +@adjoint tuple_at(Δ, i, N) = tuple_at(Δ, i, N), Δ′->(Δ′[i], nothing, nothing) + @adjoint getindex(xs::NTuple{N,Any}, i::Integer) where N = - (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) + (xs[i], Δ -> (tuple_at(Δ, i, N), nothing)) # Needed for iteration lowering @adjoint Core.getfield(xs::NTuple{N,Any}, i::Integer) where N = - (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) + (xs[i], Δ -> (tuple_at(Δ, i, N), nothing)) @adjoint function Base.first(xs::Tuple) drest = map(_->nothing, tail(xs)) @@ -79,7 +98,7 @@ unapply(t, xs) = _unapply(t, xs)[1] st = map(_empty, args) y, function (Δ) Δ = back(Δ) - (first(Δ), unapply(st, Base.tail(Δ))...) + (nobacksies(:apply, first(Δ)), unapply(st, Base.tail(Δ))...) end end @@ -94,23 +113,39 @@ function deref!(x::Ref) end @generated nt_nothing(x) = Expr(:tuple, [:($f=nothing) for f in fieldnames(x)]...) +@generated nt_nothing_type(::Type{T}) where {T} = Expr(:tuple, [:($f=nothing) for f in fieldnames(T)]...) @generated pair(::Val{k}, v) where k = :($k = v,) +struct ∇getfield{T} + f::Symbol +end + +function (g::∇getfield{T})(Δ) where {T} + ((;nt_nothing_type(T)...,pair(Val(g.f), Δ)...), nothing) +end + +@adjoint function (g::∇getfield)(Δ) + g(Δ), Δ′′->(nothing, getfield(Δ′′[1], g.f)) +end + # TODO make this inferrable # Right now constant prop is too fragile ... @adjoint function getfield(x, f::Symbol) val = getfield(x, f) - unwrap(val), function (Δ) - accum_param(__context__, val, Δ) - if isimmutable(x) - ((;nt_nothing(x)...,pair(Val(f), Δ)...), nothing) - else + back = if isimmutable(x) + g = ∇getfield{typeof(x)}(f) + isimmutable(val) ? g : Δ->(accum_param(__context__, val, Δ); g(Δ)) + else + # TODO: Nested AD for mutable structs + function (Δ) + accum_param(__context__, val, Δ) dx = getfield(grad_mut(__context__, x), f) dx[] = accum(dx[], Δ) return end end + unwrap(val), back end # ... so we have Zygote call this version where we can. diff --git a/src/lib/real.jl b/src/lib/real.jl index aee1274cb..ef3c9c453 100644 --- a/src/lib/real.jl +++ b/src/lib/real.jl @@ -19,6 +19,10 @@ end @adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), Δ -> (nothing, Δ) +for T in Base.uniontypes(Core.BuiltinInts) + @adjoint (::Type{T})(x::Core.BuiltinInts) = T(x), Δ -> (Δ,) +end + @adjoint Base.:+(xs...) = +(xs...), Δ -> map(_ -> Δ, xs) @adjoint function sincos(x) diff --git a/test/features.jl b/test/features.jl index 34ebf01fd..4a9fe32a7 100644 --- a/test/features.jl +++ b/test/features.jl @@ -240,3 +240,29 @@ end if VERSION >= v"1.1" @test Zygote.@code_adjoint(f(1)) isa Zygote.Adjoint end +@test Zygote.@code_adjoint(f(1)) isa Zygote.Adjoint + +# Basic nested +f_nested(x) = x^4 +@test f_nested''(1.0) = 12.0 + +# Nested AD for `sum` +@test gradient([1.0, 2.0]) do x + gradient(x) do x + sin(sum(x)) + end[1][1] +end == -sin(3.0) + +# Nested AD for getindex +@test gradient([1.0, 2.0]) do x + gradient(x) do x + sin(x[1]) + end[1][1] +end == -sin(1.0) + +# Third-order AD + +# Currently disabled pending improvements to Base and Zygote +if false + @test sin'''(1.0) == -sin(1.0) +end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index db9770631..e7fb872ff 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1,6 +1,7 @@ using Zygote, NNlib, Test, Random, LinearAlgebra using Zygote: gradient using NNlib: conv +import ForwardDiff import Random function ngradient(f, xs::AbstractArray...) @@ -233,3 +234,27 @@ end @test size(Zygote.gradient((x, y)->sum(x * y), randn(1, 1), randn(1, 10))[1]) == (1, 1) @test size(Zygote.gradient((x, y)->sum(x * y), randn(1, 1), randn(1, 10))[2]) == (1, 10) end + +# Currently disabled pending improvements in Zygote and Base +if false + @testset "third order AD (indexing)" begin + # Nested AD for getindex + grad_tracker = gradient([1.0, 2.0, 3.0]) do x + sum(gradient(x) do x + sum(gradient(x) do x + sum(x[1:2])^4 + end[1]) + end[1]) + end[1] + # We compare to ForwardDiff, since the high order derivative is not + # numerically stable under finite differencing. + grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x + sum(ForwardDiff.gradient(x) do x + sum(ForwardDiff.gradient(x) do x + sum(x[1:2])^4 + end) + end) + end + @test grad_tracker ≈ grad_forward ≈ [288.0, 288.0, 0.0] + end +end