-
-
Notifications
You must be signed in to change notification settings - Fork 218
Forward Mode #503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Forward Mode #503
Changes from 8 commits
81a5efc
9d5e571
d4f34ba
a3b404e
bc1cc83
a02e888
68a3a5e
0e8badc
9c7c2f0
256e9cc
84c5d30
10121c1
365d2f1
72543bb
5cf6a2c
6b73a2c
ff742be
8859c01
2a02d4c
05f4043
08d89b7
2178efe
5850d96
e25ac87
e5b31ba
bf8bf7b
1413b83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| module Forward | ||
|
|
||
| import ..Zygote | ||
| import ..Zygote: __new__, __splatnew__ | ||
|
|
||
| export pushforward | ||
|
|
||
| include("compiler.jl") | ||
| include("interface.jl") | ||
| include("lib.jl") | ||
| include("number.jl") | ||
| include("array.jl") | ||
|
|
||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| using LinearAlgebra | ||
|
|
||
| zerolike(x::AbstractArray) = zerolike.(x) | ||
|
|
||
| @tangent function (A::Type{<:Array})(::UndefInitializer, sz::Integer...) | ||
| x = A(UndefInitializer(), sz...) | ||
| x, (_...) -> zerolike(x) | ||
| end | ||
|
|
||
| @tangent length(A::AbstractArray) = length(A), _ -> 0 | ||
| @tangent size(A::AbstractArray, i::Integer) = size(A, i), (_, _) -> 0 | ||
| @tangent size(A::AbstractArray) = size(A), _ -> zerolike(size(A)) | ||
|
|
||
| @tangent Base.vect(xs...) = Base.vect(xs...), Base.vect | ||
|
|
||
| @tangent first(x) = first(x), ẋ -> first(ẋ) | ||
|
|
||
| @tangent setindex!(x::AbstractArray, v, inds...) = | ||
| setindex!(x, v, inds...), (ẋ, v̇, _...) -> setindex!(ẋ, v̇, inds...) | ||
|
|
||
| @tangent mul!(C, A, B) = mul!(C, A, B), (Ċ, Ȧ, Ḃ) -> Ċ .+= Ȧ*B .+ A*Ḃ | ||
MikeInnes marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @tangent A::AbstractArray * B::AbstractArray = A*B, (Ȧ, Ḃ) -> Ȧ*B .+ A*Ḃ | ||
|
|
||
| @tangent sum(x; dims = :) = sum(x; dims = dims), ẋ -> sum(x, dims = dims) | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| using IRTools.All | ||
| using IRTools: Pipe | ||
| using Base: tail | ||
|
|
||
| ntail(x, n) = n <= 0 ? x : xcall(:tail, ntail(x, n-1)) | ||
|
|
||
| function instrument!(pr, v, st) | ||
| ex = st.expr | ||
| if isexpr(ex, :new) | ||
| st = stmt(st, expr = xcall(Zygote, :__new__, ex.args...)) | ||
| pr[v] = st | ||
| elseif isexpr(ex, :splatnew) | ||
| st = stmt(st, expr = xcall(Zygote, :__splatnew__, ex.args...)) | ||
| pr[v] = st | ||
| end | ||
| return st | ||
| end | ||
|
|
||
| function dual(ir) | ||
| args = copy(arguments(ir)) | ||
| dx = argument!(ir, at = 1) | ||
| Δs = Dict() | ||
| for bl in blocks(ir)[2:end], arg in copy(arguments(bl)) | ||
| Δs[arg] = argument!(bl, insert = false) | ||
| end | ||
| pr = Pipe(ir) | ||
| partial(x::Variable) = Δs[x] | ||
| partial(x) = push!(pr, xcall(Forward, :zerolike, x)) | ||
| partial(v, x::Variable) = Δs[x] | ||
| partial(v, x) = insert!(pr, v, xcall(Forward, :zerolike, x)) | ||
| for (i, x) in enumerate(args) | ||
| if i == length(args) && ir.meta.method.isva | ||
| Δs[x] = push!(pr, ntail(dx, i-1)) | ||
| else | ||
| Δs[x] = push!(pr, xcall(:getindex, dx, i)) | ||
| end | ||
| end | ||
| branches(pr) do br | ||
| args = arguments(br) | ||
| if isreturn(br) | ||
| args[1] = push!(pr, xcall(:tuple, args[1], partial(args[1]))) | ||
| else | ||
| for arg in copy(args) | ||
| push!(args, partial(arg)) | ||
| end | ||
| end | ||
| br | ||
| end | ||
| for (v, st) in pr | ||
| st = instrument!(pr, v, st) | ||
| if isexpr(st.expr, :meta, :inbounds, :loopinfo) | ||
| Δs[v] = nothing | ||
| elseif isexpr(st.expr, :boundscheck) || | ||
| (isexpr(st.expr, :call) && st.expr.args[1] == GlobalRef(Base, :not_int)) | ||
| Δs[v] = false | ||
| elseif isexpr(st.expr, :call) | ||
| dargs = insert!(pr, v, xcall(:tuple, partial.((v,), st.expr.args)...)) | ||
| result = insert!(pr, v, stmt(st, expr = xcall(Forward, :_pushforward, dargs, st.expr.args...))) | ||
| pr[v] = xcall(:getindex, result, 1) | ||
| Δs[v] = push!(pr, xcall(:getindex, result, 2)) | ||
| elseif !isexpr(st.expr) | ||
| Δs[v] = push!(pr, xcall(Forward, :zerolike, v)) | ||
| else | ||
| error("Unsupported $(st.expr.head) expression") | ||
| end | ||
| end | ||
| ir = finish(pr) | ||
| return ir | ||
| end | ||
|
|
||
| @dynamo function _pushforward(_, x...) | ||
| ir = IR(x...) | ||
| ir == nothing && return :(error("non-differentiable function $(args[2])")) | ||
| return dual(ir) | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| using MacroTools: @capture, @q, shortdef | ||
| using ZygoteRules: named, typeless, isvararg | ||
| using Base: tail | ||
|
|
||
| drop(x, n) = n == 0 ? x : :(tail($(drop(x, n-1)))) | ||
| drop(n) = x -> drop(x, n) | ||
|
|
||
| # TODO: move to ZygoteRules | ||
| function tangent end | ||
|
|
||
| function gradm(ex) | ||
| @capture(shortdef(ex), (name_(args__) = body_) | | ||
| (name_(args__) where {Ts__} = body_)) || error("Need a function definition") | ||
| kw = length(args) > 1 && isexpr(args[1], :parameters) ? esc(popfirst!(args)) : nothing | ||
| isclosure = isexpr(name, :(::)) && length(name.args) > 1 | ||
| f, T = isexpr(name, :(::)) ? | ||
| (length(name.args) == 1 ? (esc(gensym()), esc(name.args[1])) : esc.(name.args)) : | ||
| (esc(gensym()), :(Core.Typeof($(esc(name))))) | ||
| kT = :(Core.kwftype($T)) | ||
| Ts == nothing && (Ts = []) | ||
| args = named.(args) | ||
| argnames = Any[typeless(arg) for arg in args] | ||
| !isempty(args) && isvararg(args[end]) && (argnames[end] = :($(argnames[end])...,)) | ||
| args = esc.(args) | ||
| argnames = esc.(argnames) | ||
| Ts = esc.(Ts) | ||
| fargs = kw == nothing ? [:($f::$T), args...] : [kw, :($f::$T), args...] | ||
| dropg = isclosure ? identity : drop(1) | ||
| dropkw = isclosure ? drop(2) : drop(3) | ||
| adj = @q @inline Zygote.Forward.tangent($(fargs...)) where $(Ts...) = $(esc(body)) | ||
| quote | ||
| $adj | ||
| @inline function Zygote.Forward._pushforward(partials, $f::$T, $(args...)) where $(Ts...) | ||
| y, forw = tangent($f, $(argnames...)) | ||
| return y, forw($(dropg(:partials))...) | ||
| end | ||
| @inline function Zygote.Forward._pushforward(dargs, ::$kT, kw, $f::$T, $(args...)) where $(Ts...) | ||
| y, forw = tangent($f, $(argnames...)) | ||
| return y, forw($(dropkw(:partials))...) | ||
| end | ||
| nothing | ||
| end | ||
| end | ||
|
|
||
| macro tangent(ex) | ||
| gradm(ex) | ||
| end | ||
|
|
||
| pushforward(f, x...) = (ẋ...) -> _pushforward((zerolike(f), ẋ...), f, x...)[2] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| zerolike(x::Number) = zero(x) | ||
| zerolike(x::Tuple) = zerolike.(x) | ||
| zerolike(x::T) where T = | ||
| NamedTuple{fieldnames(T)}(map(f -> zerolike(getfield(x, f)), fieldnames(T))) | ||
| # TODO figure out why this made a test fail | ||
| zerolike(x::Union{Module,Type}) = false | ||
|
|
||
| # TODO: `@nograd` and `@linear` | ||
|
|
||
| @tangent zerolike(x) = zerolike(x), _ -> zerolike(x) | ||
| @tangent one(x) = one(x), _ -> zerolike(x) | ||
| @tangent typeof(x) = typeof(x), _ -> nothing | ||
| @tangent Core.apply_type(args...) = Core.apply_type(args...), (_...) -> nothing | ||
| @tangent fieldtype(args...) = fieldtype(args...), (_...) -> nothing | ||
| @tangent isa(a, b) = isa(a, b), (_, _) -> false | ||
| @tangent repr(x) = repr(x), _ -> nothing | ||
| @tangent println(x...) = println(x...), (_...) -> nothing | ||
| @tangent typeassert(x, T) = typeassert(x, T), (ẋ, _) -> ẋ | ||
| @tangent fieldnames(T) = fieldnames(T), _ -> zerolike(fieldnames(T)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are these 0s and not |
||
| @tangent fieldcount(T) = fieldcount(T), _ -> zerolike(fieldcount(T)) | ||
|
|
||
| @tangent tuple(t...) = t, (ṫ...) -> ṫ | ||
| @tangent tail(t) = tail(t), ṫ -> tail(ṫ) | ||
MikeInnes marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @tangent getfield(t, i) = getfield(t, i), (ṫ, _) -> getfield(ṫ, i) | ||
| @tangent setfield!(t, i, x) = setfield!(t, i, x), (ṫ, _, ẋ) -> setfield!(ṫ, i, ẋ) | ||
| @tangent getindex(t, i) = getindex(t, i), (ṫ, _) -> getindex(ṫ, i) | ||
| @tangent isdefined(t, i) = isdefined(t, i), (_, _) -> false | ||
|
|
||
| # TODO should be using a context for this | ||
| zerolike(x::Core.Box) = isdefined(x, :contents) ? Core.Box(zerolike(x.contents)) : Core.Box() | ||
| @tangent Core.Box() = Core.Box(), () -> Core.Box() | ||
| @tangent Core.Box(x) = Core.Box(x), ẋ -> Core.Box(x) | ||
|
|
||
| @tangent __new__(T, s...) = | ||
| __new__(T, s...), (_, ṡ...) -> NamedTuple{fieldnames(T)}(ṡ) | ||
|
|
||
| @tangent __splatnew__(T, s) = | ||
| __splatnew__(T, s), (_, ṡ) -> NamedTuple{fieldnames(T)}(ṡ) | ||
|
|
||
| function _pushforward(dargs, ::typeof(Core._apply), f, args...) | ||
| dargs = tail(dargs) # drop self gradient | ||
| df, dargs = first(dargs), tail(dargs) | ||
| dargs = Core._apply(tuple, dargs...) | ||
| Core._apply(_pushforward, ((df, dargs...), f), args...) | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| using DiffRules, SpecialFunctions, NaNMath | ||
| using Base.FastMath: fast_op, make_fastmath | ||
|
|
||
| # TODO use CSE here | ||
|
|
||
| for (M, f, arity) in DiffRules.diffrules() | ||
| arity == 1 || continue | ||
| dx = DiffRules.diffrule(M, f, :x) | ||
| @eval begin | ||
| @tangent $M.$f(x::Number) = $M.$f(x), ẋ -> ẋ * $dx | ||
| end | ||
| end | ||
|
|
||
| for (M, f, arity) in DiffRules.diffrules() | ||
| arity == 2 || continue | ||
| da, db = DiffRules.diffrule(M, f, :a, :b) | ||
| @eval begin | ||
| @tangent $M.$f(a::Number, b::Number) = $M.$f(a, b), (ȧ, ḃ) -> ȧ*$da + ḃ*$db | ||
| end | ||
| end | ||
|
|
||
| for f in [>, <, ==, ===, !=, in] | ||
| @eval @tangent $f(a, b) = $f(a, b), (_, _) -> false | ||
| end | ||
|
|
||
| @tangent convert(T::Type{<:Real}, x::Real) = convert(T, x), (_, ẋ) -> convert(T, ẋ) | ||
|
|
||
| @tangent function Colon()(xs...) | ||
| c = Colon()(xs...) | ||
| c, (_...) -> zerolike(c) | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| @generated function __new__(T, args...) | ||
| quote | ||
| Base.@_inline_meta | ||
| $(Expr(:new, :T, [:(args[$i]) for i = 1:length(args)]...)) | ||
| end | ||
| end | ||
|
|
||
| @generated function __splatnew__(T, args) | ||
| quote | ||
| Base.@_inline_meta | ||
| $(Expr(:splatnew, :T, :args)) | ||
| end | ||
| end | ||
|
|
||
| literal_getindex(x, ::Val{i}) where i = getindex(x, i) | ||
| literal_indexed_iterate(x, ::Val{i}) where i = Base.indexed_iterate(x, i) | ||
| literal_indexed_iterate(x, ::Val{i}, state) where i = Base.indexed_iterate(x, i, state) |
Uh oh!
There was an error while loading. Please reload this page.