Skip to content
9 changes: 3 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ version = "2.38.0"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -18,26 +16,24 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Adapt = "3"
ArrayInterface = "7"
ChainRulesCore = "0.10.7, 1"
DocStringExtensions = "0.8, 0.9"
FillArrays = "0.11, 0.12, 0.13"
GPUArraysCore = "0.1"
IteratorInterfaceExtensions = "1"
RecipesBase = "0.7, 0.8, 1.0"
Requires = "1.0"
StaticArraysCore = "1.1"
SymbolicIndexingInterface = "0.1, 0.2"
Tables = "1"
ZygoteRules = "0.2"
Zygote = "< 0.6.56"
julia = "1.6"

[extensions]
RecursiveArrayToolsTrackerExt = "Tracker"
RecursiveArrayToolsZygoteExt = "Zygote"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand All @@ -60,3 +56,4 @@ test = ["SafeTestsets", "Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "Ord

[weakdeps]
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
113 changes: 113 additions & 0 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
module RecursiveArrayToolsZygoteExt

using RecursiveArrayTools

if isdefined(Base, :get_extension)
using Zygote
using Zygote: FillArrays, ChainRulesCore, literal_getproperty, @adjoint
else
using ..Zygote
using ..Zygote: FillArrays, ChainRulesCore, literal_getproperty, @adjoint
end

# Define a new species of projection operator for this type:
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()

function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray},
xs::AbstractVectorOfArray)
T(xs), ȳ -> (ChainRulesCore.NoTangent(), ȳ)
end

@adjoint function getindex(VA::AbstractVectorOfArray, i::Int)
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)), size(x)))
for (x, j) in zip(VA.u, 1:length(VA))]
(VectorOfArray(Δ′), nothing)
end
VA[i], AbstractVectorOfArray_getindex_adjoint
end

@adjoint function getindex(VA::AbstractVectorOfArray,
i::Union{BitArray, AbstractArray{Bool}})
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)), size(x)))
for (x, j) in zip(VA.u, 1:length(VA))]
(VectorOfArray(Δ′), nothing)
end
VA[i], AbstractVectorOfArray_getindex_adjoint
end

@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int})
function AbstractVectorOfArray_getindex_adjoint(Δ)
iter = 0
Δ′ = [(j ∈ i ? Δ[iter += 1] : Fill(zero(eltype(x)), size(x)))
for (x, j) in zip(VA.u, 1:length(VA))]
(VectorOfArray(Δ′), nothing)
end
VA[i], AbstractVectorOfArray_getindex_adjoint
end

@adjoint function getindex(VA::AbstractVectorOfArray,
i::Union{Int, AbstractArray{Int}})
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)), size(x)))
for (x, j) in zip(VA.u, 1:length(VA))]
(VectorOfArray(Δ′), nothing)
end
VA[i], AbstractVectorOfArray_getindex_adjoint
end

@adjoint function getindex(VA::AbstractVectorOfArray, i::Colon)
function AbstractVectorOfArray_getindex_adjoint(Δ)
(VectorOfArray(Δ), nothing)
end
VA[i], AbstractVectorOfArray_getindex_adjoint
end

@adjoint function getindex(VA::AbstractVectorOfArray, i::Int,
j::Union{Int, AbstractArray{Int}, CartesianIndex,
Colon, BitArray, AbstractArray{Bool}}...)
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))])
Δ′[i, j...] = Δ
(Δ′, nothing, map(_ -> nothing, j)...)
end
VA[i, j...], AbstractVectorOfArray_getindex_adjoint
end

@adjoint function ArrayPartition(x::S,
::Type{Val{copy_x}} = Val{false}) where {
S <:
Tuple,
copy_x
}
function ArrayPartition_adjoint(_y)
y = Array(_y)
starts = vcat(0, cumsum(reduce(vcat, length.(x))))
ntuple(i -> reshape(y[(starts[i] + 1):starts[i + 1]], size(x[i])), length(x)),
nothing
end

ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
end

@adjoint function VectorOfArray(u)
VectorOfArray(u),
y -> (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
for i in 1:size(y)[end]]),)
end

@adjoint function DiffEqArray(u, t)
DiffEqArray(u, t),
y -> (DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]],
t), nothing)
end

@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})
function literal_ArrayPartition_x_adjoint(d)
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)
end
A.x, literal_ArrayPartition_x_adjoint
end

end
12 changes: 2 additions & 10 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@ using RecipesBase, StaticArraysCore, Statistics,
ArrayInterface, LinearAlgebra
using SymbolicIndexingInterface

import ChainRulesCore
import ChainRulesCore: NoTangent
import ZygoteRules, Adapt

using FillArrays
import Adapt

import Tables, IteratorInterfaceExtensions

Expand All @@ -24,23 +20,19 @@ include("utils.jl")
include("vector_of_array.jl")
include("tabletraits.jl")
include("array_partition.jl")
include("zygote.jl")

function Base.show(io::IO, x::Union{ArrayPartition, AbstractVectorOfArray})
invoke(show, Tuple{typeof(io), Any}, io, x)
end

import GPUArraysCore
Base.convert(T::Type{<:GPUArraysCore.AbstractGPUArray}, VA::AbstractVectorOfArray) = T(VA)
function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray},
xs::AbstractVectorOfArray)
T(xs), ȳ -> (NoTangent(), ȳ)
end

import Requires
@static if !isdefined(Base, :get_extension)
function __init__()
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/RecursiveArrayToolsTrackerExt.jl") end
Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/RecursiveArrayToolsZygoteExt.jl") end
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/array_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ Base.ones(A::ArrayPartition, dims::NTuple{N, Int}) where {N} = ones(A)

# mutable iff all components of ArrayPartition are mutable
@generated function ArrayInterface.ismutable(::Type{<:ArrayPartition{T, S}}) where {T, S
}
}
res = all(ArrayInterface.ismutable, S.parameters)
return :($res)
end
Expand Down
160 changes: 0 additions & 160 deletions src/zygote.jl

This file was deleted.

Loading