From 08f39929c88c1dc570712a610c0109af83fcc379 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 25 Mar 2023 11:34:34 -0400 Subject: [PATCH 01/13] Remove FillArrays dependency to reduce import times It was added in https://github.com/SciML/RecursiveArrayTools.jl/commit/b3ed973b0167f0e60f4b7c4e801af850ce021b17 but in https://github.com/PumasAI/DataInterpolations.jl/issues/129 it has been identified as one of the main import time contributions. It's only necessary for the Zygote rules, and not even ChainRules but specifically `Zygote.@adjoint`, so they are simply moved to a Zygote extension and use FillArrays from Zygote to achieve the same goal. --- Project.toml | 3 +- ext/RecursiveArrayToolsZygoteExt.jl | 105 ++++++++++++++++++++++++++++ src/RecursiveArrayTools.jl | 3 +- src/zygote.jl | 94 +------------------------ 4 files changed, 109 insertions(+), 96 deletions(-) create mode 100644 ext/RecursiveArrayToolsZygoteExt.jl diff --git a/Project.toml b/Project.toml index c23a76ec..d8def414 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -38,6 +37,7 @@ julia = "1.6" [extensions] RecursiveArrayToolsTrackerExt = "Tracker" +RecursiveArrayToolsZygoteExt = "Zygote" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" @@ -60,3 +60,4 @@ test = ["SafeTestsets", "Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "Ord [weakdeps] Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" \ No newline at end of file diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl new file mode 100644 index 00000000..ad5b4113 --- /dev/null +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -0,0 +1,105 @@ +module RecursiveArrayToolsZygoteExt + +import RecursiveArrayTools + +if isdefined(Base, :get_extension) + using Zygote + using Zygote: ZygoteRules, FullArrays +else + using ..Zygote + using ..Zygote: ZygoteRules, FullArrays +end + +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int) + function AbstractVectorOfArray_getindex_adjoint(Δ) + Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)), size(x))) + for (x, j) in zip(VA.u, 1:length(VA))] + (VectorOfArray(Δ′), nothing) + end + VA[i], AbstractVectorOfArray_getindex_adjoint +end + +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, + i::Union{BitArray, AbstractArray{Bool}}) + function AbstractVectorOfArray_getindex_adjoint(Δ) + Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)), size(x))) + for (x, j) in zip(VA.u, 1:length(VA))] + (VectorOfArray(Δ′), nothing) + end + VA[i], AbstractVectorOfArray_getindex_adjoint +end + +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int}) + function AbstractVectorOfArray_getindex_adjoint(Δ) + iter = 0 + Δ′ = [(j ∈ i ? Δ[iter += 1] : Fill(zero(eltype(x)), size(x))) + for (x, j) in zip(VA.u, 1:length(VA))] + (VectorOfArray(Δ′), nothing) + end + VA[i], AbstractVectorOfArray_getindex_adjoint +end + +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, + i::Union{Int, AbstractArray{Int}}) + function AbstractVectorOfArray_getindex_adjoint(Δ) + Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)), size(x))) + for (x, j) in zip(VA.u, 1:length(VA))] + (VectorOfArray(Δ′), nothing) + end + VA[i], AbstractVectorOfArray_getindex_adjoint +end + +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Colon) + function AbstractVectorOfArray_getindex_adjoint(Δ) + (VectorOfArray(Δ), nothing) + end + VA[i], AbstractVectorOfArray_getindex_adjoint +end + +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, + j::Union{Int, AbstractArray{Int}, CartesianIndex, + Colon, BitArray, AbstractArray{Bool}}...) + function AbstractVectorOfArray_getindex_adjoint(Δ) + Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))]) + Δ′[i, j...] = Δ + (Δ′, nothing, map(_ -> nothing, j)...) + end + VA[i, j...], AbstractVectorOfArray_getindex_adjoint +end + +ZygoteRules.@adjoint function ArrayPartition(x::S, + ::Type{Val{copy_x}} = Val{false}) where { + S <: + Tuple, + copy_x + } + function ArrayPartition_adjoint(_y) + y = Array(_y) + starts = vcat(0, cumsum(reduce(vcat, length.(x)))) + ntuple(i -> reshape(y[(starts[i] + 1):starts[i + 1]], size(x[i])), length(x)), + nothing + end + + ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint +end + +ZygoteRules.@adjoint function VectorOfArray(u) + VectorOfArray(u), + y -> (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] + for i in 1:size(y)[end]]),) +end + +ZygoteRules.@adjoint function DiffEqArray(u, t) + DiffEqArray(u, t), + y -> (DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]], + t), nothing) +end + +ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x}) + function literal_ArrayPartition_x_adjoint(d) + (ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),) + end + A.x, literal_ArrayPartition_x_adjoint +end + +end \ No newline at end of file diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index f4fc7cc8..f67d8b60 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -13,8 +13,6 @@ import ChainRulesCore import ChainRulesCore: NoTangent import ZygoteRules, Adapt -using FillArrays - import Tables, IteratorInterfaceExtensions abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end @@ -41,6 +39,7 @@ import Requires @static if !isdefined(Base, :get_extension) function __init__() Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/RecursiveArrayToolsTrackerExt.jl") end + Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/RecursiveArrayToolsZygoteExt.jl") end end end diff --git a/src/zygote.jl b/src/zygote.jl index 15672da8..e4d9f164 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -65,96 +65,4 @@ ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfAr #(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx # These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint` -# definition first, and finds its own before finding those. - -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int) - function AbstractVectorOfArray_getindex_adjoint(Δ) - Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)), size(x))) - for (x, j) in zip(VA.u, 1:length(VA))] - (VectorOfArray(Δ′), nothing) - end - VA[i], AbstractVectorOfArray_getindex_adjoint -end - -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, - i::Union{BitArray, AbstractArray{Bool}}) - function AbstractVectorOfArray_getindex_adjoint(Δ) - Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)), size(x))) - for (x, j) in zip(VA.u, 1:length(VA))] - (VectorOfArray(Δ′), nothing) - end - VA[i], AbstractVectorOfArray_getindex_adjoint -end - -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int}) - function AbstractVectorOfArray_getindex_adjoint(Δ) - iter = 0 - Δ′ = [(j ∈ i ? Δ[iter += 1] : Fill(zero(eltype(x)), size(x))) - for (x, j) in zip(VA.u, 1:length(VA))] - (VectorOfArray(Δ′), nothing) - end - VA[i], AbstractVectorOfArray_getindex_adjoint -end - -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, - i::Union{Int, AbstractArray{Int}}) - function AbstractVectorOfArray_getindex_adjoint(Δ) - Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)), size(x))) - for (x, j) in zip(VA.u, 1:length(VA))] - (VectorOfArray(Δ′), nothing) - end - VA[i], AbstractVectorOfArray_getindex_adjoint -end - -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Colon) - function AbstractVectorOfArray_getindex_adjoint(Δ) - (VectorOfArray(Δ), nothing) - end - VA[i], AbstractVectorOfArray_getindex_adjoint -end - -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, - j::Union{Int, AbstractArray{Int}, CartesianIndex, - Colon, BitArray, AbstractArray{Bool}}...) - function AbstractVectorOfArray_getindex_adjoint(Δ) - Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))]) - Δ′[i, j...] = Δ - (Δ′, nothing, map(_ -> nothing, j)...) - end - VA[i, j...], AbstractVectorOfArray_getindex_adjoint -end - -ZygoteRules.@adjoint function ArrayPartition(x::S, - ::Type{Val{copy_x}} = Val{false}) where { - S <: - Tuple, - copy_x - } - function ArrayPartition_adjoint(_y) - y = Array(_y) - starts = vcat(0, cumsum(reduce(vcat, length.(x)))) - ntuple(i -> reshape(y[(starts[i] + 1):starts[i + 1]], size(x[i])), length(x)), - nothing - end - - ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint -end - -ZygoteRules.@adjoint function VectorOfArray(u) - VectorOfArray(u), - y -> (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] - for i in 1:size(y)[end]]),) -end - -ZygoteRules.@adjoint function DiffEqArray(u, t) - DiffEqArray(u, t), - y -> (DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]], - t), nothing) -end - -ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x}) - function literal_ArrayPartition_x_adjoint(d) - (ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),) - end - A.x, literal_ArrayPartition_x_adjoint -end +# definition first, and finds its own before finding those. \ No newline at end of file From 8aa1e1e4dd227a2286be541bdaf7200ee2e798cf Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 25 Mar 2023 11:37:04 -0400 Subject: [PATCH 02/13] remove FillArrays compat --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index d8def414..0f07bcc5 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,6 @@ Adapt = "3" ArrayInterface = "7" ChainRulesCore = "0.10.7, 1" DocStringExtensions = "0.8, 0.9" -FillArrays = "0.11, 0.12, 0.13" GPUArraysCore = "0.1" IteratorInterfaceExtensions = "1" RecipesBase = "0.7, 0.8, 1.0" From 2d9cb7836a6a5c9449bc944a3b53c77de2412aa2 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 25 Mar 2023 12:41:31 -0400 Subject: [PATCH 03/13] no ZygoteRules --- ext/RecursiveArrayToolsZygoteExt.jl | 44 ++++++++++++++--------------- src/array_partition.jl | 2 +- src/zygote.jl | 2 +- test/partitions_test.jl | 26 ++++++++--------- 4 files changed, 37 insertions(+), 37 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index ad5b4113..34d0992b 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -4,13 +4,13 @@ import RecursiveArrayTools if isdefined(Base, :get_extension) using Zygote - using Zygote: ZygoteRules, FullArrays + using Zygote: FullArrays, literal_getproperty, @adjoint else using ..Zygote - using ..Zygote: ZygoteRules, FullArrays + using ..Zygote: FullArrays, literal_getproperty, @adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int) +@adjoint function getindex(VA::AbstractVectorOfArray, i::Int) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)), size(x))) for (x, j) in zip(VA.u, 1:length(VA))] @@ -19,8 +19,8 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int) VA[i], AbstractVectorOfArray_getindex_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, - i::Union{BitArray, AbstractArray{Bool}}) +@adjoint function getindex(VA::AbstractVectorOfArray, + i::Union{BitArray, AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)), size(x))) for (x, j) in zip(VA.u, 1:length(VA))] @@ -29,7 +29,7 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, VA[i], AbstractVectorOfArray_getindex_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int}) +@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int}) function AbstractVectorOfArray_getindex_adjoint(Δ) iter = 0 Δ′ = [(j ∈ i ? Δ[iter += 1] : Fill(zero(eltype(x)), size(x))) @@ -39,8 +39,8 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArr VA[i], AbstractVectorOfArray_getindex_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, - i::Union{Int, AbstractArray{Int}}) +@adjoint function getindex(VA::AbstractVectorOfArray, + i::Union{Int, AbstractArray{Int}}) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)), size(x))) for (x, j) in zip(VA.u, 1:length(VA))] @@ -49,16 +49,16 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, VA[i], AbstractVectorOfArray_getindex_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Colon) +@adjoint function getindex(VA::AbstractVectorOfArray, i::Colon) function AbstractVectorOfArray_getindex_adjoint(Δ) (VectorOfArray(Δ), nothing) end VA[i], AbstractVectorOfArray_getindex_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, - j::Union{Int, AbstractArray{Int}, CartesianIndex, - Colon, BitArray, AbstractArray{Bool}}...) +@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, + j::Union{Int, AbstractArray{Int}, CartesianIndex, + Colon, BitArray, AbstractArray{Bool}}...) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))]) Δ′[i, j...] = Δ @@ -67,12 +67,12 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, VA[i, j...], AbstractVectorOfArray_getindex_adjoint end -ZygoteRules.@adjoint function ArrayPartition(x::S, - ::Type{Val{copy_x}} = Val{false}) where { - S <: - Tuple, - copy_x - } +@adjoint function ArrayPartition(x::S, + ::Type{Val{copy_x}} = Val{false}) where { + S <: + Tuple, + copy_x + } function ArrayPartition_adjoint(_y) y = Array(_y) starts = vcat(0, cumsum(reduce(vcat, length.(x)))) @@ -83,23 +83,23 @@ ZygoteRules.@adjoint function ArrayPartition(x::S, ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint end -ZygoteRules.@adjoint function VectorOfArray(u) +@adjoint function VectorOfArray(u) VectorOfArray(u), y -> (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]]),) end -ZygoteRules.@adjoint function DiffEqArray(u, t) +@adjoint function DiffEqArray(u, t) DiffEqArray(u, t), y -> (DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]], t), nothing) end -ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x}) +@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x}) function literal_ArrayPartition_x_adjoint(d) (ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),) end A.x, literal_ArrayPartition_x_adjoint end -end \ No newline at end of file +end diff --git a/src/array_partition.jl b/src/array_partition.jl index 3ccadc49..a72b9c63 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -115,7 +115,7 @@ Base.ones(A::ArrayPartition, dims::NTuple{N, Int}) where {N} = ones(A) # mutable iff all components of ArrayPartition are mutable @generated function ArrayInterface.ismutable(::Type{<:ArrayPartition{T, S}}) where {T, S - } + } res = all(ArrayInterface.ismutable, S.parameters) return :($res) end diff --git a/src/zygote.jl b/src/zygote.jl index e4d9f164..7f2778a3 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -65,4 +65,4 @@ ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfAr #(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx # These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint` -# definition first, and finds its own before finding those. \ No newline at end of file +# definition first, and finds its own before finding those. diff --git a/test/partitions_test.jl b/test/partitions_test.jl index 7caba936..1aa13f3f 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -191,20 +191,20 @@ up = 2 .* ap .+ 1 @test typeof(ap) == typeof(up) @testset "ArrayInterface.ismutable(ArrayPartition($a, $b)) == $r" for (a, b, r) in ((1, + 2, + false), + ([ + 1, + ], + 2, + false), + ([ + 1, + ], + [ 2, - false), - ([ - 1, - ], - 2, - false), - ([ - 1, - ], - [ - 2, - ], - true)) + ], + true)) @test ArrayInterface.ismutable(ArrayPartition(a, b)) == r end From 08efc723194d04176852189030710eb3ef79fb8a Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 25 Mar 2023 12:53:58 -0400 Subject: [PATCH 04/13] FillArrays --- ext/RecursiveArrayToolsZygoteExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 34d0992b..3832bd21 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -4,10 +4,10 @@ import RecursiveArrayTools if isdefined(Base, :get_extension) using Zygote - using Zygote: FullArrays, literal_getproperty, @adjoint + using Zygote: FillArrays, literal_getproperty, @adjoint else using ..Zygote - using ..Zygote: FullArrays, literal_getproperty, @adjoint + using ..Zygote: FillArrays, literal_getproperty, @adjoint end @adjoint function getindex(VA::AbstractVectorOfArray, i::Int) From 056450661a524383c17654eae3a4837a076065ae Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 25 Mar 2023 13:10:24 -0400 Subject: [PATCH 05/13] using RAT --- ext/RecursiveArrayToolsZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 3832bd21..b94327c7 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -1,6 +1,6 @@ module RecursiveArrayToolsZygoteExt -import RecursiveArrayTools +using RecursiveArrayTools if isdefined(Base, :get_extension) using Zygote From 0e7d97f666bd058f715e6c33db6d30c803b4660a Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 26 Mar 2023 00:15:07 -0400 Subject: [PATCH 06/13] remove extra chainrules --- src/zygote.jl | 66 --------------------------------------------------- 1 file changed, 66 deletions(-) diff --git a/src/zygote.jl b/src/zygote.jl index 7f2778a3..c06eca4f 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -1,68 +1,2 @@ -function ChainRulesCore.rrule(::typeof(getindex), VA::AbstractVectorOfArray, - i::Union{Int, AbstractArray{Int}, CartesianIndex, Colon, - BitArray, AbstractArray{Bool}}) - function AbstractVectorOfArray_getindex_adjoint(Δ) - Δ′ = [(i == j ? Δ : zero(x)) for (x, j) in zip(VA.u, 1:length(VA))] - (NoTangent(), VectorOfArray(Δ′), NoTangent()) - end - VA[i], AbstractVectorOfArray_getindex_adjoint -end - -function ChainRulesCore.rrule(::typeof(getindex), VA::AbstractVectorOfArray, - indices::Union{Int, AbstractArray{Int}, CartesianIndex, Colon, - BitArray, AbstractArray{Bool}}...) - function AbstractVectorOfArray_getindex_adjoint(Δ) - Δ′ = zero(VA) - Δ′[indices...] = Δ - (NoTangent(), VectorOfArray(Δ′), map(_ -> NoTangent(), indices)...) - end - VA[indices...], AbstractVectorOfArray_getindex_adjoint -end - -function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, - ::Type{Val{copy_x}} = Val{false}) where {S <: Tuple, copy_x} - function ArrayPartition_adjoint(_y) - y = Array(_y) - starts = vcat(0, cumsum(reduce(vcat, length.(x)))) - NoTangent(), - ntuple(i -> reshape(y[(starts[i] + 1):starts[i + 1]], size(x[i])), length(x)), - NoTangent() - end - - ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint -end - -function ChainRulesCore.rrule(::Type{<:VectorOfArray}, u) - VectorOfArray(u), - y -> (NoTangent(), - [y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]]) -end - -function ChainRulesCore.rrule(::Type{<:DiffEqArray}, u, t) - DiffEqArray(u, t), - y -> (NoTangent(), - [y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]], - NoTangent()) -end - -function ChainRulesCore.rrule(::typeof(getproperty), A::ArrayPartition, s::Symbol) - if s !== :x - error("$s is not a field of ArrayPartition") - end - function literal_ArrayPartition_x_adjoint(d) - (NoTangent(), - ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...)) - end - A.x, literal_ArrayPartition_x_adjoint -end - # Define a new species of projection operator for this type: ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() - -# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix -#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx) -# Gradient from broadcasting will be another AbstractArray -#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx - -# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint` -# definition first, and finds its own before finding those. From c0577293a44e55cf467fc6b2e3a6ff9ffca4a981 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 26 Mar 2023 00:58:18 -0400 Subject: [PATCH 07/13] Fix Zygote version --- Project.toml | 5 +---- src/RecursiveArrayTools.jl | 4 +--- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 0f07bcc5..df95cdfc 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "2.38.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" @@ -17,12 +16,10 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Adapt = "3" ArrayInterface = "7" -ChainRulesCore = "0.10.7, 1" DocStringExtensions = "0.8, 0.9" GPUArraysCore = "0.1" IteratorInterfaceExtensions = "1" @@ -31,7 +28,7 @@ Requires = "1.0" StaticArraysCore = "1.1" SymbolicIndexingInterface = "0.1, 0.2" Tables = "1" -ZygoteRules = "0.2" +Zygote = "= 0.6.55" julia = "1.6" [extensions] diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index f67d8b60..6a3a4f56 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -9,9 +9,7 @@ using RecipesBase, StaticArraysCore, Statistics, ArrayInterface, LinearAlgebra using SymbolicIndexingInterface -import ChainRulesCore -import ChainRulesCore: NoTangent -import ZygoteRules, Adapt +import Adapt import Tables, IteratorInterfaceExtensions From 2d6b9cdb03ad3dd55f98c2cc425dd9f066fc8dc0 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 26 Mar 2023 01:10:10 -0400 Subject: [PATCH 08/13] further simplify --- ext/RecursiveArrayToolsZygoteExt.jl | 7 +++++-- src/RecursiveArrayTools.jl | 1 - src/zygote.jl | 2 -- 3 files changed, 5 insertions(+), 5 deletions(-) delete mode 100644 src/zygote.jl diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index b94327c7..17fdcbe6 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -4,12 +4,15 @@ using RecursiveArrayTools if isdefined(Base, :get_extension) using Zygote - using Zygote: FillArrays, literal_getproperty, @adjoint + using Zygote: FillArrays, ChainRulesCore, literal_getproperty, @adjoint else using ..Zygote - using ..Zygote: FillArrays, literal_getproperty, @adjoint + using ..Zygote: FillArrays, ChainRulesCore, literal_getproperty, @adjoint end +# Define a new species of projection operator for this type: +ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() + @adjoint function getindex(VA::AbstractVectorOfArray, i::Int) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)), size(x))) diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index 6a3a4f56..a2065722 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -20,7 +20,6 @@ include("utils.jl") include("vector_of_array.jl") include("tabletraits.jl") include("array_partition.jl") -include("zygote.jl") function Base.show(io::IO, x::Union{ArrayPartition, AbstractVectorOfArray}) invoke(show, Tuple{typeof(io), Any}, io, x) diff --git a/src/zygote.jl b/src/zygote.jl deleted file mode 100644 index c06eca4f..00000000 --- a/src/zygote.jl +++ /dev/null @@ -1,2 +0,0 @@ -# Define a new species of projection operator for this type: -ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() From 087fa42d8ebb839777d382f7757bf520d2788214 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 26 Mar 2023 01:33:03 -0400 Subject: [PATCH 09/13] fix removal --- ext/RecursiveArrayToolsZygoteExt.jl | 5 +++++ src/RecursiveArrayTools.jl | 4 ---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 17fdcbe6..cee91758 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -13,6 +13,11 @@ end # Define a new species of projection operator for this type: ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() +function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray}, + xs::AbstractVectorOfArray) + T(xs), ȳ -> (NoTangent(), ȳ) +end + @adjoint function getindex(VA::AbstractVectorOfArray, i::Int) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)), size(x))) diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index a2065722..fc97b082 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -27,10 +27,6 @@ end import GPUArraysCore Base.convert(T::Type{<:GPUArraysCore.AbstractGPUArray}, VA::AbstractVectorOfArray) = T(VA) -function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray}, - xs::AbstractVectorOfArray) - T(xs), ȳ -> (NoTangent(), ȳ) -end import Requires @static if !isdefined(Base, :get_extension) From b45ef3ccc703ddd287c1aa71d97677f2c3f03e45 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 26 Mar 2023 01:52:25 -0400 Subject: [PATCH 10/13] namespace --- ext/RecursiveArrayToolsZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index cee91758..0a4d9cd1 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -13,7 +13,7 @@ end # Define a new species of projection operator for this type: ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() -function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray}, +function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray}, xs::AbstractVectorOfArray) T(xs), ȳ -> (NoTangent(), ȳ) end From 6b95b9c2db2f756896fcf77a57a48f7127b4c97c Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 26 Mar 2023 03:07:32 -0400 Subject: [PATCH 11/13] bump Zygote version --- Project.toml | 2 +- ext/RecursiveArrayToolsZygoteExt.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index df95cdfc..1347e3df 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,7 @@ Requires = "1.0" StaticArraysCore = "1.1" SymbolicIndexingInterface = "0.1, 0.2" Tables = "1" -Zygote = "= 0.6.55" +Zygote = "= 0.6.57" julia = "1.6" [extensions] diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 0a4d9cd1..d8e167c0 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -15,7 +15,7 @@ ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfAr function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray}, xs::AbstractVectorOfArray) - T(xs), ȳ -> (NoTangent(), ȳ) + T(xs), ȳ -> (ChainRulesCore.NoTangent(), ȳ) end @adjoint function getindex(VA::AbstractVectorOfArray, i::Int) From 840fcb75862c835ea63ae5e208f72a38c1706eb4 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 26 Mar 2023 03:25:57 -0400 Subject: [PATCH 12/13] reduce version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1347e3df..fa698074 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,7 @@ Requires = "1.0" StaticArraysCore = "1.1" SymbolicIndexingInterface = "0.1, 0.2" Tables = "1" -Zygote = "= 0.6.57" +Zygote = "= 0.6.56" julia = "1.6" [extensions] From 9f10f06b62fcdd13a93ad60c757e4efec0ac746c Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 26 Mar 2023 04:12:13 -0400 Subject: [PATCH 13/13] add version limit --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fa698074..d86a3434 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,7 @@ Requires = "1.0" StaticArraysCore = "1.1" SymbolicIndexingInterface = "0.1, 0.2" Tables = "1" -Zygote = "= 0.6.56" +Zygote = "< 0.6.56" julia = "1.6" [extensions]