Skip to content

Commit fbe8271

Browse files
authored
Merge pull request #1328 from FluxML/bc/rm-getindex-adjoint
Excise getindex adjoint
2 parents f755127 + 08e0cd8 commit fbe8271

File tree

4 files changed

+10
-28
lines changed

4 files changed

+10
-28
lines changed

src/deprecated.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,9 @@ macro nograd(ex)
6565
end
6666
return blk
6767
end
68+
69+
# Internal function used by some downstream packages.
70+
# Removing this completely would require some tricky registry changes,
71+
# but leaving it as a vestigial function is much easier.
72+
# See https://github.com/FluxML/Zygote.jl/pull/1328 for more context.
73+
function ∇getindex end

src/lib/array.jl

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,6 @@ end
4141
@adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,)
4242
@adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,)
4343

44-
@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)
45-
46-
@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds)
47-
48-
∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin
49-
if inds isa NTuple{N,Int} && T <: Number
50-
dx = OneElement(dy, inds, axes(x))
51-
elseif inds isa NTuple{<:Any, Integer}
52-
dx = _zero(x, typeof(dy))
53-
dx[inds...] = dy
54-
else
55-
dx = _zero(x, eltype(dy))
56-
dxv = view(dx, inds...)
57-
dxv .= accum.(dxv, _droplike(dy, dxv))
58-
end
59-
return (_project(x, dx), map(_->nothing, inds)...)
60-
end
61-
6244
"""
6345
OneElement(val, ind, axes) <: AbstractArray
6446

test/gradcheck.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,11 @@ end
174174

175175
# Ensure that nothings work with numeric types.
176176
_, back = Zygote.pullback(getindex, randn(4), [1])
177-
@test back([nothing]) == (zeros(4), nothing)
177+
@test back([nothing]) === nothing
178178

179179
# Ensure that nothings work with non-numeric types.
180180
_, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1])
181-
@test back([nothing]) == (nothing, nothing)
181+
@test back([nothing]) === nothing
182182
end
183183

184184
@testset "view" begin

test/utils.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,8 @@ using ForwardDiff
22
using Zygote: hessian_dual, hessian_reverse
33

44
@testset "hessian: $hess" for hess in [hessian_dual, hessian_reverse]
5-
6-
if hess == hessian_dual
7-
@test hess(x -> x[1]*x[2], randn(2)) [0 1; 1 0]
8-
@test hess(((x,y),) -> x*y, randn(2)) [0 1; 1 0] # original docstring version
9-
else
10-
@test_broken hess(x -> x[1]*x[2], randn(2)) [0 1; 1 0] # can't differentiate ∇getindex
11-
@test_broken hess(((x,y),) -> x*y, randn(2)) [0 1; 1 0]
12-
end
5+
@test hess(x -> x[1]*x[2], randn(2)) [0 1; 1 0]
6+
@test hess(((x,y),) -> x*y, randn(2)) [0 1; 1 0] # original docstring version
137
@test hess(x -> sum(x.^3), [1 2; 3 4]) Diagonal([6, 18, 12, 24])
148
@test hess(sin, pi/2) -1
159

0 commit comments

Comments
 (0)