diff --git a/src/axes.jl b/src/axes.jl index bd4fd85ba..03eacb740 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -88,7 +88,7 @@ Returns the axis associated with each dimension of `A` or dimension `dim`. exception of a handful of types replace `Base.OneTo{Int}` with `ArrayInterface.SOneTo`. For example, the axis along the first dimension of `Transpose{T,<:AbstractVector{T}}` and `Adjoint{T,<:AbstractVector{T}}` can be represented by `SOneTo(1)`. Similarly, -`Base.ReinterpretArray`'s first axis may be statically sized. +`Base.ReinterpretArray`'s first axis may be statically sized. """ @inline axes(A) = Base.axes(A) axes(A::ReshapedArray) = Base.axes(A) @@ -112,7 +112,22 @@ end return getfield(axes(A), Int(dim)) end end -axes(A::SubArray, dim) = Base.axes(getindex(A.indices, to_parent_dims(A, to_dims(A, dim))), 1) + +@inline function axes(A::SubArray, dim::Integer) + if dim > ndims(A) + return OneTo(1) + else + return axes(getindex(A.indices, to_parent_dims(A, to_dims(A, dim))), 1) + end +end +@inline function axes(A::SubArray, ::StaticInt{dim}) where {dim} + if dim > ndims(A) + return SOneTo{1}() + else + return axes(getindex(A.indices, to_parent_dims(A, to_dims(A, dim))), 1) + end +end + if isdefined(Base, :ReshapedReinterpretArray) function axes_types(::Type{A}) where {T,N,S,A<:Base.ReshapedReinterpretArray{T,N,S}} if sizeof(S) > sizeof(T) diff --git a/test/axes.jl b/test/axes.jl index c357d19da..2ea85a94e 100644 --- a/test/axes.jl +++ b/test/axes.jl @@ -44,6 +44,29 @@ m = Array{Float64}(undef, 4, 3) @test_throws DimensionMismatch ArrayInterface.LazyAxis{0}(A) end +@testset "`axes(A, dim)`` with `dim > ndims(A)` (#224)" begin + m = 2 + n = 3 + B = Array{Float64, 2}(undef, m, n) + b = view(B, :, 1) + + @test @inferred(ArrayInterface.axes(B, 1)) == 1:m + @test @inferred(ArrayInterface.axes(B, 2)) == 1:n + @test @inferred(ArrayInterface.axes(B, 3)) == 1:1 + + @test @inferred(ArrayInterface.axes(B, static(1))) == 1:m + @test @inferred(ArrayInterface.axes(B, static(2))) == 1:n + @test @inferred(ArrayInterface.axes(B, static(3))) == 1:1 + + @test @inferred(ArrayInterface.axes(b, 1)) == 1:m + @test @inferred(ArrayInterface.axes(b, 2)) == 1:1 + @test @inferred(ArrayInterface.axes(b, 3)) == 1:1 + + @test @inferred(ArrayInterface.axes(b, static(1))) == 1:m + @test @inferred(ArrayInterface.axes(b, static(2))) == 1:1 + @test @inferred(ArrayInterface.axes(b, static(3))) == 1:1 +end + if isdefined(Base, :ReshapedReinterpretArray) a = rand(3, 5) ua = reinterpret(reshape, UInt64, a)