Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
fail-fast: false
matrix:
version:
- "1.0" # LTS
- "1.6" # LTS
- "1" # Latest Release
os:
- ubuntu-latest
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.21"
version = "1.22"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -15,10 +15,10 @@ ChainRulesCore = "1.11.5"
ChainRulesTestUtils = "1"
Compat = "3.35"
FiniteDifferences = "0.12.20"
JuliaInterpreter = "0.8"
JuliaInterpreter = "0.8" # latest is "0.9.1"
RealDot = "0.1"
StaticArrays = "1.2"
julia = "1"
julia = "1.6"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Expand Down
7 changes: 0 additions & 7 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,6 @@ using Statistics
# to the normal rule of only overload via `ChainRulesCore.rrule`.
import ChainRulesCore: rrule, frule

if VERSION < v"1.3.0-DEV.142"
# In prior versions, the BLAS submodule also exported `dot`, which caused a conflict
# with its parent module. To get around this, we can simply create a hard binding for
# the one we want to use without qualification.
import LinearAlgebra: dot
end

# numbers that we know commute under multiplication
const CommutativeMulNumber = Union{Real,Complex}

Expand Down
4 changes: 1 addition & 3 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ end
@scalar_rule sinc(x) cosc(x)

# the position of the minus sign below warrants the correct type for π
if VERSION ≥ v"1.6"
@scalar_rule sincospi(x) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix))
end
@scalar_rule sincospi(x) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix))

