diff --git a/Project.toml b/Project.toml index cc28bbebe..c71dee4c4 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.5.1" AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Future = "9fa8497b-333b-5362-9e8d-4d0656e87820" @@ -14,8 +15,10 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/src/Zygote.jl b/src/Zygote.jl index b10f5fbd1..6722ce880 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -11,10 +11,14 @@ using IRTools using MacroTools, Requires using MacroTools: @forward -export Params, gradient, pullback, @code_grad +export Params, gradient, pullback, pushforward, @code_grad include("tools/idset.jl") include("tools/buffer.jl") +include("tools/builtins.jl") + +include("forward/Forward.jl") +using .Forward include("compiler/reverse.jl") include("compiler/emit.jl") diff --git a/src/forward/Forward.jl b/src/forward/Forward.jl new file mode 100644 index 000000000..1478c59af --- /dev/null +++ b/src/forward/Forward.jl @@ -0,0 +1,15 @@ +module Forward + +import ..Zygote +import ..Zygote: __new__, __splatnew__ + +export pushforward + +include("compiler.jl") +include("interface.jl") +include("lib.jl") +include("number.jl") +include("array.jl") +include("broadcast.jl") + +end diff --git a/src/forward/array.jl b/src/forward/array.jl new file mode 100644 index 000000000..0eb2edad2 --- /dev/null +++ b/src/forward/array.jl @@ -0,0 +1,27 @@ +using LinearAlgebra + +zerolike(x::AbstractArray) = zerolike.(x) + +@tangent function (A::Type{<:Array})(::UndefInitializer, sz::Integer...) + x = A(UndefInitializer(), sz...) + x, (_...) -> zerolike(x) +end + +@tangent length(A::AbstractArray) = length(A), _ -> 0 +@tangent size(A::AbstractArray, i::Integer) = size(A, i), (_, _) -> 0 +@tangent size(A::AbstractArray) = size(A), _ -> zerolike(size(A)) + +@tangent Base.vect(xs...) = Base.vect(xs...), Base.vect + +@tangent fill(x, dims::Tuple) = fill(x, dims), (ẋ, _) -> fill(ẋ, dims) + +@tangent first(x) = first(x), first + +@tangent setindex!(x::AbstractArray, v, inds...) = + setindex!(x, v, inds...), (ẋ, v̇, _...) -> setindex!(ẋ, v̇, inds...) + +@tangent mul!(C, A, B) = mul!(C, A, B), (Ċ, Ȧ, Ḃ) -> Ċ .= Ȧ*B .+ A*Ḃ + +@tangent A::AbstractArray * B::AbstractArray = A*B, (Ȧ, Ḃ) -> Ȧ*B .+ A*Ḃ + +@tangent sum(x; dims = :) = sum(x; dims = dims), ẋ -> sum(ẋ, dims = dims) diff --git a/src/forward/broadcast.jl b/src/forward/broadcast.jl new file mode 100644 index 000000000..da2705a63 --- /dev/null +++ b/src/forward/broadcast.jl @@ -0,0 +1,20 @@ +using Base.Broadcast: AbstractArrayStyle, broadcasted + +Numeric{T<:Number} = Union{T,AbstractArray{<:T}} + +@tangent Broadcast.preprocess(dest, bc) = + Broadcast.preprocess(dest, bc), (ddest, dbc) -> dbc + +@tangent broadcasted(::typeof(identity), x::Numeric) = x, (_, ẋ) -> ẋ + +@tangent broadcasted(::typeof(+), xs::Numeric...) = + broadcast(+, xs...), (_, ẋs...) -> broadcast(+, ẋs...) + +@tangent function broadcasted(::typeof(tanh), x::Numeric) + y = tanh.(x) + y, (_, ẋ) -> ẋ .* (1 .- y.^2) +end + +@tangent function broadcasted(::AbstractArrayStyle, f, args...) + error("Generic broadcast of $f not supported yet") +end diff --git a/src/forward/compiler.jl b/src/forward/compiler.jl new file mode 100644 index 000000000..c171e74ae --- /dev/null +++ b/src/forward/compiler.jl @@ -0,0 +1,79 @@ +using IRTools.All +using IRTools: Pipe +using Base: tail + +ntail(x, n) = n <= 0 ? x : xcall(:tail, ntail(x, n-1)) + +function instrument!(pr, v, st) + ex = st.expr + if isexpr(ex, :new) + st = stmt(st, expr = xcall(Zygote, :__new__, ex.args...)) + pr[v] = st + elseif isexpr(ex, :splatnew) + st = stmt(st, expr = xcall(Zygote, :__splatnew__, ex.args...)) + pr[v] = st + end + return st +end + +function dual(ir) + args = copy(arguments(ir)) + dx = argument!(ir, at = 1) + Δs = Dict() + for bl in blocks(ir)[2:end], arg in copy(arguments(bl)) + Δs[arg] = argument!(bl, insert = false) + end + pr = Pipe(ir) + partial(x::Variable) = Δs[x] + partial(x) = push!(pr, xcall(Forward, :zerolike, x)) + partial(v, x::Variable) = Δs[x] + partial(v, x) = insert!(pr, v, xcall(Forward, :zerolike, x)) + for (i, x) in enumerate(args) + if i == length(args) && ir.meta.method.isva + Δs[x] = push!(pr, ntail(dx, i-1)) + else + Δs[x] = push!(pr, xcall(:getindex, dx, i)) + end + end + branches(pr) do br + args = arguments(br) + if isreturn(br) + args[1] = push!(pr, xcall(:tuple, args[1], partial(args[1]))) + else + for arg in copy(args) + push!(args, partial(arg)) + end + end + br + end + for (v, st) in pr + st = instrument!(pr, v, st) + if isexpr(st.expr, :meta, :inbounds, :loopinfo) + Δs[v] = nothing + elseif isexpr(st.expr, :boundscheck) || + (isexpr(st.expr, :call) && st.expr.args[1] == GlobalRef(Base, :not_int)) || + (isexpr(st.expr, :call) && st.expr.args[1] == GlobalRef(Core, :(===))) || + (isexpr(st.expr, :call) && st.expr.args[1] == GlobalRef(Main, :(===))) + Δs[v] = false + elseif isexpr(st.expr, :call) + dargs = insert!(pr, v, xcall(:tuple, partial.((v,), st.expr.args)...)) + result = insert!(pr, v, stmt(st, expr = xcall(Forward, :_pushforward, dargs, st.expr.args...))) + pr[v] = xcall(:getindex, result, 1) + Δs[v] = push!(pr, xcall(:getindex, result, 2)) + elseif !isexpr(st.expr) + Δs[v] = push!(pr, xcall(Forward, :zerolike, v)) + else + error("Unsupported $(st.expr.head) expression") + end + end + ir = finish(pr) + return ir +end + +@dynamo function _pushforward(_, x...) + ir = IR(x...) + ir == nothing && return :(error("non-differentiable function $(args[2])")) + ir = Zygote.instrument(ir) + ir.meta.code.inlineable = true + return dual(ir) +end diff --git a/src/forward/interface.jl b/src/forward/interface.jl new file mode 100644 index 000000000..3286426a0 --- /dev/null +++ b/src/forward/interface.jl @@ -0,0 +1,49 @@ +using MacroTools: @capture, @q, shortdef +using ZygoteRules: named, typeless, isvararg +using Base: tail + +drop(x, n) = n == 0 ? x : :(tail($(drop(x, n-1)))) +drop(n) = x -> drop(x, n) + +# TODO: move to ZygoteRules +function tangent end + +function gradm(ex) + @capture(shortdef(ex), (name_(args__) = body_) | + (name_(args__) where {Ts__} = body_)) || error("Need a function definition") + kw = length(args) > 1 && isexpr(args[1], :parameters) ? esc(popfirst!(args)) : nothing + isclosure = isexpr(name, :(::)) && length(name.args) > 1 + f, T = isexpr(name, :(::)) ? + (length(name.args) == 1 ? (esc(gensym()), esc(name.args[1])) : esc.(name.args)) : + (esc(gensym()), :(Core.Typeof($(esc(name))))) + kT = :(Core.kwftype($T)) + Ts == nothing && (Ts = []) + args = named.(args) + argnames = Any[typeless(arg) for arg in args] + !isempty(args) && isvararg(args[end]) && (argnames[end] = :($(argnames[end])...,)) + args = esc.(args) + argnames = esc.(argnames) + Ts = esc.(Ts) + fargs = kw == nothing ? [:($f::$T), args...] : [kw, :($f::$T), args...] + dropg = isclosure ? identity : drop(1) + dropkw = isclosure ? drop(2) : drop(3) + adj = @q @inline Zygote.Forward.tangent($(fargs...)) where $(Ts...) = $(esc(body)) + quote + $adj + @inline function Zygote.Forward._pushforward(partials, $f::$T, $(args...)) where $(Ts...) + y, forw = tangent($f, $(argnames...)) + return y, forw($(dropg(:partials))...) + end + @inline function Zygote.Forward._pushforward(dargs, ::$kT, kw, $f::$T, $(args...)) where $(Ts...) + y, forw = tangent($f, $(argnames...)) + return y, forw($(dropkw(:partials))...) + end + nothing + end +end + +macro tangent(ex) + gradm(ex) +end + +pushforward(f, x...) = (ẋ...) -> _pushforward((zerolike(f), ẋ...), f, x...)[2] diff --git a/src/forward/lib.jl b/src/forward/lib.jl new file mode 100644 index 000000000..ecfc4c511 --- /dev/null +++ b/src/forward/lib.jl @@ -0,0 +1,81 @@ +zerolike(x::Number) = zero(x) +zerolike(x::Tuple) = zerolike.(x) + +@generated function zerolike(x::T) where T + length(fieldnames(T)) == 0 ? nothing : + :(NamedTuple{$(fieldnames(T))}(($(map(f -> :(zerolike(x.$f)), fieldnames(T))...),))) +end + +# TODO figure out why this made a test fail +zerolike(x::Union{Module,Type}) = nothing + +# TODO: `@nograd` and `@linear` + +@tangent zerolike(x) = zerolike(x), _ -> zerolike(x) +@tangent one(x::Number) = one(x), _ -> zero(x) +@tangent one(T::Type) = one(T), _ -> zero(T) +@tangent Core.Typeof(x) = Core.Typeof(x), _ -> nothing +@tangent typeof(x) = typeof(x), _ -> nothing +@tangent Core.apply_type(args...) = Core.apply_type(args...), (_...) -> nothing +@tangent fieldtype(args...) = fieldtype(args...), (_...) -> nothing +@tangent isa(a, b) = isa(a, b), (_, _) -> false +@tangent repr(x) = repr(x), _ -> nothing +@tangent println(x...) = println(x...), (_...) -> nothing +@tangent typeassert(x, T) = typeassert(x, T), (ẋ, _) -> ẋ +@tangent fieldnames(T) = fieldnames(T), _ -> zerolike(fieldnames(T)) +@tangent eltype(x) = eltype(x), ẋ -> zerolike(eltype(ẋ)) + +@tangent fieldcount(T) = fieldcount(T), _ -> zerolike(fieldcount(T)) + +@tangent tuple(t...) = t, (ṫ...) -> ṫ +@tangent tail(t) = tail(t), tail + +@tangent setfield!(t, i, x) = setfield!(t, i, x), (ṫ, _, ẋ) -> setfield!(ṫ, i, ẋ) +@tangent getindex(t, i) = getindex(t, i), (ṫ, _) -> getindex(ṫ, i) +@tangent isdefined(t, i) = isdefined(t, i), (_, _) -> false + +# TODO should be using a context for this +zerolike(x::Core.Box) = isdefined(x, :contents) ? Core.Box(zerolike(x.contents)) : Core.Box() +@tangent Core.Box() = Core.Box(), () -> Core.Box() +@tangent Core.Box(x) = Core.Box(x), ẋ -> Core.Box(x) +# TODO: this is too generic (e.g. broadcast). +# @tangent Base.copy(x) = copy(@show x), ẋ -> copy(ẋ) + +@tangent Core.Compiler.return_type(args...) = + Core.Compiler.return_type(args...), (_...) -> nothing + +@tangent __new__(T, s...) = + __new__(T, s...), (_, ṡ...) -> NamedTuple{fieldnames(T)}(ṡ) + +@tangent __splatnew__(T, s) = + __splatnew__(T, s), (_, ṡ) -> NamedTuple{fieldnames(T)}(ṡ) + +function _pushforward(dargs, ::typeof(Core._apply), f, args...) + dargs = tail(dargs) # drop self gradient + df, dargs = first(dargs), tail(dargs) + dargs = Core._apply(tuple, dargs...) + Core._apply(_pushforward, ((df, dargs...), f), args...) +end + +if VERSION >= v"1.4.0-DEV.304" + _pushforward(dargs, ::typeof(Core._apply_iterate), ::typeof(iterate), f, args...) = + _pushforward((first(args), tail(tail(dargs))...), Core._apply, f, args...) +end + +using ..Zygote: literal_getproperty, literal_getindex + +_pushforward(dargs, ::typeof(getproperty), x, f) = + _pushforward(dargs, literal_getproperty, x, Val(f)) + +@tangent function literal_getproperty(t, ::Val{i}) where i + y = getproperty(t, i) + forw(ṫ, _) = getproperty(ṫ, i) + forw(ṫ::Nothing, _) = zerolike(y) + return y, forw +end + +@tangent literal_getindex(t, ::Val{i}) where i = + getindex(t, i), (ṫ, _) -> getindex(ṫ, i) + +@tangent getfield(t::Tuple, i::Integer) = + getfield(t, i), (ṫ, _) -> getfield(ṫ, i) diff --git a/src/forward/number.jl b/src/forward/number.jl new file mode 100644 index 000000000..db88af656 --- /dev/null +++ b/src/forward/number.jl @@ -0,0 +1,49 @@ +using DiffRules, SpecialFunctions, NaNMath +using Base.FastMath: fast_op, make_fastmath + +# TODO use CSE here + +for (M, f, arity) in DiffRules.diffrules() + arity == 1 || continue + dx = DiffRules.diffrule(M, f, :x) + @eval begin + @tangent $M.$f(x::Number) = $M.$f(x), ẋ -> ẋ * $dx + end +end + +for (M, f, arity) in DiffRules.diffrules() + arity == 2 || continue + da, db = DiffRules.diffrule(M, f, :a, :b) + @eval begin + @tangent $M.$f(a::Number, b::Number) = $M.$f(a, b), (ȧ, ḃ) -> ȧ*$da + ḃ*$db + end +end + +# Some specific overrides +# The DiffRules definitions are suboptimal due to repeated work in the tangent + +@tangent function tanh(x) + y = tanh(x) + y, ẋ -> ẋ * (1 - y^2) +end + +@tangent function exp(x) + y = exp(x) + y, ẋ -> ẋ * y +end + +for f in [>, <, ==, ===, !=, in] + @eval @tangent $f(a, b) = $f(a, b), (_, _) -> false +end + +@tangent convert(T::Type{<:Real}, x::Real) = convert(T, x), (_, ẋ) -> convert(T, ẋ) + +@tangent function Colon()(xs...) + c = Colon()(xs...) + c, (_...) -> zerolike(c) +end + +zerolike(x::AbstractRange) = + invoke(zerolike, Tuple{Any}, x) + +DiffRules._abs_deriv(x::Complex) = x/abs(x) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index c1712cf32..3447cdd20 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -91,10 +91,6 @@ using Base: tail @adjoint tuple(xs...) = xs, identity -literal_getindex(x, ::Val{i}) where i = getindex(x, i) -literal_indexed_iterate(x, ::Val{i}) where i = Base.indexed_iterate(x, i) -literal_indexed_iterate(x, ::Val{i}, state) where i = Base.indexed_iterate(x, i, state) - @adjoint function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i} val = xs[i] function back(Δ) @@ -248,20 +244,6 @@ end end end -@generated function __new__(T, args...) - quote - Base.@_inline_meta - $(Expr(:new, :T, [:(args[$i]) for i = 1:length(args)]...)) - end -end - -@generated function __splatnew__(T, args) - quote - Base.@_inline_meta - $(Expr(:splatnew, :T, :args)) - end -end - struct Jnew{T,G,splat} g::G end diff --git a/src/tools/builtins.jl b/src/tools/builtins.jl new file mode 100644 index 000000000..6d0daf57c --- /dev/null +++ b/src/tools/builtins.jl @@ -0,0 +1,17 @@ +@generated function __new__(T, args...) + quote + Base.@_inline_meta + $(Expr(:new, :T, [:(args[$i]) for i = 1:length(args)]...)) + end +end + +@generated function __splatnew__(T, args) + quote + Base.@_inline_meta + $(Expr(:splatnew, :T, :args)) + end +end + +literal_getindex(x, ::Val{i}) where i = getindex(x, i) +literal_indexed_iterate(x, ::Val{i}) where i = Base.indexed_iterate(x, i) +literal_indexed_iterate(x, ::Val{i}, state) where i = Base.indexed_iterate(x, i, state) diff --git a/test/forward/forward.jl b/test/forward/forward.jl new file mode 100644 index 000000000..97c1efb9e --- /dev/null +++ b/test/forward/forward.jl @@ -0,0 +1,50 @@ +using Zygote, Test +using NNlib: relu + +D(f, x) = pushforward(f, x)(1) + +@test D(x -> sin(cos(x)), 0.5) == -cos(cos(0.5))*sin(0.5) + +@test D(x -> D(cos, x), 0.5) == -cos(0.5) + +@test D(x -> x*D(y -> x*y, 1), 4) == 8 + +function pow(x, n) + r = 1 + while n > 0 + n -= 1 + r *= x + end + return r +end + +@test D(x -> pow(x, 3), 2) == 12 + +@test D(1) do x + f(y) = x = x*y + D(f, 1) + D(f, 1) +end == 1 + +@test D(x -> D(y -> x = y, x)*x, 1) == 1 + +@test D(1) do x + D(2) do y + D(3) do z + x = z * y + end + end + x +end == 0 + +@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1)[1] + +using LinearAlgebra + +@test D(3) do x + A = zeros(5, 5) + B = zeros(5, 5) + A[1, 1] = x + mul!(B, A, A) + sum(B) +end == 6 diff --git a/test/runtests.jl b/test/runtests.jl index 409f6c54e..5e7ca4e62 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,16 +2,16 @@ using Zygote, Test using Zygote: gradient using CUDA: has_cuda -@testset "Interface" begin +@testset "Interface" begin include("interface.jl") end -@testset "Tools" begin +@testset "Tools" begin include("tools.jl") end -@testset "lib/number" begin +@testset "lib/number" begin include("lib/number.jl") end @@ -19,6 +19,10 @@ end include("features.jl") end +@testset "Forward" begin + include("forward/forward.jl") +end + @testset "Data Structures" begin include("structures.jl") end