Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.23"
version = "1.24"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
169 changes: 140 additions & 29 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,37 @@

ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...)

function frule((_, ẋ), ::Type{T}, x::AbstractArray) where {T<:Array}
return T(x), T(ẋ)
end

function frule((_, ẋ), ::Type{AbstractArray{T}}, x::AbstractArray) where {T}
return AbstractArray{T}(x), AbstractArray{T}(ẋ)
end

function rrule(::Type{T}, x::AbstractArray) where {T<:Array}
project_x = ProjectTo(x)
Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))
return T(x), Array_pullback
end

# This abstract one is used for `float(x)` and other float conversion purposes:
function rrule(::Type{AbstractArray{T}}, x::AbstractArray) where {T}
project_x = ProjectTo(x)
AbstractArray_pullback(ȳ) = (NoTangent(), project_x(ȳ))
return AbstractArray{T}(x), AbstractArray_pullback
end

#####
##### `vect`
#####

@non_differentiable Base.vect()

function frule((_, ẋs...), ::typeof(Base.vect), xs::Number...)
return Base.vect(xs...), Base.vect(_instantiate_zeros(ẋs, xs)...)
end

# Case of uniform type `T`: the data passes straight through,
# so no projection should be required.
function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
Expand Down Expand Up @@ -43,32 +62,84 @@ function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N}
return Base.vect(X...), vect_pullback
end

"""
_instantiate_zeros(ẋs, xs)

Forward rules for `vect`, `cat` etc may receive a mixture of data and `ZeroTangent`s.
To avoid `vect(1, ZeroTangent(), 3)` or worse `vcat([1,2], ZeroTangent(), [6,7])`, this
materialises each zero `ẋ` to be `zero(x)`.
"""
_instantiate_zeros(ẋs, xs) = map(_i_zero, ẋs, xs)
_i_zero(ẋ, x) = ẋ
_i_zero(ẋ::AbstractZero, x) = zero(x)
# Possibly this won't work for partly non-diff arrays, sometihng like `gradient(x -> ["abc", x][end], 1)`
# may give a MethodError for `zero` but won't be wrong.

# Fast paths. Should it also collapse all-Zero cases?
_instantiate_zeros(ẋs::Tuple{Vararg{<:Number}}, xs) = ẋs
_instantiate_zeros(ẋs::Tuple{Vararg{<:AbstractArray}}, xs) = ẋs
_instantiate_zeros(ẋs::AbstractArray{<:Number}, xs) = ẋs
_instantiate_zeros(ẋs::AbstractArray{<:AbstractArray}, xs) = ẋs

#####
##### `copyto!`
#####

function frule((_, ẏ, ẋ), ::typeof(copyto!), y::AbstractArray, x)
return copyto!(y, x), copyto!(ẏ, ẋ)
end

function frule((_, ẏ, _, ẋ), ::typeof(copyto!), y::AbstractArray, i::Integer, x, js::Integer...)
return copyto!(y, i, x, js...), copyto!(ẏ, i, ẋ, js...)
end

#####
##### `reshape`
#####

function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Union{Colon,Int}}})
A_dims = size(A)
function reshape_pullback(Ȳ)
return (NoTangent(), reshape(Ȳ, A_dims), NoTangent())
end
return reshape(A, dims), reshape_pullback
function frule((_, ẋ), ::typeof(reshape), x::AbstractArray, dims...)
return reshape(x, dims...), reshape(ẋ, dims...)
end

function rrule(::typeof(reshape), A::AbstractArray, dims::Union{Colon,Int}...)
A_dims = size(A)
function reshape_pullback(Ȳ)
∂A = reshape(Ȳ, A_dims)
∂dims = broadcast(Returns(NoTangent()), dims)
return (NoTangent(), ∂A, ∂dims...)
end
function rrule(::typeof(reshape), A::AbstractArray, dims...)
ax = axes(A)
project = ProjectTo(A) # Projection is here for e.g. reshape(::Diagonal, :)
∂dims = broadcast(Returns(NoTangent()), dims)
reshape_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax)), ∂dims...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I correct in saying project will not do the reshaping for us, as it only handles cases with singleton dimensions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It will accept arrays whose size is almost right, i.e. differing by trailing 1s only. The offsets can be wrong. Doing one reshape(... , axes) here should mean it never reshapes twice.

return reshape(A, dims...), reshape_pullback
end

#####
##### `dropdims`
#####

function frule((_, ẋ), ::typeof(dropdims), x::AbstractArray; dims)
return dropdims(x; dims), dropdims(ẋ; dims)
end

function rrule(::typeof(dropdims), A::AbstractArray; dims)
ax = axes(A)
project = ProjectTo(A)
dropdims_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax)))
return dropdims(A; dims), dropdims_pullback
end

#####
##### `permutedims`
#####

