@@ -81,16 +81,48 @@ For the `rrule` of `y = x[inds...]`, this function is roughly
8181`setindex(zero(x), dy, inds...)`, returning the array `dx`.
8282Differentiable. Includes `ProjectTo(x)(dx)`.
8383"""
84- function ∇getindex (x:: AbstractArray , dy, inds... )
84+ function ∇getindex (x:: AbstractArray{T,N} , dy, inds... ) where {T,N}
8585 # `to_indices` removes any logical indexing, colons, CartesianIndex etc,
8686 # leaving just Int / AbstractVector of Int
8787 plain_inds = Base. to_indices (x, inds)
88- dx = _setindex_zero (x, dy, plain_inds... )
89- ∇getindex! (dx, dy, plain_inds... )
88+ dx = if plain_inds isa NTuple{N, Int} && T<: Number
89+ # scalar indexing
90+ OneElement (dy, plain_inds, axes (x))
91+ else # some from slicing (potentially noncontigous)
92+ dx = _setindex_zero (x, dy, plain_inds... )
93+ ∇getindex! (dx, dy, plain_inds... )
94+ end
9095 return ProjectTo (x)(dx) # since we have x, may as well do this inside, not in rules
9196end
9297∇getindex (x:: AbstractArray , z:: AbstractZero , inds... ) = z
9398
99+ """
100+ OneElement(val, ind, axes) <: AbstractArray
101+
102+ Extremely simple `struct` used for the gradient of scalar `getindex`.
103+ """
104+ struct OneElement{T,N,I,A} <: AbstractArray{T,N}
105+ val:: T
106+ ind:: I
107+ axes:: A
108+ OneElement (val:: T , ind:: I , axes:: A ) where {T<: Number , I<: NTuple{N,Int} , A<: NTuple{N,AbstractUnitRange} } where {N} = new {T,N,I,A} (val, ind, axes)
109+ end
110+ Base. size (A:: OneElement ) = map (length, A. axes)
111+ Base. axes (A:: OneElement ) = A. axes
112+ Base. getindex (A:: OneElement{T,N} , i:: Vararg{Int,N} ) where {T,N} = ifelse (i== A. ind, A. val, zero (T))
113+
114+ function ChainRulesCore. add!! (xs:: AbstractArray{<:Any,N} , oe:: OneElement{<:Any,N} ) where {N}
115+ if ! ChainRulesCore. is_inplaceable_destination (xs)
116+ xs = collect (xs)
117+ end
118+ xs[oe. ind... ] += oe. val
119+ return xs
120+ end
121+
122+ Base.:(+ )(xs:: AbstractArray , oe:: OneElement ) = add!! (copy (xs), oe)
123+ Base.:(+ )(oe:: OneElement , xs:: AbstractArray ) = + (xs, oe)
124+ Base.:(+ )(oe1:: OneElement , oe2:: OneElement ) = + (collect (oe1), oe2)
125+
94126"""
95127 _setindex_zero(x, dy, inds...)
96128
0 commit comments