Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ version = "0.2.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "f784254f428fb8fd7ac15982e5862a38a44523d3"
git-tree-sha1 = "b7720de347734f4716d1815b00ce5664ed6bbfd4"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.7"
version = "0.17.9"

[[Dates]]
deps = ["Printf"]
Expand Down Expand Up @@ -74,9 +74,11 @@ version = "0.10.8"

[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "72421971e60917b8cd7737f9577c4f0f87eab306"
git-tree-sha1 = "95cb11304fe8301fbbc8b95c672b5cfc6d5ad6ed"
repo-rev = "master"
repo-url = "https://github.com/MikeInnes/IRTools.jl.git"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.3.0"
version = "0.3.1"

[[IntelOpenMP_jll]]
deps = ["Libdl", "Pkg"]
Expand All @@ -103,10 +105,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MKL_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "61069ae718b8ab1e325bbfb4e5268902e7ea08e3"
deps = ["IntelOpenMP_jll", "Libdl", "Pkg"]
git-tree-sha1 = "720629cc8cbd12c146ca01b661fd1a6cf66e2ff4"
uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
version = "2019.0.117+0"
version = "2019.0.117+2"

[[MacroTools]]
deps = ["DataStructures", "Markdown", "Random"]
Expand Down
6 changes: 5 additions & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ using IRTools
using MacroTools, Requires
using MacroTools: @forward

export Params, gradient, pullback, @code_grad
export Params, gradient, pullback, pushforward, @code_grad

include("tools/idset.jl")
include("tools/buffer.jl")
include("tools/builtins.jl")

include("forward/Forward.jl")
using .Forward

include("compiler/reverse.jl")
include("compiler/emit.jl")
Expand Down
14 changes: 14 additions & 0 deletions src/forward/Forward.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module Forward

import ..Zygote
import ..Zygote: __new__, __splatnew__

export pushforward

include("compiler.jl")
include("interface.jl")
include("lib.jl")
include("number.jl")
include("array.jl")

end
25 changes: 25 additions & 0 deletions src/forward/array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using LinearAlgebra

zerolike(x::AbstractArray) = zerolike.(x)

@tangent function (A::Type{<:Array})(::UndefInitializer, sz::Integer...)
x = A(UndefInitializer(), sz...)
x, (_...) -> zerolike(x)
end

@tangent length(A::AbstractArray) = length(A), _ -> 0
@tangent size(A::AbstractArray, i::Integer) = size(A, i), (_, _) -> 0
@tangent size(A::AbstractArray) = size(A), _ -> zerolike(size(A))

@tangent Base.vect(xs...) = Base.vect(xs...), Base.vect

@tangent first(x) = first(x), ẋ -> first(ẋ)

@tangent setindex!(x::AbstractArray, v, inds...) =
setindex!(x, v, inds...), (ẋ, v̇, _...) -> setindex!(ẋ, v̇, inds...)

@tangent mul!(C, A, B) = mul!(C, A, B), (Ċ, Ȧ, Ḃ) -> Ċ .+= Ȧ*B .+ A*Ḃ

@tangent A::AbstractArray * B::AbstractArray = A*B, (Ȧ, Ḃ) -> Ȧ*B .+ A*Ḃ

@tangent sum(x; dims = :) = sum(x; dims = dims), ẋ -> sum(x, dims = dims)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this rule's wrong

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops – do you want commit so you can hack on this branch?

75 changes: 75 additions & 0 deletions src/forward/compiler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using IRTools.All
using IRTools: Pipe
using Base: tail

ntail(x, n) = n <= 0 ? x : xcall(:tail, ntail(x, n-1))

function instrument!(pr, v, st)
ex = st.expr
if isexpr(ex, :new)
st = stmt(st, expr = xcall(Zygote, :__new__, ex.args...))
pr[v] = st
elseif isexpr(ex, :splatnew)
st = stmt(st, expr = xcall(Zygote, :__splatnew__, ex.args...))
pr[v] = st
end
return st
end

function dual(ir)
args = copy(arguments(ir))
dx = argument!(ir, at = 1)
Δs = Dict()
for bl in blocks(ir)[2:end], arg in copy(arguments(bl))
Δs[arg] = argument!(bl, insert = false)
end
pr = Pipe(ir)
partial(x::Variable) = Δs[x]
partial(x) = push!(pr, xcall(Forward, :zerolike, x))
partial(v, x::Variable) = Δs[x]
partial(v, x) = insert!(pr, v, xcall(Forward, :zerolike, x))
for (i, x) in enumerate(args)
if i == length(args) && ir.meta.method.isva
Δs[x] = push!(pr, ntail(dx, i-1))
else
Δs[x] = push!(pr, xcall(:getindex, dx, i))
end
end
branches(pr) do br
args = arguments(br)
if isreturn(br)
args[1] = push!(pr, xcall(:tuple, args[1], partial(args[1])))
else
for arg in copy(args)
push!(args, partial(arg))
end
end
br
end
for (v, st) in pr
st = instrument!(pr, v, st)
if isexpr(st.expr, :meta, :inbounds, :loopinfo)
Δs[v] = nothing
elseif isexpr(st.expr, :boundscheck) ||
(isexpr(st.expr, :call) && st.expr.args[1] == GlobalRef(Base, :not_int))
Δs[v] = false
elseif isexpr(st.expr, :call)
dargs = insert!(pr, v, xcall(:tuple, partial.((v,), st.expr.args)...))
result = insert!(pr, v, stmt(st, expr = xcall(Forward, :_pushforward, dargs, st.expr.args...)))
pr[v] = xcall(:getindex, result, 1)
Δs[v] = push!(pr, xcall(:getindex, result, 2))
elseif !isexpr(st.expr)
Δs[v] = push!(pr, xcall(Forward, :zerolike, v))
else
error("Unsupported $(st.expr.head) expression")
end
end
ir = finish(pr)
return ir
end

@dynamo function _pushforward(_, x...)
ir = IR(x...)
ir == nothing && return :(error("non-differentiable function $(args[2])"))
return dual(ir)
end
49 changes: 49 additions & 0 deletions src/forward/interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using MacroTools: @capture, @q, shortdef
using ZygoteRules: named, typeless, isvararg
using Base: tail

drop(x, n) = n == 0 ? x : :(tail($(drop(x, n-1))))
drop(n) = x -> drop(x, n)

# TODO: move to ZygoteRules
function tangent end

function gradm(ex)
@capture(shortdef(ex), (name_(args__) = body_) |
(name_(args__) where {Ts__} = body_)) || error("Need a function definition")
kw = length(args) > 1 && isexpr(args[1], :parameters) ? esc(popfirst!(args)) : nothing
isclosure = isexpr(name, :(::)) && length(name.args) > 1
f, T = isexpr(name, :(::)) ?
(length(name.args) == 1 ? (esc(gensym()), esc(name.args[1])) : esc.(name.args)) :
(esc(gensym()), :(Core.Typeof($(esc(name)))))
kT = :(Core.kwftype($T))
Ts == nothing && (Ts = [])
args = named.(args)
argnames = Any[typeless(arg) for arg in args]
!isempty(args) && isvararg(args[end]) && (argnames[end] = :($(argnames[end])...,))
args = esc.(args)
argnames = esc.(argnames)
Ts = esc.(Ts)
fargs = kw == nothing ? [:($f::$T), args...] : [kw, :($f::$T), args...]
dropg = isclosure ? identity : drop(1)
dropkw = isclosure ? drop(2) : drop(3)
adj = @q @inline Zygote.Forward.tangent($(fargs...)) where $(Ts...) = $(esc(body))
quote
$adj
@inline function Zygote.Forward._pushforward(partials, $f::$T, $(args...)) where $(Ts...)
y, forw = tangent($f, $(argnames...))
return y, forw($(dropg(:partials))...)
end
@inline function Zygote.Forward._pushforward(dargs, ::$kT, kw, $f::$T, $(args...)) where $(Ts...)
y, forw = tangent($f, $(argnames...))
return y, forw($(dropkw(:partials))...)
end
nothing
end
end

macro tangent(ex)
gradm(ex)
end

pushforward(f, x...) = (ẋ...) -> _pushforward((zerolike(f), ẋ...), f, x...)[2]
46 changes: 46 additions & 0 deletions src/forward/lib.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
zerolike(x::Number) = zero(x)
zerolike(x::Tuple) = zerolike.(x)
zerolike(x::T) where T =
NamedTuple{fieldnames(T)}(map(f -> zerolike(getfield(x, f)), fieldnames(T)))
# TODO figure out why this made a test fail
zerolike(x::Union{Module,Type}) = false

# TODO: `@nograd` and `@linear`

@tangent zerolike(x) = zerolike(x), _ -> zerolike(x)
@tangent one(x) = one(x), _ -> zerolike(x)
@tangent typeof(x) = typeof(x), _ -> nothing
@tangent Core.apply_type(args...) = Core.apply_type(args...), (_...) -> nothing
@tangent fieldtype(args...) = fieldtype(args...), (_...) -> nothing
@tangent isa(a, b) = isa(a, b), (_, _) -> false
@tangent repr(x) = repr(x), _ -> nothing
@tangent println(x...) = println(x...), (_...) -> nothing
@tangent typeassert(x, T) = typeassert(x, T), (ẋ, _) -> ẋ
@tangent fieldnames(T) = fieldnames(T), _ -> zerolike(fieldnames(T))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these 0s and not nothing?

@tangent fieldcount(T) = fieldcount(T), _ -> zerolike(fieldcount(T))

@tangent tuple(t...) = t, (ṫ...) -> ṫ
@tangent tail(t) = tail(t), ṫ -> tail(ṫ)

@tangent getfield(t, i) = getfield(t, i), (ṫ, _) -> getfield(ṫ, i)
@tangent setfield!(t, i, x) = setfield!(t, i, x), (ṫ, _, ẋ) -> setfield!(ṫ, i, ẋ)
@tangent getindex(t, i) = getindex(t, i), (ṫ, _) -> getindex(ṫ, i)
@tangent isdefined(t, i) = isdefined(t, i), (_, _) -> false

# TODO should be using a context for this
zerolike(x::Core.Box) = isdefined(x, :contents) ? Core.Box(zerolike(x.contents)) : Core.Box()
@tangent Core.Box() = Core.Box(), () -> Core.Box()
@tangent Core.Box(x) = Core.Box(x), ẋ -> Core.Box(x)

@tangent __new__(T, s...) =
__new__(T, s...), (_, ṡ...) -> NamedTuple{fieldnames(T)}(ṡ)

@tangent __splatnew__(T, s) =
__splatnew__(T, s), (_, ṡ) -> NamedTuple{fieldnames(T)}(ṡ)

function _pushforward(dargs, ::typeof(Core._apply), f, args...)
dargs = tail(dargs) # drop self gradient
df, dargs = first(dargs), tail(dargs)
dargs = Core._apply(tuple, dargs...)
Core._apply(_pushforward, ((df, dargs...), f), args...)
end
31 changes: 31 additions & 0 deletions src/forward/number.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using DiffRules, SpecialFunctions, NaNMath
using Base.FastMath: fast_op, make_fastmath

# TODO use CSE here

for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue
dx = DiffRules.diffrule(M, f, :x)
@eval begin
@tangent $M.$f(x::Number) = $M.$f(x), ẋ -> ẋ * $dx
end
end

for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :a, :b)
@eval begin
@tangent $M.$f(a::Number, b::Number) = $M.$f(a, b), (ȧ, ḃ) -> ȧ*$da + ḃ*$db
end
end

