Skip to content

Commit b1f8d93

Browse files
committed
Vector indexing for OneElement
1 parent 4f8a966 commit b1f8d93

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

src/oneelement.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,27 @@ OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz)
4242

4343
Base.size(A::OneElement) = map(length, A.axes)
4444
Base.axes(A::OneElement) = A.axes
45+
Base.getindex(A::OneElement{T,0}) where {T} = getindex_value(A)
4546
Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, kj::Vararg{Int,N}) where {T,N}
4647
@boundscheck checkbounds(A, kj...)
4748
ifelse(kj == A.ind, A.val, zero(T))
4849
end
50+
const VectorIndsWithColon = Union{AbstractRange{Int}, Colon, Int}
51+
const VectorInds = Union{AbstractRange{Int}, Int}
52+
# retain the values from Ainds corresponding to the vector indices in inds
53+
_index_shape(Ainds, inds::Tuple{Integer, Vararg{Any}}) = _index_shape(Base.tail(Ainds), Base.tail(inds))
54+
_index_shape(Ainds, inds::Tuple{AbstractVector, Vararg{Any}}) = (Ainds[1], _index_shape(Base.tail(Ainds), Base.tail(inds))...)
55+
_index_shape(::Tuple{}, ::Tuple{}) = ()
56+
@inline function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorInds,N}) where {T,N}
57+
@boundscheck checkbounds(A, inds...)
58+
shape = _index_shape(inds, inds)
59+
nzind = _index_shape(A.ind, inds) .- first.(shape) .+ firstindex.(shape)
60+
containsval = all(in.(A.ind, inds))
61+
OneElement(getindex_value(A), containsval ? Int.(nzind) : Int.(lastindex.(shape,1)).+1, axes.(shape,1))
62+
end
63+
Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorIndsWithColon,N}) where {T,N}
64+
getindex(A, Base.to_indices(A, inds)...)
65+
end
4966

