Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ jobs:
- {user: SciML, repo: OrdinaryDiffEq.jl, group: Core}
- {user: SciML, repo: OrdinaryDiffEq.jl, group: Interface}
- {user: SciML, repo: DelayDiffEq.jl, group: Interface}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core1}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core2}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core3}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core4}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core5}
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
22 changes: 14 additions & 8 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,29 @@ end
@adjoint function Base.Array(VA::AbstractVectorOfArray)
adj = let VA=VA
function Array_adjoint(y)
VA = copy(VA)
VA = recursivecopy(VA)
copyto!(VA, y)
return (VA,)
end
end
Array(VA), adj
end

@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
function adjoint(y)
(recursivecopy(parent(y)), map(_ -> nothing, I)...)
end
return view(A, I...), adjoint
end

@adjoint function Base.view(A::AbstractVectorOfArray, I...)
adj = let A = A, I = I
function view_adjoint(y)
A = zero(A)
view(A, I...) .= y
return (A, map(_ -> nothing, I)...)
end
function view_adjoint(y)
A = recursivecopy(parent(y))
recursivefill!(A, zero(eltype(A)))
A[I...] .= y
return (A, map(_ -> nothing, I)...)
end
view(A, I...), adj
view(A, I...), view_adjoint
end

ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end

function recursivecopy(a::AbstractVectorOfArray)
b = copy(a)
b.u = recursivecopy.(a.u)
b.u .= recursivecopy.(a.u)
return b
end

Expand Down