From 64a83b79ecae56c9922cfe72d4cb432afdc8de4b Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 26 Sep 2021 02:38:14 -0500 Subject: [PATCH] ArrayInterface.axes for ReshapedReinterpretArray This specializes `axes` for ReshapedReinterpretArray, and streamlines a bit of the code surrounding this type. --- src/ArrayInterface.jl | 8 ++-- src/axes.jl | 15 ++++-- src/stridelayout.jl | 108 +++++++++++++++++++----------------------- test/runtests.jl | 15 +++++- 4 files changed, 76 insertions(+), 70 deletions(-) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index aa263b3a7..92e3e7f41 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -15,12 +15,10 @@ using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretAr const CanonicalInt = Union{Int,StaticInt} -@static if VERSION ≥ v"1.6.0-DEV.1581" - _is_reshaped(::Type{ReinterpretArray{T,N,S,A,true}}) where {T,N,S,A} = true - _is_reshaped(::Type{ReinterpretArray{T,N,S,A,false}}) where {T,N,S,A} = false -else - _is_reshaped(::Type{ReinterpretArray{T,N,S,A}}) where {T,N,S,A} = false +@static if isdefined(Base, :ReshapedReinterpretArray) + _is_reshaped(::Type{<:Base.ReshapedReinterpretArray}) = true end +_is_reshaped(::Type{<:ReinterpretArray}) = false Base.@pure __parameterless_type(T) = Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) diff --git a/src/axes.jl b/src/axes.jl index 1d13eb5e6..8e2c64ab7 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -149,10 +149,10 @@ end @inline _axes(A::SubArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version @inline function _axes(A::ReinterpretArray{T,N,S}, dim::Integer) where {T,N,S} - if _is_reshaped(typeof(A)) && (sizeof(S) > sizeof(T)) && dim == 1 - return One():static(div(sizeof(S), sizeof(T))) - end - Base.axes(A, Int(dim)) # TODO implement ArrayInterface version + if _is_reshaped(typeof(A)) && (sizeof(S) > sizeof(T)) && dim == 1 + return One():static(div(sizeof(S), sizeof(T))) + end + Base.axes(A, Int(dim)) # TODO implement ArrayInterface version end @inline _axes(A::Base.ReshapedArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version @@ -174,6 +174,13 @@ axes(A::Union{Transpose,Adjoint}) = _axes(A, parent(A)) _axes(A::Union{Transpose,Adjoint}, p::AbstractVector) = (One():One(), axes(p, One())) _axes(A::Union{Transpose,Adjoint}, p::AbstractMatrix) = (axes(p, StaticInt(2)), axes(p, One())) axes(A::SubArray) = Base.axes(A) # TODO implement ArrayInterface version +if isdefined(Base, :ReshapedReinterpretArray) + function axes(A::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S} + sizeof(S) > sizeof(T) && return (_axes(A, 1), axes(parent(A))...) + sizeof(S) < sizeof(T) && return Base.tail(axes(parent(A))) + return axes(parent(A)) + end +end axes(A::ReinterpretArray) = Base.axes(A) # TODO implement ArrayInterface version axes(A::Base.ReshapedArray) = Base.axes(A) # TODO implement ArrayInterface version axes(A::CartesianIndices) = A.indices diff --git a/src/stridelayout.jl b/src/stridelayout.jl index e5490e7ee..8bc1add14 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -237,31 +237,31 @@ stride_rank(x, i) = stride_rank(x)[i] function stride_rank(::Type{R}) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}} return nstatic(Val(N)) end -if VERSION ≥ v"1.6.0-DEV.1581" - @inline function stride_rank(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}} - NA == NB ? stride_rank(B) : _stride_rank_reinterpret(stride_rank(B), gt(StaticInt{NB}(), StaticInt{NA}())) - end - @inline _stride_rank_reinterpret(sr, ::False) = (One(), map(Base.Fix2(+,One()),sr)...) - @inline _stride_rank_reinterpret(sr::Tuple{One,Vararg}, ::True) = map(Base.Fix2(-,One()), tail(sr)) - # if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface - # doesn't currently have a means of representing. - @inline function contiguous_axis(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}} - _reinterpret_contiguous_axis(stride_rank(B), dense_dims(B), contiguous_axis(B), gt(StaticInt{NB}(), StaticInt{NA}())) - end - @inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::False) = One() - @inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::True) = Zero() - @generated function _reinterpret_contiguous_axis(t::Tuple{One,Vararg{StaticInt,N}}, d::Tuple{True,Vararg{StaticBool,N}}, ::One, ::True) where {N} - for n in 1:N - if t.parameters[n+1].parameters[1] === 2 - if d.parameters[n+1] === True - return :(StaticInt{$n}()) - else - return :(Zero()) +if isdefined(Base, :ReshapedReinterpretArray) + @inline function stride_rank(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReshapedReinterpretArray{<:Any, NA, <:Any, B}} + NA == NB ? stride_rank(B) : _stride_rank_reinterpret(stride_rank(B), gt(StaticInt{NB}(), StaticInt{NA}())) + end + @inline _stride_rank_reinterpret(sr, ::False) = (One(), map(Base.Fix2(+,One()),sr)...) + @inline _stride_rank_reinterpret(sr::Tuple{One,Vararg}, ::True) = map(Base.Fix2(-,One()), tail(sr)) + # if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface + # doesn't currently have a means of representing. + @inline function contiguous_axis(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReshapedReinterpretArray{<:Any, NA, <:Any, B}} + _reinterpret_contiguous_axis(stride_rank(B), dense_dims(B), contiguous_axis(B), gt(StaticInt{NB}(), StaticInt{NA}())) + end + @inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::False) = One() + @inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::True) = Zero() + @generated function _reinterpret_contiguous_axis(t::Tuple{One,Vararg{StaticInt,N}}, d::Tuple{True,Vararg{StaticBool,N}}, ::One, ::True) where {N} + for n in 1:N + if t.parameters[n+1].parameters[1] === 2 + if d.parameters[n+1] === True + return :(StaticInt{$n}()) + else + return :(Zero()) + end + end end - end + :(Zero()) end - :(Zero()) - end end function stride_rank(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M} @@ -376,8 +376,8 @@ end function dense_dims(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} return _dense_dims(S, dense_dims(A), Val(stride_rank(A))) end -if VERSION ≥ v"1.6.0-DEV.1581" - @inline function dense_dims(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}} +if isdefined(Base, :ReshapedReinterpretArray) + @inline function dense_dims(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReshapedReinterpretArray{<:Any, NA, <:Any, B}} ddb = dense_dims(B) IfElse.ifelse(Static.le(StaticInt(NB), StaticInt(NA)), (True(), ddb...), Base.tail(ddb)) end @@ -539,47 +539,35 @@ end @inline bmap(f::F, t::Tuple{}, x::Number) where {F} = () @inline bmap(f::F, t::Tuple{T}, x::Number) where {F, T} = (f(first(t),x), ) @inline bmap(f::F, t::Tuple, x::Number) where {F} = (f(first(t),x), bmap(f, Base.tail(t), x)...) -@static if VERSION ≥ v"1.6.0-DEV.1581" - # from `reinterpret(reshape, ...)` - @inline function strides(A::Base.ReinterpretArray{R, N, T, B, true}) where {R,N,T,B} - P = strides(parent(A)) - if sizeof(R) == sizeof(T) - P - elseif sizeof(R) > sizeof(T) - x = Base.tail(P) - fx = first(x) - if fx isa Int - (One(), bmap(Base.sdiv_int, Base.tail(x), fx)...) - else - (One(), bmap(÷, Base.tail(x), fx)...) - end - else - (One(), bmap(*, P, StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...) - end - end - # plain `reinterpret(...)` - @inline function strides(A::Base.ReinterpretArray{R, N, T, B, false}) where {R,N,T,B} - P = strides(parent(A)) - if sizeof(R) == sizeof(T) - P - elseif sizeof(R) > sizeof(T) - (first(P), bmap(÷, Base.tail(P), StaticInt(sizeof(R)) ÷ StaticInt(sizeof(T)))...) - else # sizeof(R) < sizeof(T) - (first(P), bmap(*, Base.tail(P), StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...) +@static if isdefined(Base, :ReshapedReinterpretArray) + # from `reinterpret(reshape, ...)` + @inline function strides(A::Base.ReshapedReinterpretArray{R, N, T}) where {R,N,T} + P = strides(parent(A)) + if sizeof(R) == sizeof(T) + P + elseif sizeof(R) > sizeof(T) + x = Base.tail(P) + fx = first(x) + if fx isa Int + (One(), bmap(Base.sdiv_int, Base.tail(x), fx)...) + else + (One(), bmap(÷, Base.tail(x), fx)...) + end + else + (One(), bmap(*, P, StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...) + end end - end -else - # plain `reinterpret(...)` - @inline function strides(A::Base.ReinterpretArray{R, N, T}) where {R,N,T} +end +# plain `reinterpret(...)` +@inline function strides(A::Base.ReinterpretArray{R, N, T}) where {R,N,T} P = strides(parent(A)) if sizeof(R) == sizeof(T) - P + P elseif sizeof(R) > sizeof(T) - (first(P), bmap(÷, Base.tail(P), StaticInt(sizeof(R)) ÷ StaticInt(sizeof(T)))...) + (first(P), bmap(÷, Base.tail(P), StaticInt(sizeof(R)) ÷ StaticInt(sizeof(T)))...) else # sizeof(R) < sizeof(T) - (first(P), bmap(*, Base.tail(P), StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...) + (first(P), bmap(*, Base.tail(P), StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...) end - end end #@inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A)) diff --git a/test/runtests.jl b/test/runtests.jl index da99b7222..06649d163 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -433,7 +433,7 @@ ArrayInterface.parent_type(::Type{DenseWrapper{T,N,P}}) where {T,N,P} = P Am = @MMatrix rand(2,10); @test @inferred(ArrayInterface.strides(view(Am,1,:))) === (StaticInt(2),) - if VERSION ≥ v"1.6.0-DEV.1581" # reinterpret(reshape,...) tests + if isdefined(Base, :ReshapedReinterpretArray) # reinterpret(reshape,...) tests C1 = reinterpret(reshape, Float64, PermutedDimsArray(Array{Complex{Float64}}(undef, 3,4,5), (2,1,3))); C2 = reinterpret(reshape, Complex{Float64}, PermutedDimsArray(view(A,1:2,:,:), (1,3,2))); C3 = reinterpret(reshape, Complex{Float64}, PermutedDimsArray(Wrapper(reshape(view(x, 1:24), (2,3,4))), (1,3,2))); @@ -717,6 +717,8 @@ end @test @inferred(ArrayInterface.strides(u_view)) == (4,) @test @inferred(ArrayInterface.strides(u_view_reinterpreted)) == (4,) @test @inferred(ArrayInterface.strides(u_view_reshaped)) == (4, 4) + + @test_broken @inferred(ArrayInterface.axes(u_vectors)) isa ArrayInterface.axes_types(u_vectors) end @test ArrayInterface.can_avx(ArrayInterface.can_avx) == false @@ -834,6 +836,17 @@ end @test @inferred(lzaxis[1:2]) === axis[1:2] @test @inferred(ArrayInterface.axes(Array{Float64}(undef, 4)')) === (StaticInt(1):StaticInt(1),Base.OneTo(4)) @test @inferred(ArrayInterface.axes(Array{Float64}(undef, 4, 3)')) === (Base.OneTo(3),Base.OneTo(4)) + + if isdefined(Base, :ReshapedReinterpretArray) + a = rand(3, 5) + ua = reinterpret(reshape, UInt64, a) + @test ArrayInterface.axes(ua) === ArrayInterface.axes(a) + @test @inferred(ArrayInterface.axes(ua)) isa ArrayInterface.axes_types(ua) + u8a = reinterpret(reshape, UInt8, a) + @test @inferred(ArrayInterface.axes(u8a)) isa ArrayInterface.axes_types(u8a) + fa = reinterpret(reshape, Float64, copy(u8a)) + @inferred(ArrayInterface.axes(fa)) isa ArrayInterface.axes_types(fa) + end end @testset "arrayinterface" begin