5067
"""
5168
nzind(A::OneElement{T,N}) -> CartesianIndex{N}

test/runtests.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2148,10 +2148,12 @@ end
21482148
@test FillArrays.nzind(A) == CartesianIndex()
21492149
@test A == Fill(2, ())
21502150
@test A[] === 2
2151+
@test A[1] === A[1,1] === 2
21512152

21522153
e₁ = OneElement(2, 5)
21532154
@test e₁ == [0,1,0,0,0]
21542155
@test FillArrays.nzind(e₁) == CartesianIndex(2)
2156+
@test e₁[2] === e₁[2,1] === e₁[2,1,1] === 1
21552157
@test_throws BoundsError e₁[6]
21562158

21572159
f₁ = AbstractArray{Float64}(e₁)
@@ -2193,6 +2195,82 @@ end
21932195
@test A[1,1] === A[1,2] === A[2,1] === zero(S)
21942196
end
21952197

2198+
@testset "Vector indexing" begin
2199+
@testset "1D" begin
2200+
A = OneElement(2, 2, 4)
2201+
@test @inferred(A[:]) === @inferred(A[axes(A)...]) === A
2202+
@test @inferred(A[3:4]) isa OneElement{Int,1}
2203+
@test @inferred(A[3:4]) == Zeros(2)
2204+
@test @inferred(A[1:2]) === OneElement(2, 2, 2)
2205+
@test @inferred(A[2:3]) === OneElement(2, 1, 2)
2206+
@test @inferred(A[Base.IdentityUnitRange(2:3)]) isa OneElement{Int,1}
2207+
@test @inferred(A[Base.IdentityUnitRange(2:3)]) == OneElement(2,(2,),(Base.IdentityUnitRange(2:3),))
2208+
@test A[:,:] == reshape(A, size(A)..., 1)
2209+
2210+
B = OneElement(2, (2,), (Base.IdentityUnitRange(-1:4),))
2211+
@test @inferred(A[:]) === @inferred(A[axes(A)...]) === A
2212+
@test @inferred(A[3:4]) isa OneElement{Int,1}
2213+
@test @inferred(A[3:4]) == Zeros(2)
2214+
@test @inferred(A[2:3]) === OneElement(2, 1, 2)
2215+
2216+
C = OneElement(2, (2,), (Base.OneTo(big(4)),))
2217+
@test @inferred(C[1:4]) === OneElement(2, 2, 4)
2218+
2219+
D = OneElement(2, (2,), (InfiniteArrays.OneToInf(),))
2220+
D2 = D[:]
2221+
@test axes(D2) == axes(D)
2222+
@test D2[2] == D[2]
2223+
D3 = D[axes(D)...]
2224+
@test axes(D3) == axes(D)
2225+
@test D3[2] == D[2]
2226+
end
2227+
@testset "2D" begin
2228+
A = OneElement(2, (2,3), (4,5))
2229+
@test @inferred(A[:,:]) === @inferred(A[axes(A)...]) === A
2230+
@test @inferred(A[:,1]) isa OneElement{Int,1}
2231+
@test @inferred(A[:,1]) == Zeros(4)
2232+
@test @inferred(A[1,:]) isa OneElement{Int,1}
2233+
@test @inferred(A[1,:]) == Zeros(5)
2234+
@test @inferred(A[:,3]) === OneElement(2, 2, 4)
2235+
@test @inferred(A[2,:]) === OneElement(2, 3, 5)
2236+
@test @inferred(A[1:1,:]) isa OneElement{Int,2}
2237+
@test @inferred(A[1:1,:]) == Zeros(1,5)
2238+
@test @inferred(A[4:4,:]) isa OneElement{Int,2}
2239+
@test @inferred(A[4:4,:]) == Zeros(1,5)
2240+
@test @inferred(A[2:2,:]) === OneElement(2, (1,3), (1,5))
2241+
@test @inferred(A[1:4,:]) === OneElement(2, (2,3), (4,5))
2242+
@test @inferred(A[:,3:3]) === OneElement(2, (2,1), (4,1))
2243+
@test @inferred(A[:,1:5]) === OneElement(2, (2,3), (4,5))
2244+
@test @inferred(A[1:4,1:4]) === OneElement(2, (2,3), (4,4))
2245+
@test @inferred(A[2:4,2:4]) === OneElement(2, (1,2), (3,3))
2246+
@test @inferred(A[2:4,3:4]) === OneElement(2, (1,1), (3,2))
2247+
@test @inferred(A[4:4,5:5]) isa OneElement{Int,2}
2248+
@test @inferred(A[4:4,5:5]) == Zeros(1,1)
2249+
@test @inferred(A[Base.IdentityUnitRange(2:4), :]) isa OneElement{Int,2}
2250+
@test axes(A[Base.IdentityUnitRange(2:4), :]) == (Base.IdentityUnitRange(2:4), axes(A,2))
2251+
@test @inferred(A[:,:,:]) == reshape(A, size(A)...,1)
2252+
2253+
B = OneElement(2, (2,3), (Base.IdentityUnitRange(2:4),Base.IdentityUnitRange(2:5)))
2254+
@test @inferred(B[:,:]) === @inferred(B[axes(B)...]) === B
2255+
@test @inferred(B[:,3]) === OneElement(2, (2,), (Base.IdentityUnitRange(2:4),))
2256+
@test @inferred(B[3:4, 4:5]) isa OneElement{Int,2}
2257+
@test @inferred(B[3:4, 4:5]) == Zeros(2,2)
2258+
b = @inferred(B[Base.IdentityUnitRange(3:4), Base.IdentityUnitRange(4:5)])
2259+
@test b == Zeros(axes(b))
2260+
2261+
C = OneElement(2, (2,3), (Base.OneTo(big(4)), Base.OneTo(big(5))))
2262+
@test @inferred(C[1:4, 1:5]) === OneElement(2, (2,3), Int.(size(C)))
2263+
2264+
D = OneElement(2, (2,3), (InfiniteArrays.OneToInf(), InfiniteArrays.OneToInf()))
2265+
D2 = @inferred D[:,:]
2266+
@test axes(D2) == axes(D)
2267+
@test D2[2,3] == D[2,3]
2268+
D3 = @inferred D[axes(D)...]
2269+
@test axes(D3) == axes(D)
2270+
@test D3[2,3] == D[2,3]
2271+
end
2272+
end
2273+
21962274
@testset "adjoint/transpose" begin
21972275
A = OneElement(3im, (2,4), (4,6))
21982276
@test A' === OneElement(-3im, (4,2), (6,4))

0 commit comments

Comments
 (0)