Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretAr

const CanonicalInt = Union{Int,StaticInt}

@static if VERSION ≥ v"1.6.0-DEV.1581"
_is_reshaped(::Type{ReinterpretArray{T,N,S,A,true}}) where {T,N,S,A} = true
_is_reshaped(::Type{ReinterpretArray{T,N,S,A,false}}) where {T,N,S,A} = false
else
_is_reshaped(::Type{ReinterpretArray{T,N,S,A}}) where {T,N,S,A} = false
@static if isdefined(Base, :ReshapedReinterpretArray)
_is_reshaped(::Type{<:Base.ReshapedReinterpretArray}) = true
end
_is_reshaped(::Type{<:ReinterpretArray}) = false

Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
Expand Down
15 changes: 11 additions & 4 deletions src/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,10 @@ end

@inline _axes(A::SubArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
@inline function _axes(A::ReinterpretArray{T,N,S}, dim::Integer) where {T,N,S}
if _is_reshaped(typeof(A)) && (sizeof(S) > sizeof(T)) && dim == 1
return One():static(div(sizeof(S), sizeof(T)))
end
Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
if _is_reshaped(typeof(A)) && (sizeof(S) > sizeof(T)) && dim == 1
return One():static(div(sizeof(S), sizeof(T)))
end
Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
end
@inline _axes(A::Base.ReshapedArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version

Expand All @@ -174,6 +174,13 @@ axes(A::Union{Transpose,Adjoint}) = _axes(A, parent(A))
_axes(A::Union{Transpose,Adjoint}, p::AbstractVector) = (One():One(), axes(p, One()))
_axes(A::Union{Transpose,Adjoint}, p::AbstractMatrix) = (axes(p, StaticInt(2)), axes(p, One()))
axes(A::SubArray) = Base.axes(A) # TODO implement ArrayInterface version
if isdefined(Base, :ReshapedReinterpretArray)
function axes(A::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S}
sizeof(S) > sizeof(T) && return (_axes(A, 1), axes(parent(A))...)
sizeof(S) < sizeof(T) && return Base.tail(axes(parent(A)))
return axes(parent(A))
end
end
axes(A::ReinterpretArray) = Base.axes(A) # TODO implement ArrayInterface version
axes(A::Base.ReshapedArray) = Base.axes(A) # TODO implement ArrayInterface version
axes(A::CartesianIndices) = A.indices
Expand Down
108 changes: 48 additions & 60 deletions src/stridelayout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,31 +237,31 @@ stride_rank(x, i) = stride_rank(x)[i]
function stride_rank(::Type{R}) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}}
return nstatic(Val(N))
end
if VERSION ≥ v"1.6.0-DEV.1581"
@inline function stride_rank(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
NA == NB ? stride_rank(B) : _stride_rank_reinterpret(stride_rank(B), gt(StaticInt{NB}(), StaticInt{NA}()))
end
@inline _stride_rank_reinterpret(sr, ::False) = (One(), map(Base.Fix2(+,One()),sr)...)
@inline _stride_rank_reinterpret(sr::Tuple{One,Vararg}, ::True) = map(Base.Fix2(-,One()), tail(sr))
# if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface
# doesn't currently have a means of representing.
@inline function contiguous_axis(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
_reinterpret_contiguous_axis(stride_rank(B), dense_dims(B), contiguous_axis(B), gt(StaticInt{NB}(), StaticInt{NA}()))
end
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::False) = One()
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::True) = Zero()
@generated function _reinterpret_contiguous_axis(t::Tuple{One,Vararg{StaticInt,N}}, d::Tuple{True,Vararg{StaticBool,N}}, ::One, ::True) where {N}
for n in 1:N
if t.parameters[n+1].parameters[1] === 2
if d.parameters[n+1] === True
return :(StaticInt{$n}())
else
return :(Zero())
if isdefined(Base, :ReshapedReinterpretArray)
@inline function stride_rank(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReshapedReinterpretArray{<:Any, NA, <:Any, B}}
NA == NB ? stride_rank(B) : _stride_rank_reinterpret(stride_rank(B), gt(StaticInt{NB}(), StaticInt{NA}()))
end
@inline _stride_rank_reinterpret(sr, ::False) = (One(), map(Base.Fix2(+,One()),sr)...)
@inline _stride_rank_reinterpret(sr::Tuple{One,Vararg}, ::True) = map(Base.Fix2(-,One()), tail(sr))
# if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface
# doesn't currently have a means of representing.
@inline function contiguous_axis(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReshapedReinterpretArray{<:Any, NA, <:Any, B}}
_reinterpret_contiguous_axis(stride_rank(B), dense_dims(B), contiguous_axis(B), gt(StaticInt{NB}(), StaticInt{NA}()))
end
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::False) = One()
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::True) = Zero()
@generated function _reinterpret_contiguous_axis(t::Tuple{One,Vararg{StaticInt,N}}, d::Tuple{True,Vararg{StaticBool,N}}, ::One, ::True) where {N}
for n in 1:N
if t.parameters[n+1].parameters[1] === 2
if d.parameters[n+1] === True
return :(StaticInt{$n}())
else
return :(Zero())
end
end
end
end
:(Zero())
end
:(Zero())
end
end