@scalar_rule(
clamp(x, low, high),
Expand Down
258 changes: 128 additions & 130 deletions src/rulesets/Base/evalpoly.jl
Original file line number Diff line number Diff line change
@@ -1,151 +1,149 @@

if VERSION ≥ v"1.4"
function frule((_, ẋ, ṗ), ::typeof(evalpoly), x, p::Union{Tuple,AbstractVector})
Δx, Δp = ẋ, unthunk(ṗ)
N = length(p)
@inbounds y = p[N]
Δy = Δp[N]
@inbounds for i in (N - 1):-1:1
Δy = muladd(Δx, y, muladd(x, Δy, Δp[i]))
y = muladd(x, y, p[i])
end
return y, Δy
function frule((_, ẋ, ṗ), ::typeof(evalpoly), x, p::Union{Tuple,AbstractVector})
Δx, Δp = ẋ, unthunk(ṗ)
N = length(p)
@inbounds y = p[N]
Δy = Δp[N]
@inbounds for i in (N - 1):-1:1
Δy = muladd(Δx, y, muladd(x, Δy, Δp[i]))
y = muladd(x, y, p[i])
end
return y, Δy
end

function rrule(::typeof(evalpoly), x, p::Union{Tuple,AbstractVector})
y, ys = _evalpoly_intermediates(x, p)
project_x = ProjectTo(x)
project_p = p isa Tuple ? identity : ProjectTo(p)
function evalpoly_pullback(Δy)
∂x, ∂p = _evalpoly_back(x, p, ys, Δy)
return NoTangent(), project_x(∂x), project_p(∂p)
end
return y, evalpoly_pullback
function rrule(::typeof(evalpoly), x, p::Union{Tuple,AbstractVector})
y, ys = _evalpoly_intermediates(x, p)
project_x = ProjectTo(x)
project_p = p isa Tuple ? identity : ProjectTo(p)
function evalpoly_pullback(Δy)
∂x, ∂p = _evalpoly_back(x, p, ys, Δy)
return NoTangent(), project_x(∂x), project_p(∂p)
end
return y, evalpoly_pullback
end

function rrule(::typeof(evalpoly), x, p::Vector{<:Matrix}) # does not type infer with ProjectTo
y, ys = _evalpoly_intermediates(x, p)
function evalpoly_pullback(Δy)
∂x, ∂p = _evalpoly_back(x, p, ys, Δy)
return NoTangent(), ∂x, ∂p
end
return y, evalpoly_pullback
function rrule(::typeof(evalpoly), x, p::Vector{<:Matrix}) # does not type infer with ProjectTo
y, ys = _evalpoly_intermediates(x, p)
function evalpoly_pullback(Δy)
∂x, ∂p = _evalpoly_back(x, p, ys, Δy)
return NoTangent(), ∂x, ∂p
end
return y, evalpoly_pullback
end

# evalpoly but storing intermediates
function _evalpoly_intermediates(x, p::Tuple)
return if @generated
N = length(p.parameters)
exs = []
vars = []
ex = :(p[$N])
for i in 1:(N - 1)
yi = Symbol("y", i)
push!(vars, yi)
push!(exs, :($yi = $ex))
ex = :(muladd(x, $yi, p[$(N - i)]))
end
push!(exs, :(y = $ex))
Expr(:block, exs..., :(y, ($(vars...),)))
else
_evalpoly_intermediates_fallback(x, p)
# evalpoly but storing intermediates
function _evalpoly_intermediates(x, p::Tuple)
return if @generated
N = length(p.parameters)
exs = []
vars = []
ex = :(p[$N])
for i in 1:(N - 1)
yi = Symbol("y", i)
push!(vars, yi)
push!(exs, :($yi = $ex))
ex = :(muladd(x, $yi, p[$(N - i)]))
end
push!(exs, :(y = $ex))
Expr(:block, exs..., :(y, ($(vars...),)))
else
_evalpoly_intermediates_fallback(x, p)
end
function _evalpoly_intermediates_fallback(x, p::Tuple)
N = length(p)
y = p[N]
ys = (y, ntuple(N - 2) do i
return y = muladd(x, y, p[N - i])
end...)
y = muladd(x, y, p[1])
return y, ys
end
function _evalpoly_intermediates(x, p)
N = length(p)
@inbounds yn = one(x) * p[N]
ys = similar(p, typeof(yn), N - 1)
@inbounds ys[1] = yn
@inbounds for i in 2:(N - 1)
ys[i] = muladd(x, ys[i - 1], p[N - i + 1])
end
@inbounds y = muladd(x, ys[N - 1], p[1])
return y, ys
end
function _evalpoly_intermediates_fallback(x, p::Tuple)
N = length(p)
y = p[N]
ys = (y, ntuple(N - 2) do i
return y = muladd(x, y, p[N - i])
end...)
y = muladd(x, y, p[1])
return y, ys
end
function _evalpoly_intermediates(x, p)
N = length(p)
@inbounds yn = one(x) * p[N]
ys = similar(p, typeof(yn), N - 1)
@inbounds ys[1] = yn
@inbounds for i in 2:(N - 1)
ys[i] = muladd(x, ys[i - 1], p[N - i + 1])
end
@inbounds y = muladd(x, ys[N - 1], p[1])
return y, ys
end

# TODO: Handle following cases
# 1) x is a UniformScaling, pᵢ is a matrix
# 2) x is a matrix, pᵢ is a UniformScaling
@inline _evalpoly_backx(x, yi, ∂yi) = ∂yi * yi'
@inline _evalpoly_backx(x, yi, ∂x, ∂yi) = muladd(∂yi, yi', ∂x)
@inline _evalpoly_backx(x::Number, yi, ∂yi) = conj(dot(∂yi, yi))
@inline _evalpoly_backx(x::Number, yi, ∂x, ∂yi) = _evalpoly_backx(x, yi, ∂yi) + ∂x
# TODO: Handle following cases
# 1) x is a UniformScaling, pᵢ is a matrix
# 2) x is a matrix, pᵢ is a UniformScaling
@inline _evalpoly_backx(x, yi, ∂yi) = ∂yi * yi'
@inline _evalpoly_backx(x, yi, ∂x, ∂yi) = muladd(∂yi, yi', ∂x)
@inline _evalpoly_backx(x::Number, yi, ∂yi) = conj(dot(∂yi, yi))
@inline _evalpoly_backx(x::Number, yi, ∂x, ∂yi) = _evalpoly_backx(x, yi, ∂yi) + ∂x

@inline _evalpoly_backp(pi, ∂yi) = ∂yi
@inline _evalpoly_backp(pi, ∂yi) = ∂yi

function _evalpoly_back(x, p::Tuple, ys, Δy)
return if @generated
exs = []
vars = []
N = length(p.parameters)
for i in 2:(N - 1)
∂pi = Symbol("∂p", i)
push!(vars, ∂pi)
push!(exs, :(∂x = _evalpoly_backx(x, ys[$(N - i)], ∂x, ∂yi)))
push!(exs, :($∂pi = _evalpoly_backp(p[$i], ∂yi)))
push!(exs, :(∂yi = x′ * ∂yi))
end
push!(vars, :(_evalpoly_backp(p[$N], ∂yi))) # ∂pN
Expr(
:block,
:(x′ = x'),
:(∂yi = Δy),
:(∂p1 = _evalpoly_backp(p[1], ∂yi)),
:(∂x = _evalpoly_backx(x, ys[$(N - 1)], ∂yi)),
:(∂yi = x′ * ∂yi),
exs...,
:(∂p = (∂p1, $(vars...))),
:(∂x, Tangent{typeof(p),typeof(∂p)}(∂p)),
)
else
_evalpoly_back_fallback(x, p, ys, Δy)
function _evalpoly_back(x, p::Tuple, ys, Δy)
return if @generated
exs = []
vars = []
N = length(p.parameters)
for i in 2:(N - 1)
∂pi = Symbol("∂p", i)
push!(vars, ∂pi)
push!(exs, :(∂x = _evalpoly_backx(x, ys[$(N - i)], ∂x, ∂yi)))
push!(exs, :($∂pi = _evalpoly_backp(p[$i], ∂yi)))
push!(exs, :(∂yi = x′ * ∂yi))
end
push!(vars, :(_evalpoly_backp(p[$N], ∂yi))) # ∂pN
Expr(
:block,
:(x′ = x'),
:(∂yi = Δy),
:(∂p1 = _evalpoly_backp(p[1], ∂yi)),
:(∂x = _evalpoly_backx(x, ys[$(N - 1)], ∂yi)),
:(∂yi = x′ * ∂yi),
exs...,
:(∂p = (∂p1, $(vars...))),
:(∂x, Tangent{typeof(p),typeof(∂p)}(∂p)),
)
else
_evalpoly_back_fallback(x, p, ys, Δy)
end
function _evalpoly_back_fallback(x, p::Tuple, ys, Δy)
x′ = x'
∂yi = unthunk(Δy)
N = length(p)
∂p1 = _evalpoly_backp(p[1], ∂yi)
end
function _evalpoly_back_fallback(x, p::Tuple, ys, Δy)
x′ = x'
∂yi = unthunk(Δy)
N = length(p)
∂p1 = _evalpoly_backp(p[1], ∂yi)
∂x = _evalpoly_backx(x, ys[N - 1], ∂yi)
∂yi = x′ * ∂yi
∂p = (
∂p1,
ntuple(N - 2) do i
∂x = _evalpoly_backx(x, ys[N-i-1], ∂x, ∂yi)
∂pi = _evalpoly_backp(p[i+1], ∂yi)
∂yi = x′ * ∂yi
return ∂pi
end...,
_evalpoly_backp(p[N], ∂yi), # ∂pN
)
return ∂x, Tangent{typeof(p),typeof(∂p)}(∂p)
end
function _evalpoly_back(x, p, ys, Δy)
x′ = x'
∂yi = one(x′) * Δy
N = length(p)
@inbounds ∂p1 = _evalpoly_backp(p[1], ∂yi)
∂p = similar(p, typeof(∂p1))
@inbounds begin
∂x = _evalpoly_backx(x, ys[N - 1], ∂yi)
∂yi = x′ * ∂yi
∂p = (
∂p1,
ntuple(N - 2) do i
∂x = _evalpoly_backx(x, ys[N-i-1], ∂x, ∂yi)
∂pi = _evalpoly_backp(p[i+1], ∂yi)
∂yi = x′ * ∂yi
return ∂pi
end...,
_evalpoly_backp(p[N], ∂yi), # ∂pN
)
return ∂x, Tangent{typeof(p),typeof(∂p)}(∂p)
end
function _evalpoly_back(x, p, ys, Δy)
x′ = x'
∂yi = one(x′) * Δy
N = length(p)
@inbounds ∂p1 = _evalpoly_backp(p[1], ∂yi)
∂p = similar(p, typeof(∂p1))
@inbounds begin
∂x = _evalpoly_backx(x, ys[N - 1], ∂yi)
∂p[1] = ∂p1
for i in 2:(N - 1)
∂x = _evalpoly_backx(x, ys[N - i], ∂x, ∂yi)
∂p[i] = _evalpoly_backp(p[i], ∂yi)
∂yi = x′ * ∂yi
∂p[1] = ∂p1
for i in 2:(N - 1)
∂x = _evalpoly_backx(x, ys[N - i], ∂x, ∂yi)
∂p[i] = _evalpoly_backp(p[i], ∂yi)
∂yi = x′ * ∂yi
end
∂p[N] = _evalpoly_backp(p[N], ∂yi)
end
return ∂x, ∂p
∂p[N] = _evalpoly_backp(p[N], ∂yi)
end
return ∂x, ∂p
end
31 changes: 8 additions & 23 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,10 @@ function frule(
y = sum(abs2, x; dims=dims)
∂y = if dims isa Colon
2 * realdot(x, ẋ)
elseif VERSION ≥ v"1.2" # multi-iterator mapreduce introduced in v1.2
else
mapreduce(+, x, ẋ; dims=dims) do xi, dxi
2 * realdot(xi, dxi)
end
else
2 * sum(realdot.(x, ẋ); dims=dims)
end
return y, ∂y
end
Expand Down Expand Up @@ -419,16 +417,10 @@ end
# To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps
# be replaced by _peel1 like Iterators.peel

if VERSION >= v"1.6"
_reverse1(x) = Iterators.reverse(x)
_drop1(x) = Iterators.drop(x, 1)
_zip2(x, y) = zip(x, y) # for `accumulate`, below
else
# Old versions don't support accumulate(::itr), nor multi-dim reverse
_reverse1(x) = reverse(vec(x))
_drop1(x) = vec(x)[2:end]
_zip2(x, y) = collect(zip(x, y))
end
_reverse1(x) = Iterators.reverse(x)
_drop1(x) = Iterators.drop(x, 1)
_zip2(x, y) = zip(x, y) # for `accumulate`, below

_reverse1(x::Tuple) = reverse(x)
_drop1(x::Tuple) = Base.tail(x)
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N)
Expand Down Expand Up @@ -480,16 +472,9 @@ function rrule(
function decumulate(dy)
dy_plain = _no_tuple_tangent(unthunk(dy))
rev_list = if init === _InitialValue()
if VERSION >= v"1.6"
# Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
else
# However, on 1.0 and some others, zip does not stop early. But since accumulate
# also doesn't work on iterators, `_drop1` doesn't make one, so this should work:
_zip2(_reverse1(hobbits), _reverse1(_drop1(dy_plain)))
# What an awful tangle.
end
# Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
else
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
end
Expand Down
Loading