for f in [>, <, ==, ===, !=, in]
@eval @tangent $f(a, b) = $f(a, b), (_, _) -> false
end

@tangent convert(T::Type{<:Real}, x::Real) = convert(T, x), (_, ẋ) -> convert(T, ẋ)

@tangent function Colon()(xs...)
c = Colon()(xs...)
c, (_...) -> zerolike(c)
end
18 changes: 0 additions & 18 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ using Base: tail

@adjoint tuple(xs...) = xs, identity

literal_getindex(x, ::Val{i}) where i = getindex(x, i)
literal_indexed_iterate(x, ::Val{i}) where i = Base.indexed_iterate(x, i)
literal_indexed_iterate(x, ::Val{i}, state) where i = Base.indexed_iterate(x, i, state)

@adjoint literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i} =
(xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing))

Expand Down Expand Up @@ -233,20 +229,6 @@ end
end
end

@generated function __new__(T, args...)
quote
Base.@_inline_meta
$(Expr(:new, :T, [:(args[$i]) for i = 1:length(args)]...))
end
end

@generated function __splatnew__(T, args)
quote
Base.@_inline_meta
$(Expr(:splatnew, :T, :args))
end
end

struct Jnew{T,G,splat}
g::G
end
Expand Down
17 changes: 17 additions & 0 deletions src/tools/builtins.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
@generated function __new__(T, args...)
quote
Base.@_inline_meta
$(Expr(:new, :T, [:(args[$i]) for i = 1:length(args)]...))
end
end

@generated function __splatnew__(T, args)
quote
Base.@_inline_meta
$(Expr(:splatnew, :T, :args))
end
end

literal_getindex(x, ::Val{i}) where i = getindex(x, i)
literal_indexed_iterate(x, ::Val{i}) where i = Base.indexed_iterate(x, i)
literal_indexed_iterate(x, ::Val{i}, state) where i = Base.indexed_iterate(x, i, state)
Loading