function frule((_, ẋ), ::typeof(permutedims), x::AbstractArray, perm...)
return permutedims(x, perm...), permutedims(ẋ, perm...)
end

function frule((_, ẏ, ẋ), ::typeof(permutedims!), y::AbstractArray, x::AbstractArray, perm...)
return permutedims!(y, x, perm...), permutedims!(ẏ, ẋ, perm...)
end

function frule((_, ẋ), ::Type{<:PermutedDimsArray}, x::AbstractArray, perm)
return PermutedDimsArray(x, perm), PermutedDimsArray(ẋ, perm)
end

function rrule(::typeof(permutedims), x::AbstractVector)
project = ProjectTo(x)
permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy))))
Expand All @@ -91,6 +162,10 @@ end
##### `repeat`
#####

function frule((_, ẋs), ::typeof(repeat), xs::AbstractArray, cnt...; kw...)
return repeat(xs, cnt...; kw...), repeat(ẋs, cnt...; kw...)
end

function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs)))

project_Xs = ProjectTo(xs)
Expand Down Expand Up @@ -130,6 +205,10 @@ end
##### `hcat`
#####

function frule((_, ẋs...), ::typeof(hcat), xs...)
return hcat(xs...), hcat(_instantiate_zeros(ẋs, xs)...)
end

function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray
ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability
Expand Down Expand Up @@ -164,6 +243,10 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
return Y, hcat_pullback
end

function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat})
return reduce(hcat, As), reduce(hcat, _instantiate_zeros(Ȧs, As))
end

function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat})
widths = map(A -> size(A,2), As)
function reduce_hcat_pullback_2(dY)
Expand Down Expand Up @@ -192,6 +275,10 @@ end
##### `vcat`
#####

function frule((_, ẋs...), ::typeof(vcat), xs...)
return vcat(xs...), vcat(_instantiate_zeros(ẋs, xs)...)
end

function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...)
Y = vcat(Xs...)
ndimsY = Val(ndims(Y))
Expand Down Expand Up @@ -224,6 +311,10 @@ function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...)
return Y, vcat_pullback
end

function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat})
return reduce(vcat, As), reduce(vcat, _instantiate_zeros(Ȧs, As))
end

function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat})
Y = reduce(vcat, As)
ndimsY = Val(ndims(Y))
Expand All @@ -247,6 +338,10 @@ end

_val(::Val{x}) where {x} = x

function frule((_, ẋs...), ::typeof(cat), xs...; dims)
return cat(xs...; dims), cat(_instantiate_zeros(ẋs, xs)...; dims)
end

function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims)
Y = cat(Xs...; dims=dims)
cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims)
Expand Down Expand Up @@ -285,6 +380,10 @@ end
##### `hvcat`
#####

function frule((_, _, ẋs...), ::typeof(hvcat), rows, xs...)
return hvcat(rows, xs...), hvcat(rows, _instantiate_zeros(ẋs, xs)...)
end

function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...)
Y = hvcat(rows, values...)
cols = size(Y,2)
Expand Down Expand Up @@ -321,8 +420,12 @@ end
# 1-dim case allows start/stop, N-dim case takes dims keyword
# whose defaults changed in Julia 1.6... just pass them all through:

