From 0932a86d2efc81e272756385543e4bbf31c81eb7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 18 Jan 2022 21:39:38 -0500 Subject: [PATCH 01/18] drop 1.0, now that LTS == 1.6 --- Project.toml | 1 - test/Project.toml | 23 +++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 test/Project.toml diff --git a/Project.toml b/Project.toml index 6c3f90242..41739fe0d 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1.11.5" -ChainRulesTestUtils = "1" Compat = "3.35" FiniteDifferences = "0.12.20" IrrationalConstants = "0.1.1" diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 000000000..84e3ebd58 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,23 @@ +[deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +# and test-only: +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +ChainRulesCore = "1.11.5" +ChainRulesTestUtils = "1.3.1" +Compat = "3.35" +FiniteDifferences = "0.12.20" +JuliaInterpreter = "0.8" # latest is "0.9.1" +RealDot = "0.1" +StaticArrays = "1.2" +julia = "1.6" From edb1bbd2b432932fed7b146333e91a93038ea259 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 19 Jan 2022 10:14:21 -0500 Subject: [PATCH 02/18] revert to one Project --- Project.toml | 1 + test/Project.toml | 23 ----------------------- 2 files changed, 1 insertion(+), 23 deletions(-) delete mode 100644 test/Project.toml diff --git a/Project.toml b/Project.toml index 41739fe0d..6c3f90242 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1.11.5" +ChainRulesTestUtils = "1" Compat = "3.35" FiniteDifferences = "0.12.20" IrrationalConstants = "0.1.1" diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index 84e3ebd58..000000000 --- a/test/Project.toml +++ /dev/null @@ -1,23 +0,0 @@ -[deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -# and test-only: -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[compat] -ChainRulesCore = "1.11.5" -ChainRulesTestUtils = "1.3.1" -Compat = "3.35" -FiniteDifferences = "0.12.20" -JuliaInterpreter = "0.8" # latest is "0.9.1" -RealDot = "0.1" -StaticArrays = "1.2" -julia = "1.6" From 60a73b53fa9646a0189b7ef1f4eaf0ede1425759 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 19 Jan 2022 10:15:12 -0500 Subject: [PATCH 03/18] rm Compat --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6c3f90242..f0751b795 100644 --- a/Project.toml +++ b/Project.toml @@ -14,7 +14,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1.11.5" ChainRulesTestUtils = "1" -Compat = "3.35" FiniteDifferences = "0.12.20" IrrationalConstants = "0.1.1" JuliaInterpreter = "0.8" # latest is "0.9.1" From 94913bce425b937ddefa031d0baf913b2dc04f37 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 19 Jan 2022 21:37:36 -0500 Subject: [PATCH 04/18] turns out this does still need Compat --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index f0751b795..6c3f90242 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1.11.5" ChainRulesTestUtils = "1" +Compat = "3.35" FiniteDifferences = "0.12.20" IrrationalConstants = "0.1.1" JuliaInterpreter = "0.8" # latest is "0.9.1" From 94d58984ccd7a6e0f35e6c817833a34344955fa0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 14 Jan 2022 01:01:04 -0500 Subject: [PATCH 05/18] add many frules --- src/rulesets/Base/array.jl | 64 +++++++++++++++++++++++++++++++++ src/rulesets/Base/arraymath.jl | 24 ++++++++++++- src/rulesets/Base/indexing.jl | 25 +++++++++++++ src/rulesets/Base/sort.jl | 16 +++++++++ test/rulesets/Base/array.jl | 49 +++++++++++++++++++++++++ test/rulesets/Base/arraymath.jl | 25 ++++++++++++- test/rulesets/Base/indexing.jl | 9 +++++ test/rulesets/Base/sort.jl | 10 ++++++ 8 files changed, 220 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index b6ae01ce5..9c580a427 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -4,6 +4,10 @@ ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...) +function frule((_, xdot), ::Type{T}, x::AbstractArray) where {T<:Array} + return T(x), T(xdot) +end + function rrule(::Type{T}, x::AbstractArray) where {T<:Array} project_x = ProjectTo(x) Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) @@ -16,6 +20,10 @@ end @non_differentiable Base.vect() +function frule((_, xdots...), ::typeof(Base.vect), xs::Number...) + return Base.vect(xs...), Base.vect(_make_real_zeros(xdots, xs)...) +end + # Case of uniform type `T`: the data passes straight through, # so no projection should be required. function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N} @@ -43,10 +51,34 @@ function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N} return Base.vect(X...), vect_pullback end +""" + _make_real_zeros(xdots, xs) + +Forward rules for `vect` or `cat` may receive a mixture of data and `ZeroTangent`s. +To avoid `vect(1, ZeroTangent(), 3)` or `hcat([1,2], ZeroTangent())`, this materialises +each zero `xdot` to be `zero(x)`. +""" +_make_real_zeros(xdots, xs) = map(_real_zero, xdots, xs) +_real_zero(xdot, x) = xdot +_real_zero(xdot::AbstractZero, x) = zero(x) + +# Fast paths. Should it also collapse all-Zero cases? +_make_real_zeros(xdots::NTuple{<:Any, <:Number}, xs) = xdots +_make_real_zeros(xdots::AbstractArray{<:Number}, xs) = xdots +_make_real_zeros(xdots::AbstractArray{<:AbstractArray}, xs) = xdots + +frrule((_, xdd, _), ::typeof(_make_real_zeros), xdots, xs) = _make_real_zeros(xdots, xs), xdd # not very sure! + +rrule(::typeof(_make_real_zeros), xdots, xs) = _make_real_zeros(xdots, xs), dydots -> (NoTangent(), dydots, NoTangent()) + ##### ##### `reshape` ##### +function frule((_, xdot), ::typeof(reshape), x::AbstractArray, dims...) + return reshape(x, dims...), reshape(xdot, dims...) +end + function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Union{Colon,Int}}}) A_dims = size(A) function reshape_pullback(Ȳ) @@ -69,6 +101,10 @@ end ##### `permutedims` ##### +function frule((_, xdot), ::typeof(permutedims), x::AbstractArray, perm...) + return permutedims(x, perm...), permutedims(xdot, perm...) +end + function rrule(::typeof(permutedims), x::AbstractVector) project = ProjectTo(x) permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy)))) @@ -91,6 +127,10 @@ end ##### `repeat` ##### +function frule((_, xsdot), ::typeof(repeat), xs::AbstractArray, cnt...; kw...) + return repeat(xs, cnt...; kw...), repeat(xsdot, cnt...; kw...) +end + function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs))) project_Xs = ProjectTo(xs) @@ -130,6 +170,10 @@ end ##### `hcat` ##### +function frule((_, xdots...), ::typeof(hcat), xs...) + return hcat(xs...), hcat(_make_real_zeros(xdots, xs)...) +end + function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...) Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability @@ -164,6 +208,10 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...) return Y, hcat_pullback end +function frule((_, _, Adots), ::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat}) + return reduce(hcat, As), reduce(hcat, _make_real_zeros(Adots, As)) +end + function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat}) widths = map(A -> size(A,2), As) function reduce_hcat_pullback_2(dY) @@ -192,6 +240,10 @@ end ##### `vcat` ##### +function frule((_, xdots...), ::typeof(vcat), xs...) + return vcat(xs...), vcat(_make_real_zeros(xdots, xs)...) +end + function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...) Y = vcat(Xs...) ndimsY = Val(ndims(Y)) @@ -224,6 +276,10 @@ function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...) return Y, vcat_pullback end +function frule((_, _, Adots), ::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat}) + return reduce(vcat, As), reduce(vcat, _make_real_zeros(Adots, As)) +end + function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat}) Y = reduce(vcat, As) ndimsY = Val(ndims(Y)) @@ -247,6 +303,10 @@ end _val(::Val{x}) where {x} = x +function frule((_, xdots...), ::typeof(cat), xs...; dims) + return cat(xs...; dims), cat(_make_real_zeros(xdots, xs)...; dims) +end + function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims) Y = cat(Xs...; dims=dims) cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims) @@ -285,6 +345,10 @@ end ##### `hvcat` ##### +function frule((_, _, xdots...), ::typeof(hvcat), rows, xs...) + return hvcat(rows, xs...), hvcat(rows, _make_real_zeros(xdots, xs)...) +end + function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...) Y = hvcat(rows, values...) cols = size(Y,2) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 9b442c502..e20d02acf 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -19,6 +19,10 @@ end ##### `*` ##### +frule((_, Adot, Bdot), ::typeof(*), A, B) = A * B, muladd(Adot, B, A * Bdot) + +frule((_, Adot, Bdot, Cdot), ::typeof(*), A, B, C) = A*B*C, Adot*B*C + A*Bdot*C + A*B*Cdot + function rrule( ::typeof(*), @@ -88,7 +92,9 @@ function rrule( end - +##### +##### `*` matrix-scalar_rule +##### function rrule( ::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber} @@ -204,6 +210,11 @@ end # VERSION ##### `muladd` ##### +function frule((_, Adot, Bdot, zdot), ::typeof(muladd), A, B, z) + Ω = muladd(A, B, z) + return Ω, Adot * B .+ A * Bdot .+ zdot +end + function rrule( ::typeof(muladd), A::AbstractMatrix{<:CommutativeMulNumber}, @@ -351,6 +362,13 @@ end ##### `\`, `/` matrix-scalar_rule ##### +function frule((_, Adot, bdot), ::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber) + return A/b, Adot/b - A*(bdot/b^2) +end +function frule((_, adot, Bdot), ::typeof(\), a::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}) + return B/a, Bdot/a - B*(adot/a^2) +end + function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber) Y = A/b function slash_pullback_scalar(ȳ) @@ -378,6 +396,8 @@ end ##### Negation (Unary -) ##### +frule((_, Adot), ::typeof(-), A::AbstractArray) = -A, -Adot + function rrule(::typeof(-), x::AbstractArray) function negation_pullback(ȳ) return NoTangent(), InplaceableThunk(ā -> ā .-= ȳ, @thunk(-ȳ)) @@ -390,6 +410,8 @@ end ##### Addition (Multiarg `+`) ##### +frule((_, Adots...), ::typeof(+), As::AbstractArray...) = +(As...), +(Adots...) + function rrule(::typeof(+), arrs::AbstractArray...) y = +(arrs...) arr_axs = map(axes, arrs) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 65dd55631..2cee21dce 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -2,6 +2,10 @@ ##### getindex ##### +function frule((_, xdot), ::typeof(getindex), x::AbstractArray, inds...) + return x[inds...], xdot[inds...] +end + function rrule(::typeof(getindex), x::Array{<:Number}, inds...) # removes any logical indexing, CartesianIndex etc # leaving us just with a tuple of Int, Arrays of Int and Ranges of Int @@ -27,3 +31,24 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...) return y, getindex_pullback end +##### +##### view +##### + +function frule((_, xdot), ::typeof(view), x::AbstractArray, inds...) + return view(x, inds...), view(xdot, inds...) +end + +##### +##### setindex! +##### + +function frule((_, xdot), ::typeof(setindex!), x::AbstractArray, v, inds...) + @info "setindex!" x v inds xdot + v1 = x[inds...] = v + v2 = xdot[inds...] = v + return v1, v2 +end + + + diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index be7840c8c..b6939e33f 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -2,6 +2,11 @@ ##### `sort` ##### +function frule((_, xsdot, _), ::typeof(partialsort), xs::AbstractVector, k; kw...) + inds = partialsortperm(xs, k; kw...) + return xs[inds], xsdot[inds] +end + function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,OrdinalRange}; kwargs...) inds = partialsortperm(xs, k; kwargs...) ys = xs[inds] @@ -20,6 +25,11 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin return ys, partialsort_pullback end +function frule((_, xsdot), ::typeof(sort), xs::AbstractVector; kw...) + inds = sortperm(xs; kw...) + return xs[inds], xsdot[inds] +end + function rrule(::typeof(sort), xs::AbstractVector; kwargs...) inds = sortperm(xs; kwargs...) ys = xs[inds] @@ -42,6 +52,12 @@ end ##### `sortslices` ##### +function frule((_, xdot), ::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) + p = sortperm(collect(eachslice(x; dims=dims)); kw...) + inds = ntuple(d -> d == dims ? p : (:), ndims(x)) + return x[inds...], xdot[inds...] +end + function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) p = sortperm(collect(eachslice(x; dims=dims)); kw...) inds = ntuple(d -> d == dims ? p : (:), ndims(x)) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 7a82bd350..26092d7e2 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -11,6 +11,10 @@ @test pullback(randn(5)) == (NoTangent(), NoTangent(), NoTangent()) end @testset "from existing array" begin + # fwd + test_frule(Array, randn(2, 5)) + test_frule(Array, Diagonal(randn(5))) + # rev test_rrule(Array, randn(2, 5)) test_rrule(Array, Diagonal(randn(5))) test_rrule(Matrix, Diagonal(randn(5))) @@ -27,6 +31,9 @@ end test_rrule(Base.vect, randn(2, 2), randn(3, 3)) end @testset "inhomogeneous type" begin + # fwd + test_frule(Base.vect, 5.0, 3f0) + # rev test_rrule( Base.vect, 5.0, 3f0; atol=1e-6, rtol=1e-6, @@ -34,17 +41,30 @@ end test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false) test_rrule(Base.vect, (5.0, 4.0), (y=randn(3),); check_inferred=false) end + @testset "_make_real_zeros" begin + # This is an internal function also used for `cat` etc. + # It has its own rules to allow for 2nd derivatives. + @eval using ChainRules: _make_real_zeros + @test_skip test_frule(_make_real_zeros, Tuple(rand(3)), Tuple(rand(3))) + @test_skip test_rrule(_make_real_zeros, Tuple(rand(3)), Tuple(rand(3))) + # Not sure these are defined right! Currently fail due to https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/229 + end end @testset "reshape" begin + # fwd + test_frule(reshape, rand(4, 5), 2, :) + # rev test_rrule(reshape, rand(4, 5), (2, 10)) test_rrule(reshape, rand(4, 5), 2, 10) test_rrule(reshape, rand(4, 5), 2, :) end @testset "permutedims + PermutedDimsArray" begin + test_frule(permutedims, rand(5)) test_rrule(permutedims, rand(5)) + test_frule(permutedims, rand(3, 4), (2, 1)) test_rrule(permutedims, rand(3, 4), (2, 1)) test_rrule(permutedims, Diagonal(rand(5)), (2, 1)) # Note BTW that permutedims(Diagonal(rand(5))) does not use the rule at all @@ -59,7 +79,11 @@ end end @testset "repeat" begin + # forward + test_frule(repeat, rand(4), 2) + test_frule(repeat, rand(2, 3); fkwargs = (inner=(1,2), outer=(1,3))) + # reverse test_rrule(repeat, rand(4, )) test_rrule(repeat, rand(4, 5)) test_rrule(repeat, rand(4, 5); fkwargs = (outer=(1,2),)) @@ -92,6 +116,11 @@ end end @testset "hcat" begin + # forward + test_frule(hcat, randn(3, 2), randn(3)) + test_frule(hcat, randn(), randn(1,3)) + + # reverse test_rrule(hcat, randn(3, 2), randn(3), randn(3, 3)) test_rrule(hcat, rand(), rand(1,2), rand(1,2,1)) test_rrule(hcat, rand(3,1,1,2), rand(3,3,1,2)) @@ -102,9 +131,11 @@ end @testset "reduce hcat" begin mats = [randn(3, 2), randn(3, 1), randn(3, 3)] + test_frule(reduce, hcat, mats) test_rrule(reduce, hcat, mats) vecs = [rand(3) for _ in 1:4] + test_frule(reduce, hcat, vecs) test_rrule(reduce, hcat, vecs) mix = AbstractVecOrMat[rand(4,2), rand(4)] # this is weird, but does hit the fast path @@ -121,20 +152,29 @@ end end @testset "vcat" begin + + # forward + test_frule(vcat, randn(), randn(3), rand()) + test_frule(vcat, randn(3, 1), randn(3)) + + # reverse test_rrule(vcat, randn(2, 4), randn(1, 4), randn(3, 4)) test_rrule(vcat, rand(), rand()) test_rrule(vcat, rand(), rand(3), rand(3,1,1)) test_rrule(vcat, rand(3,1,2), rand(4,1,2)) + # mix types test_rrule(vcat, rand(2, 2), rand(2, 2)') end @testset "reduce vcat" begin mats = [randn(2, 4), randn(1, 4), randn(3, 4)] + test_frule(reduce, vcat, mats) test_rrule(reduce, vcat, mats) vecs = [rand(2), rand(3), rand(4)] + test_frule(reduce, vcat, vecs) test_rrule(reduce, vcat, vecs) mix = AbstractVecOrMat[rand(4,1), rand(4)] @@ -142,6 +182,11 @@ end end @testset "cat" begin + # forward + test_frule(cat, rand(2, 4), rand(1, 4); fkwargs=(dims=1,)) + test_frule(cat, rand(), rand(2,3); fkwargs=(dims=(1,2),)) + + # reverse test_rrule(cat, rand(2, 4), rand(1, 4); fkwargs=(dims=1,)) test_rrule(cat, rand(2, 4), rand(2); fkwargs=(dims=Val(2),)) test_rrule(cat, rand(), rand(2, 3); fkwargs=(dims=[1,2],)) @@ -151,6 +196,10 @@ end end @testset "hvcat" begin + # forward + test_frule(hvcat, 2, rand(6)...) + + # reverse test_rrule(hvcat, 2, rand(ComplexF64, 6)...) test_rrule(hvcat, (2, 1), rand(), rand(1,1), rand(2,2)) test_rrule(hvcat, 1, rand(3)' ⊢ rand(1,3), transpose(rand(3)) ⊢ rand(1,3)) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index b25b16acf..324311356 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -11,24 +11,30 @@ ⋆() = only(⋆(())) # scalar @testset "Scalar-Array $dims" for dims in ((3,), (5,4), (2, 3, 4, 5)) + test_frule(*, ⋆(), ⋆(dims)) + test_frule(*, ⋆(dims), ⋆()) + test_rrule(*, ⋆(), ⋆(dims)) test_rrule(*, ⋆(dims), ⋆()) end @testset "AbstractMatrix-AbstractVector n=$n, m=$m" for n in (2, 3), m in (4, 5) @testset "Array" begin + test_frule(*, n ⋆ m, ⋆(m)) test_rrule(*, n ⋆ m, ⋆(m)) end end @testset "AbstractVector-AbstractMatrix n=$n, m=$m" for n in (2, 3), m in (4, 5) @testset "Array" begin + test_frule(*, ⋆(n), 1 ⋆ m) test_rrule(*, ⋆(n), 1 ⋆ m) end end @testset "dense-matrix n=$n, m=$m, p=$p" for n in (2, 5), m in (2, 4), p in (2, 3) @testset "Array" begin + test_frule(*, (n⋆m), (m⋆p)) test_rrule(*, (n⋆m), (m⋆p)) end @@ -53,6 +59,11 @@ end @testset "Diagonal" begin + # fwd + test_frule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0])) + test_frule(*, Diagonal([1.0, 2.0, 3.0]), rand(3)) + + # rev test_rrule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0])) test_rrule(*, Diagonal([1.0, 2.0, 3.0]), rand(3)) @@ -71,6 +82,9 @@ @testset "muladd: $T" for T in (Float64, ComplexF64) @testset "add $(typeof(z))" for z in [rand(T), rand(T, 3), rand(T, 3, 3), false] + @testset "forward mode" begin + test_frule(muladd, rand(T, 3, 5), rand(T, 5, 3), z) + end @testset "matrix * matrix" begin A = rand(T, 3, 3) B = rand(T, 3, 3) @@ -166,6 +180,10 @@ @testset "/ and \\ Scalar-AbstractArray" begin A = round.(10 .* randn(3, 4, 5), digits=1) + # fwd + test_frule(/, A, 7.2) + test_frule(\, 7.2, A) + # rev test_rrule(/, A, 7.2) test_rrule(\, 7.2, A) @@ -177,12 +195,17 @@ @testset "negation" begin A = randn(4, 4) Ā = randn(4, 4) - + # fwd + test_frule(-, A) + # rev test_rrule(-, A) test_rrule(-, Diagonal(A); output_tangent=Diagonal(Ā)) end @testset "addition" begin + # fwd + test_frule(+, randn(2), randn(2), randn(2)) + # rev test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4)) test_rrule(+, randn(3), randn(3,1), randn(3,1,1)) end diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 54939858a..094f26157 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -1,6 +1,15 @@ @testset "getindex" begin @testset "getindex(::Matrix{<:Number}, ...)" begin x = [1.0 2.0 3.0; 10.0 20.0 30.0] + + @testset "forward mode" begin + test_frule(getindex, x, 2) + test_frule(getindex, x, 2, 1) + test_frule(getindex, x, CartesianIndex(2, 3)) + + test_rrule(getindex, x, 2:3) + test_rrule(getindex, x, (:), 2:3) + end @testset "single element" begin test_rrule(getindex, x, 2) diff --git a/test/rulesets/Base/sort.jl b/test/rulesets/Base/sort.jl index f76109586..052045d1e 100644 --- a/test/rulesets/Base/sort.jl +++ b/test/rulesets/Base/sort.jl @@ -1,11 +1,19 @@ @testset "sort.jl" begin @testset "sort" begin a = rand(10) + # fwd + test_frule(sort, a) + test_frule(sort, a; fkwargs=(;rev=true)) + # rev test_rrule(sort, a) test_rrule(sort, a; fkwargs=(;rev=true)) end @testset "partialsort" begin a = rand(10) + # fwd + test_frule(partialsort, a, 3:5) + test_frule(partialsort, a, 4, fkwargs=(;rev=true)) + # rev test_rrule(partialsort, a, 4) test_rrule(partialsort, a, 3:5) test_rrule(partialsort, a, 1:2:6) @@ -14,6 +22,8 @@ end @testset "sortslices" begin + test_frule(sortslices, rand(3,4); fkwargs=(; dims=2)) + test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2)) test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last)) test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false) From 4971f995f1e2030a7a9830c6b951648f85d40bf3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 18 Jan 2022 09:04:04 -0500 Subject: [PATCH 06/18] in-place frules --- src/rulesets/Base/array.jl | 28 ++++++++++++++++++++++++++++ src/rulesets/Base/indexing.jl | 1 - src/rulesets/Base/mapreduce.jl | 5 +++++ src/rulesets/LinearAlgebra/dense.jl | 18 ++++++++++++++++++ test/rulesets/Base/array.jl | 26 +++++++++++++++++++++++--- test/rulesets/Base/mapreduce.jl | 5 +++++ test/rulesets/LinearAlgebra/dense.jl | 6 ++++++ 7 files changed, 85 insertions(+), 4 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 9c580a427..2539b6d55 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -71,6 +71,18 @@ frrule((_, xdd, _), ::typeof(_make_real_zeros), xdots, xs) = _make_real_zeros(xd rrule(::typeof(_make_real_zeros), xdots, xs) = _make_real_zeros(xdots, xs), dydots -> (NoTangent(), dydots, NoTangent()) +##### +##### `copyto!` +##### + +function frule((_, ydot, xdot), ::typeof(copyto!), y::AbstractArray, x) + return copyto!(y, x), copyto!(ydot, xdot) +end + +function frule((_, ydot, _, xdot), ::typeof(copyto!), y::AbstractArray, i::Integer, x, js::Integer...) + return copyto!(y, i, x, js...), copyto!(ydot, i, xdot, js...) +end + ##### ##### `reshape` ##### @@ -105,6 +117,10 @@ function frule((_, xdot), ::typeof(permutedims), x::AbstractArray, perm...) return permutedims(x, perm...), permutedims(xdot, perm...) end +function frule((_, ydot, xdot), ::typeof(permutedims!), y::AbstractArray, x::AbstractArray, perm...) + return permutedims!(y, x, perm...), permutedims!(ydot, xdot, perm...) +end + function rrule(::typeof(permutedims), x::AbstractVector) project = ProjectTo(x) permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy)))) @@ -389,6 +405,10 @@ function frule((_, xdot), ::typeof(reverse), x::Union{AbstractArray, Tuple}, arg return reverse(x, args...; kw...), reverse(xdot, args...; kw...) end +function frule((_, xdot), ::typeof(reverse!), x::Union{AbstractArray, Tuple}, args...; kw...) + return reverse!(x, args...; kw...), reverse!(xdot, args...; kw...) +end + function rrule(::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...) nots = map(Returns(NoTangent()), args) function reverse_pullback(dy) @@ -406,6 +426,10 @@ function frule((_, xdot), ::typeof(circshift), x::AbstractArray, shifts) return circshift(x, shifts), circshift(xdot, shifts) end +function frule((_, ydot, xdot), ::typeof(circshift!), y::AbstractArray, x::AbstractArray, shifts) + return circshift!(y, x, shifts), circshift!(ydot, xdot, shifts) +end + function rrule(::typeof(circshift), x::AbstractArray, shifts) function circshift_pullback(dy) dx = @thunk circshift(unthunk(dy), map(-, shifts)) @@ -423,6 +447,10 @@ function frule((_, xdot), ::typeof(fill), x::Any, dims...) return fill(x, dims...), fill(xdot, dims...) end +function frule((_, ydot, xdot), ::typeof(fill!), y::AbstractArray, x::Any) + return fill!(y, x), fill!(ydot, xdot) +end + function rrule(::typeof(fill), x::Any, dims...) project = ProjectTo(x) nots = map(Returns(NoTangent()), dims) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 2cee21dce..8c58e17df 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -44,7 +44,6 @@ end ##### function frule((_, xdot), ::typeof(setindex!), x::AbstractArray, v, inds...) - @info "setindex!" x v inds xdot v1 = x[inds...] = v v2 = xdot[inds...] = v return v1, v2 diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 21bf194be..fe4e68b68 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -5,10 +5,15 @@ function frule((_, ẋ), ::typeof(sum), x::Tuple) return sum(x), sum(ẋ) end + function frule((_, ẋ), ::typeof(sum), x; dims=:) return sum(x; dims=dims), sum(ẋ; dims=dims) end +function frule((_, ẏ, ẋ), ::typeof(sum!), y::AbstractArray, x::AbstractArray) + return sum!(y, x), sum!(ẏ, ẋ) +end + function rrule(::typeof(sum), x::AbstractArray; dims=:) project = ProjectTo(x) y = sum(x; dims=dims) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 3cb362900..690cac1bd 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -59,6 +59,24 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::Diagonal{<:Number} return z, dot_pullback end +##### +##### `mul!` +##### + +function frule((_, ΔC, ΔA, ΔB), ::typeof(mul!), C::AbstractArray, A, B) + mul!(C, A, B) + mul!(ΔC, ΔA, B) + mul!(ΔC, A, ΔB, true, true) + return C, ΔC +end + +# function frule((_, ΔC, ΔA, ΔB, Δα, Δβ), ::typeof(mul!), C::AbstractArray, A, B, α::Number, β::Number) +# mul!(C, A, B, α, β) +# mul!(ΔC, ΔA, B) +# mul!(ΔC, A, ΔB, true, true) +# return C, ΔC +# end + ##### ##### `cross` ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 26092d7e2..fe7c75ac9 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -51,6 +51,13 @@ end end end +@testset "copyto!" begin + test_frule(copyto!, rand(5), rand(5)) + test_frule(copyto!, rand(10), 3, rand(5)) + test_frule(copyto!, rand(10), 2, rand(5), 2) + test_frule(copyto!, rand(10), 2, rand(5), 2, 4) +end + @testset "reshape" begin # fwd test_frule(reshape, rand(4, 5), 2, :) @@ -61,10 +68,13 @@ end end @testset "permutedims + PermutedDimsArray" begin + # Forward test_frule(permutedims, rand(5)) - test_rrule(permutedims, rand(5)) - test_frule(permutedims, rand(3, 4), (2, 1)) + test_frule(permutedims!, rand(4,3), rand(3, 4), (2, 1)) + + # Reverse + test_rrule(permutedims, rand(5)) test_rrule(permutedims, rand(3, 4), (2, 1)) test_rrule(permutedims, Diagonal(rand(5)), (2, 1)) # Note BTW that permutedims(Diagonal(rand(5))) does not use the rule at all @@ -219,11 +229,14 @@ end test_frule(reverse, rand(5)) test_frule(reverse, rand(5), 2, 4) test_frule(reverse, rand(5), fkwargs=(dims=1,)) - test_frule(reverse, rand(3,4), fkwargs=(dims=2,)) test_frule(reverse, rand(3,4)) test_frule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),)) + test_frule(reverse!, rand(5)) + test_frule(reverse!, rand(5), 2, 4) + test_frule(reverse!, rand(3,4), fkwargs=(dims=2,)) + # Reverse test_rrule(reverse, rand(5)) test_rrule(reverse, rand(5), 2, 4) @@ -247,6 +260,9 @@ end test_frule(circshift, rand(10), (1,)) test_frule(circshift, rand(3,4), (-7,2)) + test_frule(circshift!, rand(10), rand(10), 1) + test_frule(circshift!, rand(3,4), rand(3,4), (-7,2)) + # Reverse test_rrule(circshift, rand(10), 1) test_rrule(circshift, rand(10) .+ im, -2) @@ -255,9 +271,13 @@ end end @testset "fill" begin + # Forward test_frule(fill, 12.3, 4) test_frule(fill, 5.0, (6, 7)) + test_frule(fill!, rand(2, 3), rand()) + + # Reverse test_rrule(fill, 44.4, 4) test_rrule(fill, 55 + 0.5im, 5) test_rrule(fill, 3.3, (3, 3, 3)) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index c9a740206..f07fc9b32 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -43,6 +43,11 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig() end end + @testset "sum!(y, x)" begin + test_frule(sum!, rand(3), rand(3, 5)) + test_frule(sum!, rand(ComplexF64, 1, 4), rand(3, 4)) + end + @testset "sum abs2" begin sizes = (3, 4, 7) @testset "dims = $dims" for dims in (:, 1) diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 601000d7e..8fc9d89df 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -35,6 +35,12 @@ end end + @testset "mul!" begin + test_frule(mul!, rand(4), rand(4, 5), rand(5)) + test_frule(mul!, rand(3, 3), rand(3, 3), rand(3, 3)) + test_frule(mul!, rand(3, 3), rand(), rand(3, 3)) + end + @testset "cross" begin test_frule(cross, randn(3), randn(3)) test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3)) From 793f47ef8c748fe53b4f32a0fd3f454c6d32855a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 18 Jan 2022 21:22:18 -0500 Subject: [PATCH 07/18] reshape + dropdims too --- src/rulesets/Base/array.jl | 33 +++++++++++++++++++-------------- test/rulesets/Base/array.jl | 26 +++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 2539b6d55..6bb08ca8b 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -91,22 +91,27 @@ function frule((_, xdot), ::typeof(reshape), x::AbstractArray, dims...) return reshape(x, dims...), reshape(xdot, dims...) end -function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Union{Colon,Int}}}) - A_dims = size(A) - function reshape_pullback(Ȳ) - return (NoTangent(), reshape(Ȳ, A_dims), NoTangent()) - end - return reshape(A, dims), reshape_pullback +function rrule(::typeof(reshape), A::AbstractArray, dims...) + ax = axes(A) + project = ProjectTo(A) # Projection is here for e.g. reshape(::Diagonal, :) + ∂dims = broadcast(Returns(NoTangent()), dims) + reshape_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax)), ∂dims...) + return reshape(A, dims...), reshape_pullback end -function rrule(::typeof(reshape), A::AbstractArray, dims::Union{Colon,Int}...) - A_dims = size(A) - function reshape_pullback(Ȳ) - ∂A = reshape(Ȳ, A_dims) - ∂dims = broadcast(Returns(NoTangent()), dims) - return (NoTangent(), ∂A, ∂dims...) - end - return reshape(A, dims...), reshape_pullback +##### +##### `dropdims` +##### + +function frule((_, xdot), ::typeof(dropdims), x::AbstractArray; dims) + return dropdims(x; dims=dims), dropdims(xdot; dims=dims) +end + +function rrule(::typeof(dropdims), A::AbstractArray; dims) + ax = axes(A) + project = ProjectTo(A) + dropdims_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax))) + return dropdims(A; dims=dims), dropdims_pullback end ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index fe7c75ac9..9028f7924 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -59,12 +59,32 @@ end end @testset "reshape" begin - # fwd - test_frule(reshape, rand(4, 5), 2, :) - # rev + # Forward + test_frule(reshape, rand(4, 3), 2, :) + test_rrule(reshape, rand(4, 3), axes(rand(6, 2))) + + # Reverse test_rrule(reshape, rand(4, 5), (2, 10)) test_rrule(reshape, rand(4, 5), 2, 10) test_rrule(reshape, rand(4, 5), 2, :) + test_rrule(reshape, rand(4, 5), axes(rand(10, 2))) + # structured + test_rrule(reshape, transpose(rand(4)), :) + test_rrule(reshape, adjoint(rand(ComplexF64, 4)), :) + @test rrule(reshape, adjoint(rand(ComplexF64, 4)), :)[2](rand(4))[2] isa Adjoint{ComplexF64} + @test rrule(reshape, Diagonal(rand(4)), (2, :))[2](ones(2,8))[2] isa Diagonal + @test_skip test_rrule(reshape, Diagonal(rand(4)), 2, :) # DimensionMismatch("second dimension of A, 22, does not match length of x, 16") + @test_skip test_rrule(reshape, UpperTriangular(rand(4,4)), (8, 2)) +end + +@testset "dropdims" begin + # fwd + test_frule(dropdims, rand(4, 1); fkwargs=(; dims=2)) + # rev + test_rrule(dropdims, rand(4, 1); fkwargs=(; dims=2)) + test_rrule(dropdims, transpose(rand(4)); fkwargs=(; dims=1)) + test_rrule(dropdims, adjoint(rand(ComplexF64, 4)); fkwargs=(; dims=1)) + @test rrule(dropdims, adjoint(rand(ComplexF64, 4)); dims=1)[2](rand(4))[2] isa Adjoint{ComplexF64} end @testset "permutedims + PermutedDimsArray" begin From d1b89496fa87d7071a4287a0cdbeae6743b25dbf Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 19 Jan 2022 23:22:53 -0500 Subject: [PATCH 08/18] tests --- src/rulesets/Base/indexing.jl | 8 ++++---- test/rulesets/Base/indexing.jl | 11 +++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 8c58e17df..0421aea96 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -43,10 +43,10 @@ end ##### setindex! ##### -function frule((_, xdot), ::typeof(setindex!), x::AbstractArray, v, inds...) - v1 = x[inds...] = v - v2 = xdot[inds...] = v - return v1, v2 +function frule((_, xdot, vdot), ::typeof(setindex!), x::AbstractArray, v, inds...) + w = x[inds...] = v + wdot = xdot[inds...] = vdot + return w, wdot end diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 094f26157..17fb6b1e5 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -57,3 +57,14 @@ end end end + +@testset "view" begin + test_frule(view, rand(3, 4), :, 1) + test_frule(view, rand(3, 4), 2, [1, 1, 2]) + test_frule(view, rand(3, 4), 3, 4) +end + +@testset "setindex!" begin + @test_skip test_frule(setindex!, rand(3, 4), rand(), 1, 2) + @test_skip test_frule(setindex!, rand(3, 4), [1,10,100.0], :, 3) +end From 3ea58fb242ff067a8f5c807aec217cba7de1837d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 21 Jan 2022 18:15:48 -0500 Subject: [PATCH 09/18] 5-arg mul --- src/rulesets/LinearAlgebra/dense.jl | 28 +++++++++++++++++++++------- test/rulesets/LinearAlgebra/dense.jl | 10 ++++++++++ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 690cac1bd..13d77d302 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -64,18 +64,32 @@ end ##### function frule((_, ΔC, ΔA, ΔB), ::typeof(mul!), C::AbstractArray, A, B) - mul!(C, A, B) + mul!(C, A, B) mul!(ΔC, ΔA, B) mul!(ΔC, A, ΔB, true, true) return C, ΔC end -# function frule((_, ΔC, ΔA, ΔB, Δα, Δβ), ::typeof(mul!), C::AbstractArray, A, B, α::Number, β::Number) -# mul!(C, A, B, α, β) -# mul!(ΔC, ΔA, B) -# mul!(ΔC, A, ΔB, true, true) -# return C, ΔC -# end +function frule((_, ΔC, ΔA, ΔB), ::typeof(mul!), C::AbstractArray, A, B, α::Bool, β::Bool) + # D = A*B*α + C*β + mul!(C, A, B, α, β) + # ΔD = ΔA*B*α + ΔC*β + A*ΔB*α + mul!(ΔC, ΔA, B, α, β) + mul!(ΔC, A, ΔB, α, true) + return C, ΔC +end + +function frule((_, ΔC, ΔA, ΔB, Δα, Δβ), ::typeof(mul!), C::AbstractArray, A, B, α::Number, β::Number) + # This is used twice: + AB = A * B + # ΔD = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α + @. ΔC = ΔC * β + C * Δβ + AB * Δα + mul!(ΔC, ΔA, B, α, true) + mul!(ΔC, A, ΔB, α, true) + # D = A*B*α + C*β + @. C = AB * α + C*β # Must be done last, as C enters above + return C, ΔC +end ##### ##### `cross` diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 8fc9d89df..d8eb50eb0 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -39,6 +39,16 @@ test_frule(mul!, rand(4), rand(4, 5), rand(5)) test_frule(mul!, rand(3, 3), rand(3, 3), rand(3, 3)) test_frule(mul!, rand(3, 3), rand(), rand(3, 3)) + + # Rule with α,β::Bool is only visually more complicated: + test_frule(mul!, rand(4), rand(4, 5), rand(5), true, true) + test_frule(mul!, rand(4), rand(4, 5), rand(5), false, true) + test_frule(mul!, rand(4), rand(4, 5), rand(5), true, false) + test_frule(mul!, rand(4), rand(4, 5), rand(5), false, false) + + # Rule with nontrivial α, β allocates A*B: + test_frule(mul!, rand(4), rand(4, 5), rand(5), true, randn()) + test_frule(mul!, rand(4), rand(4, 5), rand(5), randn(), randn()) end @testset "cross" begin From 8b11c864785df9e83b7a21526dcd0f2fa8f44ebb Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 24 Jan 2022 14:00:40 -0500 Subject: [PATCH 10/18] notation changes --- src/rulesets/Base/array.jl | 131 +++++++++++++++++---------------- src/rulesets/Base/arraymath.jl | 20 ++--- src/rulesets/Base/indexing.jl | 14 ++-- src/rulesets/Base/sort.jl | 12 +-- test/rulesets/Base/array.jl | 8 +- 5 files changed, 92 insertions(+), 93 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 6bb08ca8b..051de8b6c 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -4,8 +4,8 @@ ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...) -function frule((_, xdot), ::Type{T}, x::AbstractArray) where {T<:Array} - return T(x), T(xdot) +function frule((_, ẋ), ::Type{T}, x::AbstractArray) where {T<:Array} + return T(x), T(ẋ) end function rrule(::Type{T}, x::AbstractArray) where {T<:Array} @@ -20,8 +20,8 @@ end @non_differentiable Base.vect() -function frule((_, xdots...), ::typeof(Base.vect), xs::Number...) - return Base.vect(xs...), Base.vect(_make_real_zeros(xdots, xs)...) +function frule((_, ẋs...), ::typeof(Base.vect), xs::Number...) + return Base.vect(xs...), Base.vect(_instantiate_zeros(ẋs, xs)...) end # Case of uniform type `T`: the data passes straight through, @@ -52,43 +52,46 @@ function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N} end """ - _make_real_zeros(xdots, xs) + _instantiate_zeros(ẋs, xs) -Forward rules for `vect` or `cat` may receive a mixture of data and `ZeroTangent`s. -To avoid `vect(1, ZeroTangent(), 3)` or `hcat([1,2], ZeroTangent())`, this materialises -each zero `xdot` to be `zero(x)`. +Forward rules for `vect`, `cat` etc may receive a mixture of data and `ZeroTangent`s. +To avoid `vect(1, ZeroTangent(), 3)` or worse `vcat([1,2], ZeroTangent(), [6,7])`, this +materialises each zero `ẋ` to be `zero(x)`. """ -_make_real_zeros(xdots, xs) = map(_real_zero, xdots, xs) -_real_zero(xdot, x) = xdot -_real_zero(xdot::AbstractZero, x) = zero(x) +_instantiate_zeros(ẋs, xs) = map(_i_zero, ẋs, xs) +_i_zero(ẋ, x) = ẋ +_i_zero(ẋ::AbstractZero, x) = zero(x) +# Possibly this won't work for partly non-diff arrays, sometihng like `gradient(x -> ["abc", x][end], 1)` +# may give a MethodError for `zero` but won't be wrong. # Fast paths. Should it also collapse all-Zero cases? -_make_real_zeros(xdots::NTuple{<:Any, <:Number}, xs) = xdots -_make_real_zeros(xdots::AbstractArray{<:Number}, xs) = xdots -_make_real_zeros(xdots::AbstractArray{<:AbstractArray}, xs) = xdots +_instantiate_zeros(ẋs::NTuple{<:Any, <:Number}, xs) = ẋs +_instantiate_zeros(ẋs::NTuple{<:Any, <:AbstractArray}, xs) = ẋs +_instantiate_zeros(ẋs::AbstractArray{<:Number}, xs) = ẋs +_instantiate_zeros(ẋs::AbstractArray{<:AbstractArray}, xs) = ẋs -frrule((_, xdd, _), ::typeof(_make_real_zeros), xdots, xs) = _make_real_zeros(xdots, xs), xdd # not very sure! +frule((_, ẍs, _), ::typeof(_instantiate_zeros), ẋs, xs) = _instantiate_zeros(ẋs, xs), ẍs -rrule(::typeof(_make_real_zeros), xdots, xs) = _make_real_zeros(xdots, xs), dydots -> (NoTangent(), dydots, NoTangent()) +rrule(::typeof(_instantiate_zeros), ẋs, xs) = _instantiate_zeros(ẋs, xs), dẏs -> (NoTangent(), dẏs, NoTangent()) ##### ##### `copyto!` ##### -function frule((_, ydot, xdot), ::typeof(copyto!), y::AbstractArray, x) - return copyto!(y, x), copyto!(ydot, xdot) +function frule((_, ẏ, ẋ), ::typeof(copyto!), y::AbstractArray, x) + return copyto!(y, x), copyto!(ẏ, ẋ) end -function frule((_, ydot, _, xdot), ::typeof(copyto!), y::AbstractArray, i::Integer, x, js::Integer...) - return copyto!(y, i, x, js...), copyto!(ydot, i, xdot, js...) +function frule((_, ẏ, _, ẋ), ::typeof(copyto!), y::AbstractArray, i::Integer, x, js::Integer...) + return copyto!(y, i, x, js...), copyto!(ẏ, i, ẋ, js...) end ##### ##### `reshape` ##### -function frule((_, xdot), ::typeof(reshape), x::AbstractArray, dims...) - return reshape(x, dims...), reshape(xdot, dims...) +function frule((_, ẋ), ::typeof(reshape), x::AbstractArray, dims...) + return reshape(x, dims...), reshape(ẋ, dims...) end function rrule(::typeof(reshape), A::AbstractArray, dims...) @@ -103,8 +106,8 @@ end ##### `dropdims` ##### -function frule((_, xdot), ::typeof(dropdims), x::AbstractArray; dims) - return dropdims(x; dims=dims), dropdims(xdot; dims=dims) +function frule((_, ẋ), ::typeof(dropdims), x::AbstractArray; dims) + return dropdims(x; dims=dims), dropdims(ẋ; dims=dims) end function rrule(::typeof(dropdims), A::AbstractArray; dims) @@ -118,12 +121,12 @@ end ##### `permutedims` ##### -function frule((_, xdot), ::typeof(permutedims), x::AbstractArray, perm...) - return permutedims(x, perm...), permutedims(xdot, perm...) +function frule((_, ẋ), ::typeof(permutedims), x::AbstractArray, perm...) + return permutedims(x, perm...), permutedims(ẋ, perm...) end -function frule((_, ydot, xdot), ::typeof(permutedims!), y::AbstractArray, x::AbstractArray, perm...) - return permutedims!(y, x, perm...), permutedims!(ydot, xdot, perm...) +function frule((_, ẏ, ẋ), ::typeof(permutedims!), y::AbstractArray, x::AbstractArray, perm...) + return permutedims!(y, x, perm...), permutedims!(ẏ, ẋ, perm...) end function rrule(::typeof(permutedims), x::AbstractVector) @@ -148,8 +151,8 @@ end ##### `repeat` ##### -function frule((_, xsdot), ::typeof(repeat), xs::AbstractArray, cnt...; kw...) - return repeat(xs, cnt...; kw...), repeat(xsdot, cnt...; kw...) +function frule((_, ẋs), ::typeof(repeat), xs::AbstractArray, cnt...; kw...) + return repeat(xs, cnt...; kw...), repeat(ẋs, cnt...; kw...) end function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs))) @@ -191,8 +194,8 @@ end ##### `hcat` ##### -function frule((_, xdots...), ::typeof(hcat), xs...) - return hcat(xs...), hcat(_make_real_zeros(xdots, xs)...) +function frule((_, ẋs...), ::typeof(hcat), xs...) + return hcat(xs...), hcat(_instantiate_zeros(ẋs, xs)...) end function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...) @@ -229,8 +232,8 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...) return Y, hcat_pullback end -function frule((_, _, Adots), ::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat}) - return reduce(hcat, As), reduce(hcat, _make_real_zeros(Adots, As)) +function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat}) + return reduce(hcat, As), reduce(hcat, _instantiate_zeros(Ȧs, As)) end function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat}) @@ -261,8 +264,8 @@ end ##### `vcat` ##### -function frule((_, xdots...), ::typeof(vcat), xs...) - return vcat(xs...), vcat(_make_real_zeros(xdots, xs)...) +function frule((_, ẋs...), ::typeof(vcat), xs...) + return vcat(xs...), vcat(_instantiate_zeros(ẋs, xs)...) end function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...) @@ -297,8 +300,8 @@ function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...) return Y, vcat_pullback end -function frule((_, _, Adots), ::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat}) - return reduce(vcat, As), reduce(vcat, _make_real_zeros(Adots, As)) +function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat}) + return reduce(vcat, As), reduce(vcat, _instantiate_zeros(Ȧs, As)) end function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat}) @@ -324,8 +327,8 @@ end _val(::Val{x}) where {x} = x -function frule((_, xdots...), ::typeof(cat), xs...; dims) - return cat(xs...; dims), cat(_make_real_zeros(xdots, xs)...; dims) +function frule((_, ẋs...), ::typeof(cat), xs...; dims) + return cat(xs...; dims), cat(_instantiate_zeros(ẋs, xs)...; dims) end function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims) @@ -366,8 +369,8 @@ end ##### `hvcat` ##### -function frule((_, _, xdots...), ::typeof(hvcat), rows, xs...) - return hvcat(rows, xs...), hvcat(rows, _make_real_zeros(xdots, xs)...) +function frule((_, _, ẋs...), ::typeof(hvcat), rows, xs...) + return hvcat(rows, xs...), hvcat(rows, _instantiate_zeros(ẋs, xs)...) end function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...) @@ -406,12 +409,12 @@ end # 1-dim case allows start/stop, N-dim case takes dims keyword # whose defaults changed in Julia 1.6... just pass them all through: -function frule((_, xdot), ::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...) - return reverse(x, args...; kw...), reverse(xdot, args...; kw...) +function frule((_, ẋ), ::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...) + return reverse(x, args...; kw...), reverse(ẋ, args...; kw...) end -function frule((_, xdot), ::typeof(reverse!), x::Union{AbstractArray, Tuple}, args...; kw...) - return reverse!(x, args...; kw...), reverse!(xdot, args...; kw...) +function frule((_, ẋ), ::typeof(reverse!), x::Union{AbstractArray, Tuple}, args...; kw...) + return reverse!(x, args...; kw...), reverse!(ẋ, args...; kw...) end function rrule(::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...) @@ -427,12 +430,12 @@ end ##### `circshift` ##### -function frule((_, xdot), ::typeof(circshift), x::AbstractArray, shifts) - return circshift(x, shifts), circshift(xdot, shifts) +function frule((_, ẋ), ::typeof(circshift), x::AbstractArray, shifts) + return circshift(x, shifts), circshift(ẋ, shifts) end -function frule((_, ydot, xdot), ::typeof(circshift!), y::AbstractArray, x::AbstractArray, shifts) - return circshift!(y, x, shifts), circshift!(ydot, xdot, shifts) +function frule((_, ẏ, ẋ), ::typeof(circshift!), y::AbstractArray, x::AbstractArray, shifts) + return circshift!(y, x, shifts), circshift!(ẏ, ẋ, shifts) end function rrule(::typeof(circshift), x::AbstractArray, shifts) @@ -448,12 +451,12 @@ end ##### `fill` ##### -function frule((_, xdot), ::typeof(fill), x::Any, dims...) - return fill(x, dims...), fill(xdot, dims...) +function frule((_, ẋ), ::typeof(fill), x::Any, dims...) + return fill(x, dims...), fill(ẋ, dims...) end -function frule((_, ydot, xdot), ::typeof(fill!), y::AbstractArray, x::Any) - return fill!(y, x), fill!(ydot, xdot) +function frule((_, ẏ, ẋ), ::typeof(fill!), y::AbstractArray, x::Any) + return fill!(y, x), fill!(ẏ, ẋ) end function rrule(::typeof(fill), x::Any, dims...) @@ -467,9 +470,9 @@ end ##### `filter` ##### -function frule((_, _, xdot), ::typeof(filter), f, x::AbstractArray) +function frule((_, _, ẋ), ::typeof(filter), f, x::AbstractArray) inds = findall(f, x) - return x[inds], xdot[inds] + return x[inds], ẋ[inds] end function rrule(::typeof(filter), f, x::AbstractArray) @@ -489,9 +492,9 @@ end for findm in (:findmin, :findmax) findm_pullback = Symbol(findm, :_pullback) - @eval function frule((_, xdot), ::typeof($findm), x; dims=:) + @eval function frule((_, ẋ), ::typeof($findm), x; dims=:) y, ind = $findm(x; dims=dims) - return (y, ind), Tangent{typeof((y, ind))}(xdot[ind], NoTangent()) + return (y, ind), Tangent{typeof((y, ind))}(ẋ[ind], NoTangent()) end @eval function rrule(::typeof($findm), x::AbstractArray; dims=:) @@ -538,8 +541,8 @@ end # Allow for second derivatives, by writing rules for `_zerolike_writeat`; # these rules are the reason it takes a `dims` argument. -function frule((_, _, dydot), ::typeof(_zerolike_writeat), x, dy, dims, inds...) - return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dydot, dims, inds...) +function frule((_, _, dẏ), ::typeof(_zerolike_writeat), x, dy, dims, inds...) + return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dẏ, dims, inds...) end function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...) @@ -554,9 +557,9 @@ end # These rules for `maximum` pick the same subgradient as `findmax`: -function frule((_, xdot), ::typeof(maximum), x; dims=:) +function frule((_, ẋ), ::typeof(maximum), x; dims=:) y, ind = findmax(x; dims=dims) - return y, xdot[ind] + return y, ẋ[ind] end function rrule(::typeof(maximum), x::AbstractArray; dims=:) @@ -565,9 +568,9 @@ function rrule(::typeof(maximum), x::AbstractArray; dims=:) return y, maximum_pullback end -function frule((_, xdot), ::typeof(minimum), x; dims=:) +function frule((_, ẋ), ::typeof(minimum), x; dims=:) y, ind = findmin(x; dims=dims) - return y, xdot[ind] + return y, ẋ[ind] end function rrule(::typeof(minimum), x::AbstractArray; dims=:) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index e20d02acf..f1409515c 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -19,9 +19,9 @@ end ##### `*` ##### -frule((_, Adot, Bdot), ::typeof(*), A, B) = A * B, muladd(Adot, B, A * Bdot) +frule((_, ΔA, ΔB), ::typeof(*), A, B) = A * B, muladd(ΔA, B, A * ΔB) -frule((_, Adot, Bdot, Cdot), ::typeof(*), A, B, C) = A*B*C, Adot*B*C + A*Bdot*C + A*B*Cdot +frule((_, ΔA, ΔB, ΔC), ::typeof(*), A, B, C) = A*B*C, ΔA*B*C + A*ΔB*C + A*B*ΔC function rrule( @@ -210,9 +210,9 @@ end # VERSION ##### `muladd` ##### -function frule((_, Adot, Bdot, zdot), ::typeof(muladd), A, B, z) +function frule((_, ΔA, ΔB, Δz), ::typeof(muladd), A, B, z) Ω = muladd(A, B, z) - return Ω, Adot * B .+ A * Bdot .+ zdot + return Ω, ΔA * B .+ A * ΔB .+ Δz end function rrule( @@ -362,11 +362,11 @@ end ##### `\`, `/` matrix-scalar_rule ##### -function frule((_, Adot, bdot), ::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber) - return A/b, Adot/b - A*(bdot/b^2) +function frule((_, ΔA, Δb), ::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber) + return A/b, ΔA/b - A*(Δb/b^2) end -function frule((_, adot, Bdot), ::typeof(\), a::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}) - return B/a, Bdot/a - B*(adot/a^2) +function frule((_, Δa, ΔB), ::typeof(\), a::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}) + return B/a, ΔB/a - B*(Δa/a^2) end function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber) @@ -396,7 +396,7 @@ end ##### Negation (Unary -) ##### -frule((_, Adot), ::typeof(-), A::AbstractArray) = -A, -Adot +frule((_, ΔA), ::typeof(-), A::AbstractArray) = -A, -ΔA function rrule(::typeof(-), x::AbstractArray) function negation_pullback(ȳ) @@ -410,7 +410,7 @@ end ##### Addition (Multiarg `+`) ##### -frule((_, Adots...), ::typeof(+), As::AbstractArray...) = +(As...), +(Adots...) +frule((_, ΔAs...), ::typeof(+), As::AbstractArray...) = +(As...), +(ΔAs...) function rrule(::typeof(+), arrs::AbstractArray...) y = +(arrs...) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 0421aea96..ab815d33d 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -2,8 +2,8 @@ ##### getindex ##### -function frule((_, xdot), ::typeof(getindex), x::AbstractArray, inds...) - return x[inds...], xdot[inds...] +function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...) + return x[inds...], ẋ[inds...] end function rrule(::typeof(getindex), x::Array{<:Number}, inds...) @@ -35,18 +35,16 @@ end ##### view ##### -function frule((_, xdot), ::typeof(view), x::AbstractArray, inds...) - return view(x, inds...), view(xdot, inds...) +function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...) + return view(x, inds...), view(ẋ, inds...) end ##### ##### setindex! ##### -function frule((_, xdot, vdot), ::typeof(setindex!), x::AbstractArray, v, inds...) - w = x[inds...] = v - wdot = xdot[inds...] = vdot - return w, wdot +function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...) + return setindex!(x, v, inds...), setindex!(ẋ, v̇, inds...) end diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index b6939e33f..42f674a24 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -2,9 +2,9 @@ ##### `sort` ##### -function frule((_, xsdot, _), ::typeof(partialsort), xs::AbstractVector, k; kw...) +function frule((_, ẋs, _), ::typeof(partialsort), xs::AbstractVector, k; kw...) inds = partialsortperm(xs, k; kw...) - return xs[inds], xsdot[inds] + return xs[inds], ẋs[inds] end function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,OrdinalRange}; kwargs...) @@ -25,9 +25,9 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin return ys, partialsort_pullback end -function frule((_, xsdot), ::typeof(sort), xs::AbstractVector; kw...) +function frule((_, ẋs), ::typeof(sort), xs::AbstractVector; kw...) inds = sortperm(xs; kw...) - return xs[inds], xsdot[inds] + return xs[inds], ẋs[inds] end function rrule(::typeof(sort), xs::AbstractVector; kwargs...) @@ -52,10 +52,10 @@ end ##### `sortslices` ##### -function frule((_, xdot), ::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) +function frule((_, ẋ), ::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) p = sortperm(collect(eachslice(x; dims=dims)); kw...) inds = ntuple(d -> d == dims ? p : (:), ndims(x)) - return x[inds...], xdot[inds...] + return x[inds...], ẋ[inds...] end function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 9028f7924..64470c23c 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -41,13 +41,11 @@ end test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false) test_rrule(Base.vect, (5.0, 4.0), (y=randn(3),); check_inferred=false) end - @testset "_make_real_zeros" begin + @testset "_instantiate_zeros" begin # This is an internal function also used for `cat` etc. # It has its own rules to allow for 2nd derivatives. - @eval using ChainRules: _make_real_zeros - @test_skip test_frule(_make_real_zeros, Tuple(rand(3)), Tuple(rand(3))) - @test_skip test_rrule(_make_real_zeros, Tuple(rand(3)), Tuple(rand(3))) - # Not sure these are defined right! Currently fail due to https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/229 + @eval using ChainRules: _instantiate_zeros + test_frule(_instantiate_zeros, Tuple(rand(3)), Tuple(rand(3))) end end From c509d5e35fd54e9e142d46443769f82258712960 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 24 Jan 2022 14:33:20 -0500 Subject: [PATCH 11/18] rm 2nd order rules --- src/rulesets/Base/array.jl | 8 ++------ test/rulesets/Base/array.jl | 5 +++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 051de8b6c..b42e45304 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -65,15 +65,11 @@ _i_zero(ẋ::AbstractZero, x) = zero(x) # may give a MethodError for `zero` but won't be wrong. # Fast paths. Should it also collapse all-Zero cases? -_instantiate_zeros(ẋs::NTuple{<:Any, <:Number}, xs) = ẋs -_instantiate_zeros(ẋs::NTuple{<:Any, <:AbstractArray}, xs) = ẋs +_instantiate_zeros(ẋs::Tuple{Vararg{<:Number}}, xs) = ẋs +_instantiate_zeros(ẋs::Tuple{Vararg{<:AbstractArray}}, xs) = ẋs _instantiate_zeros(ẋs::AbstractArray{<:Number}, xs) = ẋs _instantiate_zeros(ẋs::AbstractArray{<:AbstractArray}, xs) = ẋs -frule((_, ẍs, _), ::typeof(_instantiate_zeros), ẋs, xs) = _instantiate_zeros(ẋs, xs), ẍs - -rrule(::typeof(_instantiate_zeros), ẋs, xs) = _instantiate_zeros(ẋs, xs), dẏs -> (NoTangent(), dẏs, NoTangent()) - ##### ##### `copyto!` ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 64470c23c..0f2f3d4c0 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -43,9 +43,10 @@ end end @testset "_instantiate_zeros" begin # This is an internal function also used for `cat` etc. - # It has its own rules to allow for 2nd derivatives. @eval using ChainRules: _instantiate_zeros - test_frule(_instantiate_zeros, Tuple(rand(3)), Tuple(rand(3))) + # Check these hit the fast path, unrealistic input so that map would fail: + @test _instantiate_zeros((true, 2 , 3.0), ()) == (1, 2, 3) + @test _instantiate_zeros((1:2, [3, 4]), ()) == (1:2, 3:4) end end From 51393776ce22279de5add338814a73e0226ec2df Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 24 Jan 2022 14:45:16 -0500 Subject: [PATCH 12/18] don't skip setindex --- test/rulesets/Base/indexing.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 17fb6b1e5..74a4032f3 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -65,6 +65,6 @@ end end @testset "setindex!" begin - @test_skip test_frule(setindex!, rand(3, 4), rand(), 1, 2) - @test_skip test_frule(setindex!, rand(3, 4), [1,10,100.0], :, 3) + test_frule(setindex!, rand(3, 4), rand(), 1, 2) + test_frule(setindex!, rand(3, 4), [1,10,100.0], :, 3) end From ae9d76b556edca5c1c142e1192e48bb36f609497 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 24 Jan 2022 15:08:57 -0500 Subject: [PATCH 13/18] AbstractArray constructors --- src/rulesets/Base/array.jl | 11 +++++++++++ test/rulesets/Base/array.jl | 13 +++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index b42e45304..6272c6bbb 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -8,12 +8,23 @@ function frule((_, ẋ), ::Type{T}, x::AbstractArray) where {T<:Array} return T(x), T(ẋ) end +function frule((_, ẋ), ::Type{AbstractArray{T}}, x::AbstractArray) where {T} + return AbstractArray{T}(x), AbstractArray{T}(ẋ) +end + function rrule(::Type{T}, x::AbstractArray) where {T<:Array} project_x = ProjectTo(x) Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) return T(x), Array_pullback end +# This abstract one is used for `float(x)` and other float conversion purposes: +function rrule(::Type{AbstractArray{T}}, x::AbstractArray) where {T} + project_x = ProjectTo(x) + AbstractArray_pullback(ȳ) = (NoTangent(), project_x(ȳ)) + return AbstractArray{T}(x), AbstractArray_pullback +end + ##### ##### `vect` ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 0f2f3d4c0..bd193be0d 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -1,10 +1,9 @@ @testset "Array constructors" begin - + @testset "undef" begin # We can't use test_rrule here (as it's currently implemented) because the elements of # the array have arbitrary values. The only thing we can do is ensure that we're getting # `ZeroTangent`s back, and that the forwards pass produces the correct thing still. # Issue: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/202 - @testset "undef" begin val, pullback = rrule(Array{Float64}, undef, 5) @test size(val) == (5, ) @test val isa Array{Float64, 1} @@ -22,6 +21,16 @@ test_rrule(Array{ComplexF64}, randn(3)) end end +@testset "AbstractArray constructors" begin + # These are what float(x) calls, but it's trivial with floating point numbers: + test_frule(AbstractArray{Float32}, rand(3); atol=0.01) + test_frule(AbstractArray{Float32}, Diagonal(rand(4)); atol=0.01) + # rev + test_rrule(AbstractArray{Float32}, rand(3); atol=0.01) + test_rrule(AbstractArray{Float32}, Diagonal(rand(4)); atol=0.01) + # Check with integers: + rrule(AbstractArray{Float64}, [1, 2, 3])[2]([1, 10, 100]) == (NoTangent(), [1.0, 10.0, 100.0]) +end @testset "vect" begin test_rrule(Base.vect) From b64265ef9e84d3964b26f08ce224bce6098521be Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 24 Jan 2022 15:35:15 -0500 Subject: [PATCH 14/18] reshape tests --- test/rulesets/Base/array.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index bd193be0d..a43f3b65e 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -69,7 +69,8 @@ end @testset "reshape" begin # Forward test_frule(reshape, rand(4, 3), 2, :) - test_rrule(reshape, rand(4, 3), axes(rand(6, 2))) + test_frule(reshape, rand(4, 3), axes(rand(6, 2))) + @test_skip test_frule(reshape, Diagonal(rand(4)), 2, :) # Reverse test_rrule(reshape, rand(4, 5), (2, 10)) From 6f047c32051bc689157b4c743d797e545f2fb6d6 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 25 Jan 2022 13:35:24 -0500 Subject: [PATCH 15/18] Apply 4 suggestions Co-authored-by: Lyndon White --- src/rulesets/Base/array.jl | 4 ++-- src/rulesets/Base/sort.jl | 1 + test/rulesets/Base/indexing.jl | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 6272c6bbb..fdd096c6e 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -114,14 +114,14 @@ end ##### function frule((_, ẋ), ::typeof(dropdims), x::AbstractArray; dims) - return dropdims(x; dims=dims), dropdims(ẋ; dims=dims) + return dropdims(x; dims), dropdims(ẋ; dims) end function rrule(::typeof(dropdims), A::AbstractArray; dims) ax = axes(A) project = ProjectTo(A) dropdims_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax))) - return dropdims(A; dims=dims), dropdims_pullback + return dropdims(A; dims), dropdims_pullback end ##### diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 42f674a24..1a5a22232 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -54,6 +54,7 @@ end function frule((_, ẋ), ::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) p = sortperm(collect(eachslice(x; dims=dims)); kw...) + firstindex(x, d) == 1 || throw(ArgumentError("The `rrule` for `sortslices` does not at present handle offset indices here.")) inds = ntuple(d -> d == dims ? p : (:), ndims(x)) return x[inds...], ẋ[inds...] end diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 74a4032f3..b9219825e 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -7,8 +7,8 @@ test_frule(getindex, x, 2, 1) test_frule(getindex, x, CartesianIndex(2, 3)) - test_rrule(getindex, x, 2:3) - test_rrule(getindex, x, (:), 2:3) + test_frule(getindex, x, 2:3) + test_frule(getindex, x, (:), 2:3) end @testset "single element" begin From 016064f1e06e30b17adafa8d4174898639716582 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 25 Jan 2022 14:53:15 -0500 Subject: [PATCH 16/18] fixup, bump --- Project.toml | 2 +- src/rulesets/Base/sort.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 6c3f90242..52a02566c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.23" +version = "1.24" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 1a5a22232..504feb613 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -54,7 +54,7 @@ end function frule((_, ẋ), ::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) p = sortperm(collect(eachslice(x; dims=dims)); kw...) - firstindex(x, d) == 1 || throw(ArgumentError("The `rrule` for `sortslices` does not at present handle offset indices here.")) + firstindex(x, dims) == 1 || throw(ArgumentError("The `rrule` for `sortslices` does not at present handle offset indices here.")) inds = ntuple(d -> d == dims ? p : (:), ndims(x)) return x[inds...], ẋ[inds...] end From a773f9a37f702594f5a46ead994638ba8ca291c3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 25 Jan 2022 16:10:29 -0500 Subject: [PATCH 17/18] several comments, and one rule for PermutedDimsArray --- src/rulesets/Base/array.jl | 4 ++++ test/rulesets/Base/array.jl | 7 ++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index fdd096c6e..6af3cde2a 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -136,6 +136,10 @@ function frule((_, ẏ, ẋ), ::typeof(permutedims!), y::AbstractArray, x::Abstr return permutedims!(y, x, perm...), permutedims!(ẏ, ẋ, perm...) end +function frule((_, ẋ), ::Type{<:PermutedDimsArray}, x::AbstractArray, perm) + return PermutedDimsArray(x, perm), PermutedDimsArray(ẋ, perm) +end + function rrule(::typeof(permutedims), x::AbstractVector) project = ProjectTo(x) permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy)))) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index a43f3b65e..87c352ab1 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -70,7 +70,7 @@ end # Forward test_frule(reshape, rand(4, 3), 2, :) test_frule(reshape, rand(4, 3), axes(rand(6, 2))) - @test_skip test_frule(reshape, Diagonal(rand(4)), 2, :) + @test_skip test_frule(reshape, Diagonal(rand(4)), 2, :) # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/239 # Reverse test_rrule(reshape, rand(4, 5), (2, 10)) @@ -83,7 +83,7 @@ end @test rrule(reshape, adjoint(rand(ComplexF64, 4)), :)[2](rand(4))[2] isa Adjoint{ComplexF64} @test rrule(reshape, Diagonal(rand(4)), (2, :))[2](ones(2,8))[2] isa Diagonal @test_skip test_rrule(reshape, Diagonal(rand(4)), 2, :) # DimensionMismatch("second dimension of A, 22, does not match length of x, 16") - @test_skip test_rrule(reshape, UpperTriangular(rand(4,4)), (8, 2)) + @test_skip test_rrule(reshape, UpperTriangular(rand(4,4)), (8, 2)) # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/239 end @testset "dropdims" begin @@ -101,6 +101,7 @@ end test_frule(permutedims, rand(5)) test_frule(permutedims, rand(3, 4), (2, 1)) test_frule(permutedims!, rand(4,3), rand(3, 4), (2, 1)) + test_frule(PermutedDimsArray, rand(3, 4, 5), (3, 1, 2)) # Reverse test_rrule(permutedims, rand(5)) @@ -111,7 +112,7 @@ end @test invperm((3, 1, 2)) != (3, 1, 2) test_rrule(permutedims, rand(3, 4, 5), (3, 1, 2)) - @test_skip test_rrule(PermutedDimsArray, rand(3, 4, 5), (3, 1, 2)) + @test_skip test_rrule(PermutedDimsArray, rand(3, 4, 5), (3, 1, 2)) # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/240 x = rand(2, 3, 4) dy = rand(4, 2, 3) @test rrule(permutedims, x, (3, 1, 2))[2](dy)[2] == rrule(PermutedDimsArray, x, (3, 1, 2))[2](dy)[2] From 327bf263463cc9a57373e2c109ce7beb8352e16b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 25 Jan 2022 16:19:25 -0500 Subject: [PATCH 18/18] in fact sortslices is fine with offsets --- src/rulesets/Base/sort.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 504feb613..42f674a24 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -54,7 +54,6 @@ end function frule((_, ẋ), ::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) p = sortperm(collect(eachslice(x; dims=dims)); kw...) - firstindex(x, dims) == 1 || throw(ArgumentError("The `rrule` for `sortslices` does not at present handle offset indices here.")) inds = ntuple(d -> d == dims ? p : (:), ndims(x)) return x[inds...], ẋ[inds...] end