diff --git a/Project.toml b/Project.toml index 6c3f90242..52a02566c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.23" +version = "1.24" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index b6ae01ce5..6af3cde2a 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -4,18 +4,37 @@ ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...) +function frule((_, ẋ), ::Type{T}, x::AbstractArray) where {T<:Array} + return T(x), T(ẋ) +end + +function frule((_, ẋ), ::Type{AbstractArray{T}}, x::AbstractArray) where {T} + return AbstractArray{T}(x), AbstractArray{T}(ẋ) +end + function rrule(::Type{T}, x::AbstractArray) where {T<:Array} project_x = ProjectTo(x) Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) return T(x), Array_pullback end +# This abstract one is used for `float(x)` and other float conversion purposes: +function rrule(::Type{AbstractArray{T}}, x::AbstractArray) where {T} + project_x = ProjectTo(x) + AbstractArray_pullback(ȳ) = (NoTangent(), project_x(ȳ)) + return AbstractArray{T}(x), AbstractArray_pullback +end + ##### ##### `vect` ##### @non_differentiable Base.vect() +function frule((_, ẋs...), ::typeof(Base.vect), xs::Number...) + return Base.vect(xs...), Base.vect(_instantiate_zeros(ẋs, xs)...) +end + # Case of uniform type `T`: the data passes straight through, # so no projection should be required. function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N} @@ -43,32 +62,84 @@ function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N} return Base.vect(X...), vect_pullback end +""" + _instantiate_zeros(ẋs, xs) + +Forward rules for `vect`, `cat` etc may receive a mixture of data and `ZeroTangent`s. +To avoid `vect(1, ZeroTangent(), 3)` or worse `vcat([1,2], ZeroTangent(), [6,7])`, this +materialises each zero `ẋ` to be `zero(x)`. +""" +_instantiate_zeros(ẋs, xs) = map(_i_zero, ẋs, xs) +_i_zero(ẋ, x) = ẋ +_i_zero(ẋ::AbstractZero, x) = zero(x) +# Possibly this won't work for partly non-diff arrays, sometihng like `gradient(x -> ["abc", x][end], 1)` +# may give a MethodError for `zero` but won't be wrong. + +# Fast paths. Should it also collapse all-Zero cases? +_instantiate_zeros(ẋs::Tuple{Vararg{<:Number}}, xs) = ẋs +_instantiate_zeros(ẋs::Tuple{Vararg{<:AbstractArray}}, xs) = ẋs +_instantiate_zeros(ẋs::AbstractArray{<:Number}, xs) = ẋs +_instantiate_zeros(ẋs::AbstractArray{<:AbstractArray}, xs) = ẋs + +##### +##### `copyto!` +##### + +function frule((_, ẏ, ẋ), ::typeof(copyto!), y::AbstractArray, x) + return copyto!(y, x), copyto!(ẏ, ẋ) +end + +function frule((_, ẏ, _, ẋ), ::typeof(copyto!), y::AbstractArray, i::Integer, x, js::Integer...) + return copyto!(y, i, x, js...), copyto!(ẏ, i, ẋ, js...) +end + ##### ##### `reshape` ##### -function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Union{Colon,Int}}}) - A_dims = size(A) - function reshape_pullback(Ȳ) - return (NoTangent(), reshape(Ȳ, A_dims), NoTangent()) - end - return reshape(A, dims), reshape_pullback +function frule((_, ẋ), ::typeof(reshape), x::AbstractArray, dims...) + return reshape(x, dims...), reshape(ẋ, dims...) end -function rrule(::typeof(reshape), A::AbstractArray, dims::Union{Colon,Int}...) - A_dims = size(A) - function reshape_pullback(Ȳ) - ∂A = reshape(Ȳ, A_dims) - ∂dims = broadcast(Returns(NoTangent()), dims) - return (NoTangent(), ∂A, ∂dims...) - end +function rrule(::typeof(reshape), A::AbstractArray, dims...) + ax = axes(A) + project = ProjectTo(A) # Projection is here for e.g. reshape(::Diagonal, :) + ∂dims = broadcast(Returns(NoTangent()), dims) + reshape_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax)), ∂dims...) return reshape(A, dims...), reshape_pullback end +##### +##### `dropdims` +##### + +function frule((_, ẋ), ::typeof(dropdims), x::AbstractArray; dims) + return dropdims(x; dims), dropdims(ẋ; dims) +end + +function rrule(::typeof(dropdims), A::AbstractArray; dims) + ax = axes(A) + project = ProjectTo(A) + dropdims_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax))) + return dropdims(A; dims), dropdims_pullback +end + ##### ##### `permutedims` ##### +function frule((_, ẋ), ::typeof(permutedims), x::AbstractArray, perm...) + return permutedims(x, perm...), permutedims(ẋ, perm...) +end + +function frule((_, ẏ, ẋ), ::typeof(permutedims!), y::AbstractArray, x::AbstractArray, perm...) + return permutedims!(y, x, perm...), permutedims!(ẏ, ẋ, perm...) +end + +function frule((_, ẋ), ::Type{<:PermutedDimsArray}, x::AbstractArray, perm) + return PermutedDimsArray(x, perm), PermutedDimsArray(ẋ, perm) +end + function rrule(::typeof(permutedims), x::AbstractVector) project = ProjectTo(x) permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy)))) @@ -91,6 +162,10 @@ end ##### `repeat` ##### +function frule((_, ẋs), ::typeof(repeat), xs::AbstractArray, cnt...; kw...) + return repeat(xs, cnt...; kw...), repeat(ẋs, cnt...; kw...) +end + function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs))) project_Xs = ProjectTo(xs) @@ -130,6 +205,10 @@ end ##### `hcat` ##### +function frule((_, ẋs...), ::typeof(hcat), xs...) + return hcat(xs...), hcat(_instantiate_zeros(ẋs, xs)...) +end + function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...) Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability @@ -164,6 +243,10 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...) return Y, hcat_pullback end +function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat}) + return reduce(hcat, As), reduce(hcat, _instantiate_zeros(Ȧs, As)) +end + function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat}) widths = map(A -> size(A,2), As) function reduce_hcat_pullback_2(dY) @@ -192,6 +275,10 @@ end ##### `vcat` ##### +function frule((_, ẋs...), ::typeof(vcat), xs...) + return vcat(xs...), vcat(_instantiate_zeros(ẋs, xs)...) +end + function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...) Y = vcat(Xs...) ndimsY = Val(ndims(Y)) @@ -224,6 +311,10 @@ function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...) return Y, vcat_pullback end +function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat}) + return reduce(vcat, As), reduce(vcat, _instantiate_zeros(Ȧs, As)) +end + function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat}) Y = reduce(vcat, As) ndimsY = Val(ndims(Y)) @@ -247,6 +338,10 @@ end _val(::Val{x}) where {x} = x +function frule((_, ẋs...), ::typeof(cat), xs...; dims) + return cat(xs...; dims), cat(_instantiate_zeros(ẋs, xs)...; dims) +end + function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims) Y = cat(Xs...; dims=dims) cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims) @@ -285,6 +380,10 @@ end ##### `hvcat` ##### +function frule((_, _, ẋs...), ::typeof(hvcat), rows, xs...) + return hvcat(rows, xs...), hvcat(rows, _instantiate_zeros(ẋs, xs)...) +end + function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...) Y = hvcat(rows, values...) cols = size(Y,2) @@ -321,8 +420,12 @@ end # 1-dim case allows start/stop, N-dim case takes dims keyword # whose defaults changed in Julia 1.6... just pass them all through: -function frule((_, xdot), ::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...) - return reverse(x, args...; kw...), reverse(xdot, args...; kw...) +function frule((_, ẋ), ::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...) + return reverse(x, args...; kw...), reverse(ẋ, args...; kw...) +end + +function frule((_, ẋ), ::typeof(reverse!), x::Union{AbstractArray, Tuple}, args...; kw...) + return reverse!(x, args...; kw...), reverse!(ẋ, args...; kw...) end function rrule(::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...) @@ -338,8 +441,12 @@ end ##### `circshift` ##### -function frule((_, xdot), ::typeof(circshift), x::AbstractArray, shifts) - return circshift(x, shifts), circshift(xdot, shifts) +function frule((_, ẋ), ::typeof(circshift), x::AbstractArray, shifts) + return circshift(x, shifts), circshift(ẋ, shifts) +end + +function frule((_, ẏ, ẋ), ::typeof(circshift!), y::AbstractArray, x::AbstractArray, shifts) + return circshift!(y, x, shifts), circshift!(ẏ, ẋ, shifts) end function rrule(::typeof(circshift), x::AbstractArray, shifts) @@ -355,8 +462,12 @@ end ##### `fill` ##### -function frule((_, xdot), ::typeof(fill), x::Any, dims...) - return fill(x, dims...), fill(xdot, dims...) +function frule((_, ẋ), ::typeof(fill), x::Any, dims...) + return fill(x, dims...), fill(ẋ, dims...) +end + +function frule((_, ẏ, ẋ), ::typeof(fill!), y::AbstractArray, x::Any) + return fill!(y, x), fill!(ẏ, ẋ) end function rrule(::typeof(fill), x::Any, dims...) @@ -370,9 +481,9 @@ end ##### `filter` ##### -function frule((_, _, xdot), ::typeof(filter), f, x::AbstractArray) +function frule((_, _, ẋ), ::typeof(filter), f, x::AbstractArray) inds = findall(f, x) - return x[inds], xdot[inds] + return x[inds], ẋ[inds] end function rrule(::typeof(filter), f, x::AbstractArray) @@ -392,9 +503,9 @@ end for findm in (:findmin, :findmax) findm_pullback = Symbol(findm, :_pullback) - @eval function frule((_, xdot), ::typeof($findm), x; dims=:) + @eval function frule((_, ẋ), ::typeof($findm), x; dims=:) y, ind = $findm(x; dims=dims) - return (y, ind), Tangent{typeof((y, ind))}(xdot[ind], NoTangent()) + return (y, ind), Tangent{typeof((y, ind))}(ẋ[ind], NoTangent()) end @eval function rrule(::typeof($findm), x::AbstractArray; dims=:) @@ -441,8 +552,8 @@ end # Allow for second derivatives, by writing rules for `_zerolike_writeat`; # these rules are the reason it takes a `dims` argument. -function frule((_, _, dydot), ::typeof(_zerolike_writeat), x, dy, dims, inds...) - return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dydot, dims, inds...) +function frule((_, _, dẏ), ::typeof(_zerolike_writeat), x, dy, dims, inds...) + return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dẏ, dims, inds...) end function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...) @@ -457,9 +568,9 @@ end # These rules for `maximum` pick the same subgradient as `findmax`: -function frule((_, xdot), ::typeof(maximum), x; dims=:) +function frule((_, ẋ), ::typeof(maximum), x; dims=:) y, ind = findmax(x; dims=dims) - return y, xdot[ind] + return y, ẋ[ind] end function rrule(::typeof(maximum), x::AbstractArray; dims=:) @@ -468,9 +579,9 @@ function rrule(::typeof(maximum), x::AbstractArray; dims=:) return y, maximum_pullback end -function frule((_, xdot), ::typeof(minimum), x; dims=:) +function frule((_, ẋ), ::typeof(minimum), x; dims=:) y, ind = findmin(x; dims=dims) - return y, xdot[ind] + return y, ẋ[ind] end function rrule(::typeof(minimum), x::AbstractArray; dims=:) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 9b442c502..f1409515c 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -19,6 +19,10 @@ end ##### `*` ##### +frule((_, ΔA, ΔB), ::typeof(*), A, B) = A * B, muladd(ΔA, B, A * ΔB) + +frule((_, ΔA, ΔB, ΔC), ::typeof(*), A, B, C) = A*B*C, ΔA*B*C + A*ΔB*C + A*B*ΔC + function rrule( ::typeof(*), @@ -88,7 +92,9 @@ function rrule( end - +##### +##### `*` matrix-scalar_rule +##### function rrule( ::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber} @@ -204,6 +210,11 @@ end # VERSION ##### `muladd` ##### +function frule((_, ΔA, ΔB, Δz), ::typeof(muladd), A, B, z) + Ω = muladd(A, B, z) + return Ω, ΔA * B .+ A * ΔB .+ Δz +end + function rrule( ::typeof(muladd), A::AbstractMatrix{<:CommutativeMulNumber}, @@ -351,6 +362,13 @@ end ##### `\`, `/` matrix-scalar_rule ##### +function frule((_, ΔA, Δb), ::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber) + return A/b, ΔA/b - A*(Δb/b^2) +end +function frule((_, Δa, ΔB), ::typeof(\), a::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}) + return B/a, ΔB/a - B*(Δa/a^2) +end + function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber) Y = A/b function slash_pullback_scalar(ȳ) @@ -378,6 +396,8 @@ end ##### Negation (Unary -) ##### +frule((_, ΔA), ::typeof(-), A::AbstractArray) = -A, -ΔA + function rrule(::typeof(-), x::AbstractArray) function negation_pullback(ȳ) return NoTangent(), InplaceableThunk(ā -> ā .-= ȳ, @thunk(-ȳ)) @@ -390,6 +410,8 @@ end ##### Addition (Multiarg `+`) ##### +frule((_, ΔAs...), ::typeof(+), As::AbstractArray...) = +(As...), +(ΔAs...) + function rrule(::typeof(+), arrs::AbstractArray...) y = +(arrs...) arr_axs = map(axes, arrs) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 65dd55631..ab815d33d 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -2,6 +2,10 @@ ##### getindex ##### +function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...) + return x[inds...], ẋ[inds...] +end + function rrule(::typeof(getindex), x::Array{<:Number}, inds...) # removes any logical indexing, CartesianIndex etc # leaving us just with a tuple of Int, Arrays of Int and Ranges of Int @@ -27,3 +31,21 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...) return y, getindex_pullback end +##### +##### view +##### + +function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...) + return view(x, inds...), view(ẋ, inds...) +end + +##### +##### setindex! +##### + +function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...) + return setindex!(x, v, inds...), setindex!(ẋ, v̇, inds...) +end + + + diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 21bf194be..fe4e68b68 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -5,10 +5,15 @@ function frule((_, ẋ), ::typeof(sum), x::Tuple) return sum(x), sum(ẋ) end + function frule((_, ẋ), ::typeof(sum), x; dims=:) return sum(x; dims=dims), sum(ẋ; dims=dims) end +function frule((_, ẏ, ẋ), ::typeof(sum!), y::AbstractArray, x::AbstractArray) + return sum!(y, x), sum!(ẏ, ẋ) +end + function rrule(::typeof(sum), x::AbstractArray; dims=:) project = ProjectTo(x) y = sum(x; dims=dims) diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index be7840c8c..42f674a24 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -2,6 +2,11 @@ ##### `sort` ##### +function frule((_, ẋs, _), ::typeof(partialsort), xs::AbstractVector, k; kw...) + inds = partialsortperm(xs, k; kw...) + return xs[inds], ẋs[inds] +end + function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,OrdinalRange}; kwargs...) inds = partialsortperm(xs, k; kwargs...) ys = xs[inds] @@ -20,6 +25,11 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin return ys, partialsort_pullback end +function frule((_, ẋs), ::typeof(sort), xs::AbstractVector; kw...) + inds = sortperm(xs; kw...) + return xs[inds], ẋs[inds] +end + function rrule(::typeof(sort), xs::AbstractVector; kwargs...) inds = sortperm(xs; kwargs...) ys = xs[inds] @@ -42,6 +52,12 @@ end ##### `sortslices` ##### +function frule((_, ẋ), ::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) + p = sortperm(collect(eachslice(x; dims=dims)); kw...) + inds = ntuple(d -> d == dims ? p : (:), ndims(x)) + return x[inds...], ẋ[inds...] +end + function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) p = sortperm(collect(eachslice(x; dims=dims)); kw...) inds = ntuple(d -> d == dims ? p : (:), ndims(x)) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 3cb362900..13d77d302 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -59,6 +59,38 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::Diagonal{<:Number} return z, dot_pullback end +##### +##### `mul!` +##### + +function frule((_, ΔC, ΔA, ΔB), ::typeof(mul!), C::AbstractArray, A, B) + mul!(C, A, B) + mul!(ΔC, ΔA, B) + mul!(ΔC, A, ΔB, true, true) + return C, ΔC +end + +function frule((_, ΔC, ΔA, ΔB), ::typeof(mul!), C::AbstractArray, A, B, α::Bool, β::Bool) + # D = A*B*α + C*β + mul!(C, A, B, α, β) + # ΔD = ΔA*B*α + ΔC*β + A*ΔB*α + mul!(ΔC, ΔA, B, α, β) + mul!(ΔC, A, ΔB, α, true) + return C, ΔC +end + +function frule((_, ΔC, ΔA, ΔB, Δα, Δβ), ::typeof(mul!), C::AbstractArray, A, B, α::Number, β::Number) + # This is used twice: + AB = A * B + # ΔD = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α + @. ΔC = ΔC * β + C * Δβ + AB * Δα + mul!(ΔC, ΔA, B, α, true) + mul!(ΔC, A, ΔB, α, true) + # D = A*B*α + C*β + @. C = AB * α + C*β # Must be done last, as C enters above + return C, ΔC +end + ##### ##### `cross` ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 7a82bd350..87c352ab1 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -1,16 +1,19 @@ @testset "Array constructors" begin - + @testset "undef" begin # We can't use test_rrule here (as it's currently implemented) because the elements of # the array have arbitrary values. The only thing we can do is ensure that we're getting # `ZeroTangent`s back, and that the forwards pass produces the correct thing still. # Issue: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/202 - @testset "undef" begin val, pullback = rrule(Array{Float64}, undef, 5) @test size(val) == (5, ) @test val isa Array{Float64, 1} @test pullback(randn(5)) == (NoTangent(), NoTangent(), NoTangent()) end @testset "from existing array" begin + # fwd + test_frule(Array, randn(2, 5)) + test_frule(Array, Diagonal(randn(5))) + # rev test_rrule(Array, randn(2, 5)) test_rrule(Array, Diagonal(randn(5))) test_rrule(Matrix, Diagonal(randn(5))) @@ -18,6 +21,16 @@ test_rrule(Array{ComplexF64}, randn(3)) end end +@testset "AbstractArray constructors" begin + # These are what float(x) calls, but it's trivial with floating point numbers: + test_frule(AbstractArray{Float32}, rand(3); atol=0.01) + test_frule(AbstractArray{Float32}, Diagonal(rand(4)); atol=0.01) + # rev + test_rrule(AbstractArray{Float32}, rand(3); atol=0.01) + test_rrule(AbstractArray{Float32}, Diagonal(rand(4)); atol=0.01) + # Check with integers: + rrule(AbstractArray{Float64}, [1, 2, 3])[2]([1, 10, 100]) == (NoTangent(), [1.0, 10.0, 100.0]) +end @testset "vect" begin test_rrule(Base.vect) @@ -27,6 +40,9 @@ end test_rrule(Base.vect, randn(2, 2), randn(3, 3)) end @testset "inhomogeneous type" begin + # fwd + test_frule(Base.vect, 5.0, 3f0) + # rev test_rrule( Base.vect, 5.0, 3f0; atol=1e-6, rtol=1e-6, @@ -34,17 +50,61 @@ end test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false) test_rrule(Base.vect, (5.0, 4.0), (y=randn(3),); check_inferred=false) end + @testset "_instantiate_zeros" begin + # This is an internal function also used for `cat` etc. + @eval using ChainRules: _instantiate_zeros + # Check these hit the fast path, unrealistic input so that map would fail: + @test _instantiate_zeros((true, 2 , 3.0), ()) == (1, 2, 3) + @test _instantiate_zeros((1:2, [3, 4]), ()) == (1:2, 3:4) + end +end + +@testset "copyto!" begin + test_frule(copyto!, rand(5), rand(5)) + test_frule(copyto!, rand(10), 3, rand(5)) + test_frule(copyto!, rand(10), 2, rand(5), 2) + test_frule(copyto!, rand(10), 2, rand(5), 2, 4) end @testset "reshape" begin + # Forward + test_frule(reshape, rand(4, 3), 2, :) + test_frule(reshape, rand(4, 3), axes(rand(6, 2))) + @test_skip test_frule(reshape, Diagonal(rand(4)), 2, :) # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/239 + + # Reverse test_rrule(reshape, rand(4, 5), (2, 10)) test_rrule(reshape, rand(4, 5), 2, 10) test_rrule(reshape, rand(4, 5), 2, :) + test_rrule(reshape, rand(4, 5), axes(rand(10, 2))) + # structured + test_rrule(reshape, transpose(rand(4)), :) + test_rrule(reshape, adjoint(rand(ComplexF64, 4)), :) + @test rrule(reshape, adjoint(rand(ComplexF64, 4)), :)[2](rand(4))[2] isa Adjoint{ComplexF64} + @test rrule(reshape, Diagonal(rand(4)), (2, :))[2](ones(2,8))[2] isa Diagonal + @test_skip test_rrule(reshape, Diagonal(rand(4)), 2, :) # DimensionMismatch("second dimension of A, 22, does not match length of x, 16") + @test_skip test_rrule(reshape, UpperTriangular(rand(4,4)), (8, 2)) # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/239 +end + +@testset "dropdims" begin + # fwd + test_frule(dropdims, rand(4, 1); fkwargs=(; dims=2)) + # rev + test_rrule(dropdims, rand(4, 1); fkwargs=(; dims=2)) + test_rrule(dropdims, transpose(rand(4)); fkwargs=(; dims=1)) + test_rrule(dropdims, adjoint(rand(ComplexF64, 4)); fkwargs=(; dims=1)) + @test rrule(dropdims, adjoint(rand(ComplexF64, 4)); dims=1)[2](rand(4))[2] isa Adjoint{ComplexF64} end @testset "permutedims + PermutedDimsArray" begin - test_rrule(permutedims, rand(5)) + # Forward + test_frule(permutedims, rand(5)) + test_frule(permutedims, rand(3, 4), (2, 1)) + test_frule(permutedims!, rand(4,3), rand(3, 4), (2, 1)) + test_frule(PermutedDimsArray, rand(3, 4, 5), (3, 1, 2)) + # Reverse + test_rrule(permutedims, rand(5)) test_rrule(permutedims, rand(3, 4), (2, 1)) test_rrule(permutedims, Diagonal(rand(5)), (2, 1)) # Note BTW that permutedims(Diagonal(rand(5))) does not use the rule at all @@ -52,14 +112,18 @@ end @test invperm((3, 1, 2)) != (3, 1, 2) test_rrule(permutedims, rand(3, 4, 5), (3, 1, 2)) - @test_skip test_rrule(PermutedDimsArray, rand(3, 4, 5), (3, 1, 2)) + @test_skip test_rrule(PermutedDimsArray, rand(3, 4, 5), (3, 1, 2)) # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/240 x = rand(2, 3, 4) dy = rand(4, 2, 3) @test rrule(permutedims, x, (3, 1, 2))[2](dy)[2] == rrule(PermutedDimsArray, x, (3, 1, 2))[2](dy)[2] end @testset "repeat" begin + # forward + test_frule(repeat, rand(4), 2) + test_frule(repeat, rand(2, 3); fkwargs = (inner=(1,2), outer=(1,3))) + # reverse test_rrule(repeat, rand(4, )) test_rrule(repeat, rand(4, 5)) test_rrule(repeat, rand(4, 5); fkwargs = (outer=(1,2),)) @@ -92,6 +156,11 @@ end end @testset "hcat" begin + # forward + test_frule(hcat, randn(3, 2), randn(3)) + test_frule(hcat, randn(), randn(1,3)) + + # reverse test_rrule(hcat, randn(3, 2), randn(3), randn(3, 3)) test_rrule(hcat, rand(), rand(1,2), rand(1,2,1)) test_rrule(hcat, rand(3,1,1,2), rand(3,3,1,2)) @@ -102,9 +171,11 @@ end @testset "reduce hcat" begin mats = [randn(3, 2), randn(3, 1), randn(3, 3)] + test_frule(reduce, hcat, mats) test_rrule(reduce, hcat, mats) vecs = [rand(3) for _ in 1:4] + test_frule(reduce, hcat, vecs) test_rrule(reduce, hcat, vecs) mix = AbstractVecOrMat[rand(4,2), rand(4)] # this is weird, but does hit the fast path @@ -121,20 +192,29 @@ end end @testset "vcat" begin + + # forward + test_frule(vcat, randn(), randn(3), rand()) + test_frule(vcat, randn(3, 1), randn(3)) + + # reverse test_rrule(vcat, randn(2, 4), randn(1, 4), randn(3, 4)) test_rrule(vcat, rand(), rand()) test_rrule(vcat, rand(), rand(3), rand(3,1,1)) test_rrule(vcat, rand(3,1,2), rand(4,1,2)) + # mix types test_rrule(vcat, rand(2, 2), rand(2, 2)') end @testset "reduce vcat" begin mats = [randn(2, 4), randn(1, 4), randn(3, 4)] + test_frule(reduce, vcat, mats) test_rrule(reduce, vcat, mats) vecs = [rand(2), rand(3), rand(4)] + test_frule(reduce, vcat, vecs) test_rrule(reduce, vcat, vecs) mix = AbstractVecOrMat[rand(4,1), rand(4)] @@ -142,6 +222,11 @@ end end @testset "cat" begin + # forward + test_frule(cat, rand(2, 4), rand(1, 4); fkwargs=(dims=1,)) + test_frule(cat, rand(), rand(2,3); fkwargs=(dims=(1,2),)) + + # reverse test_rrule(cat, rand(2, 4), rand(1, 4); fkwargs=(dims=1,)) test_rrule(cat, rand(2, 4), rand(2); fkwargs=(dims=Val(2),)) test_rrule(cat, rand(), rand(2, 3); fkwargs=(dims=[1,2],)) @@ -151,6 +236,10 @@ end end @testset "hvcat" begin + # forward + test_frule(hvcat, 2, rand(6)...) + + # reverse test_rrule(hvcat, 2, rand(ComplexF64, 6)...) test_rrule(hvcat, (2, 1), rand(), rand(1,1), rand(2,2)) test_rrule(hvcat, 1, rand(3)' ⊢ rand(1,3), transpose(rand(3)) ⊢ rand(1,3)) @@ -170,11 +259,14 @@ end test_frule(reverse, rand(5)) test_frule(reverse, rand(5), 2, 4) test_frule(reverse, rand(5), fkwargs=(dims=1,)) - test_frule(reverse, rand(3,4), fkwargs=(dims=2,)) test_frule(reverse, rand(3,4)) test_frule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),)) + test_frule(reverse!, rand(5)) + test_frule(reverse!, rand(5), 2, 4) + test_frule(reverse!, rand(3,4), fkwargs=(dims=2,)) + # Reverse test_rrule(reverse, rand(5)) test_rrule(reverse, rand(5), 2, 4) @@ -198,6 +290,9 @@ end test_frule(circshift, rand(10), (1,)) test_frule(circshift, rand(3,4), (-7,2)) + test_frule(circshift!, rand(10), rand(10), 1) + test_frule(circshift!, rand(3,4), rand(3,4), (-7,2)) + # Reverse test_rrule(circshift, rand(10), 1) test_rrule(circshift, rand(10) .+ im, -2) @@ -206,9 +301,13 @@ end end @testset "fill" begin + # Forward test_frule(fill, 12.3, 4) test_frule(fill, 5.0, (6, 7)) + test_frule(fill!, rand(2, 3), rand()) + + # Reverse test_rrule(fill, 44.4, 4) test_rrule(fill, 55 + 0.5im, 5) test_rrule(fill, 3.3, (3, 3, 3)) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index b25b16acf..324311356 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -11,24 +11,30 @@ ⋆() = only(⋆(())) # scalar @testset "Scalar-Array $dims" for dims in ((3,), (5,4), (2, 3, 4, 5)) + test_frule(*, ⋆(), ⋆(dims)) + test_frule(*, ⋆(dims), ⋆()) + test_rrule(*, ⋆(), ⋆(dims)) test_rrule(*, ⋆(dims), ⋆()) end @testset "AbstractMatrix-AbstractVector n=$n, m=$m" for n in (2, 3), m in (4, 5) @testset "Array" begin + test_frule(*, n ⋆ m, ⋆(m)) test_rrule(*, n ⋆ m, ⋆(m)) end end @testset "AbstractVector-AbstractMatrix n=$n, m=$m" for n in (2, 3), m in (4, 5) @testset "Array" begin + test_frule(*, ⋆(n), 1 ⋆ m) test_rrule(*, ⋆(n), 1 ⋆ m) end end @testset "dense-matrix n=$n, m=$m, p=$p" for n in (2, 5), m in (2, 4), p in (2, 3) @testset "Array" begin + test_frule(*, (n⋆m), (m⋆p)) test_rrule(*, (n⋆m), (m⋆p)) end @@ -53,6 +59,11 @@ end @testset "Diagonal" begin + # fwd + test_frule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0])) + test_frule(*, Diagonal([1.0, 2.0, 3.0]), rand(3)) + + # rev test_rrule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0])) test_rrule(*, Diagonal([1.0, 2.0, 3.0]), rand(3)) @@ -71,6 +82,9 @@ @testset "muladd: $T" for T in (Float64, ComplexF64) @testset "add $(typeof(z))" for z in [rand(T), rand(T, 3), rand(T, 3, 3), false] + @testset "forward mode" begin + test_frule(muladd, rand(T, 3, 5), rand(T, 5, 3), z) + end @testset "matrix * matrix" begin A = rand(T, 3, 3) B = rand(T, 3, 3) @@ -166,6 +180,10 @@ @testset "/ and \\ Scalar-AbstractArray" begin A = round.(10 .* randn(3, 4, 5), digits=1) + # fwd + test_frule(/, A, 7.2) + test_frule(\, 7.2, A) + # rev test_rrule(/, A, 7.2) test_rrule(\, 7.2, A) @@ -177,12 +195,17 @@ @testset "negation" begin A = randn(4, 4) Ā = randn(4, 4) - + # fwd + test_frule(-, A) + # rev test_rrule(-, A) test_rrule(-, Diagonal(A); output_tangent=Diagonal(Ā)) end @testset "addition" begin + # fwd + test_frule(+, randn(2), randn(2), randn(2)) + # rev test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4)) test_rrule(+, randn(3), randn(3,1), randn(3,1,1)) end diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 54939858a..b9219825e 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -1,6 +1,15 @@ @testset "getindex" begin @testset "getindex(::Matrix{<:Number}, ...)" begin x = [1.0 2.0 3.0; 10.0 20.0 30.0] + + @testset "forward mode" begin + test_frule(getindex, x, 2) + test_frule(getindex, x, 2, 1) + test_frule(getindex, x, CartesianIndex(2, 3)) + + test_frule(getindex, x, 2:3) + test_frule(getindex, x, (:), 2:3) + end @testset "single element" begin test_rrule(getindex, x, 2) @@ -48,3 +57,14 @@ end end end + +@testset "view" begin + test_frule(view, rand(3, 4), :, 1) + test_frule(view, rand(3, 4), 2, [1, 1, 2]) + test_frule(view, rand(3, 4), 3, 4) +end + +@testset "setindex!" begin + test_frule(setindex!, rand(3, 4), rand(), 1, 2) + test_frule(setindex!, rand(3, 4), [1,10,100.0], :, 3) +end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index c9a740206..f07fc9b32 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -43,6 +43,11 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig() end end + @testset "sum!(y, x)" begin + test_frule(sum!, rand(3), rand(3, 5)) + test_frule(sum!, rand(ComplexF64, 1, 4), rand(3, 4)) + end + @testset "sum abs2" begin sizes = (3, 4, 7) @testset "dims = $dims" for dims in (:, 1) diff --git a/test/rulesets/Base/sort.jl b/test/rulesets/Base/sort.jl index f76109586..052045d1e 100644 --- a/test/rulesets/Base/sort.jl +++ b/test/rulesets/Base/sort.jl @@ -1,11 +1,19 @@ @testset "sort.jl" begin @testset "sort" begin a = rand(10) + # fwd + test_frule(sort, a) + test_frule(sort, a; fkwargs=(;rev=true)) + # rev test_rrule(sort, a) test_rrule(sort, a; fkwargs=(;rev=true)) end @testset "partialsort" begin a = rand(10) + # fwd + test_frule(partialsort, a, 3:5) + test_frule(partialsort, a, 4, fkwargs=(;rev=true)) + # rev test_rrule(partialsort, a, 4) test_rrule(partialsort, a, 3:5) test_rrule(partialsort, a, 1:2:6) @@ -14,6 +22,8 @@ end @testset "sortslices" begin + test_frule(sortslices, rand(3,4); fkwargs=(; dims=2)) + test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2)) test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last)) test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false) diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 601000d7e..d8eb50eb0 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -35,6 +35,22 @@ end end + @testset "mul!" begin + test_frule(mul!, rand(4), rand(4, 5), rand(5)) + test_frule(mul!, rand(3, 3), rand(3, 3), rand(3, 3)) + test_frule(mul!, rand(3, 3), rand(), rand(3, 3)) + + # Rule with α,β::Bool is only visually more complicated: + test_frule(mul!, rand(4), rand(4, 5), rand(5), true, true) + test_frule(mul!, rand(4), rand(4, 5), rand(5), false, true) + test_frule(mul!, rand(4), rand(4, 5), rand(5), true, false) + test_frule(mul!, rand(4), rand(4, 5), rand(5), false, false) + + # Rule with nontrivial α, β allocates A*B: + test_frule(mul!, rand(4), rand(4, 5), rand(5), true, randn()) + test_frule(mul!, rand(4), rand(4, 5), rand(5), randn(), randn()) + end + @testset "cross" begin test_frule(cross, randn(3), randn(3)) test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3))