function frule((_, xdot), ::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
return reverse(x, args...; kw...), reverse(xdot, args...; kw...)
function frule((_, ẋ), ::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
return reverse(x, args...; kw...), reverse(ẋ, args...; kw...)
end

function frule((_, ẋ), ::typeof(reverse!), x::Union{AbstractArray, Tuple}, args...; kw...)
return reverse!(x, args...; kw...), reverse!(ẋ, args...; kw...)
end

function rrule(::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
Expand All @@ -338,8 +441,12 @@ end
##### `circshift`
#####

function frule((_, xdot), ::typeof(circshift), x::AbstractArray, shifts)
return circshift(x, shifts), circshift(xdot, shifts)
function frule((_, ẋ), ::typeof(circshift), x::AbstractArray, shifts)
return circshift(x, shifts), circshift(ẋ, shifts)
end

function frule((_, ẏ, ẋ), ::typeof(circshift!), y::AbstractArray, x::AbstractArray, shifts)
return circshift!(y, x, shifts), circshift!(ẏ, ẋ, shifts)
end

function rrule(::typeof(circshift), x::AbstractArray, shifts)
Expand All @@ -355,8 +462,12 @@ end
##### `fill`
#####

function frule((_, xdot), ::typeof(fill), x::Any, dims...)
return fill(x, dims...), fill(xdot, dims...)
function frule((_, ẋ), ::typeof(fill), x::Any, dims...)
return fill(x, dims...), fill(ẋ, dims...)
end

function frule((_, ẏ, ẋ), ::typeof(fill!), y::AbstractArray, x::Any)
return fill!(y, x), fill!(ẏ, ẋ)
end

function rrule(::typeof(fill), x::Any, dims...)
Expand All @@ -370,9 +481,9 @@ end
##### `filter`
#####

function frule((_, _, xdot), ::typeof(filter), f, x::AbstractArray)
function frule((_, _, ), ::typeof(filter), f, x::AbstractArray)
inds = findall(f, x)
return x[inds], xdot[inds]
return x[inds], [inds]
end

function rrule(::typeof(filter), f, x::AbstractArray)
Expand All @@ -392,9 +503,9 @@ end
for findm in (:findmin, :findmax)
findm_pullback = Symbol(findm, :_pullback)

@eval function frule((_, xdot), ::typeof($findm), x; dims=:)
@eval function frule((_, ), ::typeof($findm), x; dims=:)
y, ind = $findm(x; dims=dims)
return (y, ind), Tangent{typeof((y, ind))}(xdot[ind], NoTangent())
return (y, ind), Tangent{typeof((y, ind))}([ind], NoTangent())
end

@eval function rrule(::typeof($findm), x::AbstractArray; dims=:)
Expand Down Expand Up @@ -441,8 +552,8 @@ end
# Allow for second derivatives, by writing rules for `_zerolike_writeat`;
# these rules are the reason it takes a `dims` argument.

function frule((_, _, dydot), ::typeof(_zerolike_writeat), x, dy, dims, inds...)
return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dydot, dims, inds...)
function frule((_, _, dẏ), ::typeof(_zerolike_writeat), x, dy, dims, inds...)
return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dẏ, dims, inds...)
end

function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...)
Expand All @@ -457,9 +568,9 @@ end

# These rules for `maximum` pick the same subgradient as `findmax`:

function frule((_, xdot), ::typeof(maximum), x; dims=:)
function frule((_, ), ::typeof(maximum), x; dims=:)
y, ind = findmax(x; dims=dims)
return y, xdot[ind]
return y, [ind]
end

function rrule(::typeof(maximum), x::AbstractArray; dims=:)
Expand All @@ -468,9 +579,9 @@ function rrule(::typeof(maximum), x::AbstractArray; dims=:)
return y, maximum_pullback
end

function frule((_, xdot), ::typeof(minimum), x; dims=:)
function frule((_, ), ::typeof(minimum), x; dims=:)
y, ind = findmin(x; dims=dims)
return y, xdot[ind]
return y, [ind]
end

function rrule(::typeof(minimum), x::AbstractArray; dims=:)
Expand Down
24 changes: 23 additions & 1 deletion src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ end
##### `*`
#####

frule((_, ΔA, ΔB), ::typeof(*), A, B) = A * B, muladd(ΔA, B, A * ΔB)

frule((_, ΔA, ΔB, ΔC), ::typeof(*), A, B, C) = A*B*C, ΔA*B*C + A*ΔB*C + A*B*ΔC


function rrule(
::typeof(*),
Expand Down Expand Up @@ -88,7 +92,9 @@ function rrule(
end



#####
##### `*` matrix-scalar_rule
#####

function rrule(
::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}
Expand Down Expand Up @@ -204,6 +210,11 @@ end # VERSION
##### `muladd`
#####

function frule((_, ΔA, ΔB, Δz), ::typeof(muladd), A, B, z)
Ω = muladd(A, B, z)
return Ω, ΔA * B .+ A * ΔB .+ Δz
end

function rrule(
::typeof(muladd),
A::AbstractMatrix{<:CommutativeMulNumber},
Expand Down Expand Up @@ -351,6 +362,13 @@ end
##### `\`, `/` matrix-scalar_rule
#####

function frule((_, ΔA, Δb), ::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber)
return A/b, ΔA/b - A*(Δb/b^2)
end
function frule((_, Δa, ΔB), ::typeof(\), a::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber})
return B/a, ΔB/a - B*(Δa/a^2)
end

function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber)
Y = A/b
function slash_pullback_scalar(ȳ)
Expand Down Expand Up @@ -378,6 +396,8 @@ end
##### Negation (Unary -)
#####

frule((_, ΔA), ::typeof(-), A::AbstractArray) = -A, -ΔA

function rrule(::typeof(-), x::AbstractArray)
function negation_pullback(ȳ)
return NoTangent(), InplaceableThunk(ā -> ā .-= ȳ, @thunk(-ȳ))
Expand All @@ -390,6 +410,8 @@ end
##### Addition (Multiarg `+`)
#####

frule((_, ΔAs...), ::typeof(+), As::AbstractArray...) = +(As...), +(ΔAs...)

function rrule(::typeof(+), arrs::AbstractArray...)
y = +(arrs...)
arr_axs = map(axes, arrs)
Expand Down
Loading