function stride_rank(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
Expand Down Expand Up @@ -376,8 +376,8 @@ end
function dense_dims(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
return _dense_dims(S, dense_dims(A), Val(stride_rank(A)))
end
if VERSION ≥ v"1.6.0-DEV.1581"
@inline function dense_dims(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
if isdefined(Base, :ReshapedReinterpretArray)
@inline function dense_dims(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReshapedReinterpretArray{<:Any, NA, <:Any, B}}
ddb = dense_dims(B)
IfElse.ifelse(Static.le(StaticInt(NB), StaticInt(NA)), (True(), ddb...), Base.tail(ddb))
end
Expand Down Expand Up @@ -539,47 +539,35 @@ end
@inline bmap(f::F, t::Tuple{}, x::Number) where {F} = ()
@inline bmap(f::F, t::Tuple{T}, x::Number) where {F, T} = (f(first(t),x), )
@inline bmap(f::F, t::Tuple, x::Number) where {F} = (f(first(t),x), bmap(f, Base.tail(t), x)...)
@static if VERSION ≥ v"1.6.0-DEV.1581"
# from `reinterpret(reshape, ...)`
@inline function strides(A::Base.ReinterpretArray{R, N, T, B, true}) where {R,N,T,B}
P = strides(parent(A))
if sizeof(R) == sizeof(T)
P
elseif sizeof(R) > sizeof(T)
x = Base.tail(P)
fx = first(x)
if fx isa Int
(One(), bmap(Base.sdiv_int, Base.tail(x), fx)...)
else
(One(), bmap(÷, Base.tail(x), fx)...)
end
else
(One(), bmap(*, P, StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
end
end
# plain `reinterpret(...)`
@inline function strides(A::Base.ReinterpretArray{R, N, T, B, false}) where {R,N,T,B}
P = strides(parent(A))
if sizeof(R) == sizeof(T)
P
elseif sizeof(R) > sizeof(T)
(first(P), bmap(÷, Base.tail(P), StaticInt(sizeof(R)) ÷ StaticInt(sizeof(T)))...)
else # sizeof(R) < sizeof(T)
(first(P), bmap(*, Base.tail(P), StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
@static if isdefined(Base, :ReshapedReinterpretArray)
# from `reinterpret(reshape, ...)`
@inline function strides(A::Base.ReshapedReinterpretArray{R, N, T}) where {R,N,T}
P = strides(parent(A))
if sizeof(R) == sizeof(T)
P
elseif sizeof(R) > sizeof(T)
x = Base.tail(P)
fx = first(x)
if fx isa Int
(One(), bmap(Base.sdiv_int, Base.tail(x), fx)...)
else
(One(), bmap(÷, Base.tail(x), fx)...)
end
else
(One(), bmap(*, P, StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
end
end
end
else
# plain `reinterpret(...)`
@inline function strides(A::Base.ReinterpretArray{R, N, T}) where {R,N,T}
end
# plain `reinterpret(...)`
@inline function strides(A::Base.ReinterpretArray{R, N, T}) where {R,N,T}
P = strides(parent(A))
if sizeof(R) == sizeof(T)
P
P
elseif sizeof(R) > sizeof(T)
(first(P), bmap(÷, Base.tail(P), StaticInt(sizeof(R)) ÷ StaticInt(sizeof(T)))...)
(first(P), bmap(÷, Base.tail(P), StaticInt(sizeof(R)) ÷ StaticInt(sizeof(T)))...)
else # sizeof(R) < sizeof(T)
(first(P), bmap(*, Base.tail(P), StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
(first(P), bmap(*, Base.tail(P), StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
end
end
end
#@inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A))

Expand Down
15 changes: 14 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ ArrayInterface.parent_type(::Type{DenseWrapper{T,N,P}}) where {T,N,P} = P
Am = @MMatrix rand(2,10);
@test @inferred(ArrayInterface.strides(view(Am,1,:))) === (StaticInt(2),)

if VERSION ≥ v"1.6.0-DEV.1581" # reinterpret(reshape,...) tests
if isdefined(Base, :ReshapedReinterpretArray) # reinterpret(reshape,...) tests
C1 = reinterpret(reshape, Float64, PermutedDimsArray(Array{Complex{Float64}}(undef, 3,4,5), (2,1,3)));
C2 = reinterpret(reshape, Complex{Float64}, PermutedDimsArray(view(A,1:2,:,:), (1,3,2)));
C3 = reinterpret(reshape, Complex{Float64}, PermutedDimsArray(Wrapper(reshape(view(x, 1:24), (2,3,4))), (1,3,2)));
Expand Down Expand Up @@ -717,6 +717,8 @@ end
@test @inferred(ArrayInterface.strides(u_view)) == (4,)
@test @inferred(ArrayInterface.strides(u_view_reinterpreted)) == (4,)
@test @inferred(ArrayInterface.strides(u_view_reshaped)) == (4, 4)

@test_broken @inferred(ArrayInterface.axes(u_vectors)) isa ArrayInterface.axes_types(u_vectors)
Copy link
Collaborator

@chriselrod chriselrod Sep 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take it that inference fails here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's that it isn't of the type returned by axes_types.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to do some clean up on axes and can fix this. I think we need to just go all in on OptionallyStaticUnitRange here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, and no hurry (there's nothing wrong with a broken test, it's essentially a TODO list).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you want me to fix this before this PR went forward or afterwards. Just want to make sure this PR isn't stalled because of me.

end

@test ArrayInterface.can_avx(ArrayInterface.can_avx) == false
Expand Down Expand Up @@ -834,6 +836,17 @@ end
@test @inferred(lzaxis[1:2]) === axis[1:2]
@test @inferred(ArrayInterface.axes(Array{Float64}(undef, 4)')) === (StaticInt(1):StaticInt(1),Base.OneTo(4))
@test @inferred(ArrayInterface.axes(Array{Float64}(undef, 4, 3)')) === (Base.OneTo(3),Base.OneTo(4))

if isdefined(Base, :ReshapedReinterpretArray)
a = rand(3, 5)
ua = reinterpret(reshape, UInt64, a)
@test ArrayInterface.axes(ua) === ArrayInterface.axes(a)
@test @inferred(ArrayInterface.axes(ua)) isa ArrayInterface.axes_types(ua)
u8a = reinterpret(reshape, UInt8, a)
@test @inferred(ArrayInterface.axes(u8a)) isa ArrayInterface.axes_types(u8a)
fa = reinterpret(reshape, Float64, copy(u8a))
@inferred(ArrayInterface.axes(fa)) isa ArrayInterface.axes_types(fa)
end
end

@testset "arrayinterface" begin
Expand Down