-
Notifications
You must be signed in to change notification settings - Fork 39
SubIndex, LinearSubIndex, and PermutedIndex types
#202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
6de3ecf
2d5c260
ace2c57
a963ebd
b309e1f
3da21dd
adaad82
d2fef13
62457e0
56bdcf0
58b9d7b
d391ba5
55ca0e5
aa14e0e
af92766
b424766
9f631ad
fc95c21
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ const MatrixIndex = ArrayIndex{2} | |
|
|
||
| const VectorIndex = ArrayIndex{1} | ||
|
|
||
| Base.ndims(::ArrayIndex{N}) where {N} = N | ||
| Base.ndims(::Type{<:ArrayIndex{N}}) where {N} = N | ||
|
|
||
| struct BidiagonalIndex <: MatrixIndex | ||
|
|
@@ -183,6 +184,10 @@ function BandedBlockBandedMatrixIndex( | |
| rowindobj, colindobj | ||
| end | ||
|
|
||
| Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1 | ||
| Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count | ||
| Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count | ||
|
|
||
| """ | ||
| StrideIndex(x) | ||
|
|
||
|
|
@@ -204,11 +209,124 @@ struct StrideIndex{N,R,C,S,O} <: ArrayIndex{N} | |
| end | ||
| end | ||
|
|
||
| Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1 | ||
| Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count | ||
| Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count | ||
| """ | ||
| PermutedIndex | ||
|
|
||
| Subtypes of `ArrayIndex` that is responsible for permuting each index prior to accessing | ||
| parent indices. | ||
| """ | ||
| struct PermutedIndex{N,I1,I2} <: ArrayIndex{N} | ||
| PermutedIndex{N,I1,I2}() where {N,I1,I2} = new{N,I1,I2}() | ||
| function PermutedIndex(p::Tuple{Vararg{StaticInt,N}}, ip::Tuple{Vararg{StaticInt}}) where {N} | ||
| PermutedIndex{N,known(p),known(ip)}() | ||
| end | ||
| end | ||
|
|
||
| function Base.getindex(x::PermutedIndex{2,(2,1),(2,)}, i::AbstractCartesianIndex{2}) | ||
| getfield(Tuple(i), 2) | ||
| end | ||
| @inline function Base.getindex(x::PermutedIndex{N,I1,I2}, i::AbstractCartesianIndex{N}) where {N,I1,I2} | ||
| return NDIndex(permute(Tuple(i), Val(I2))) | ||
| end | ||
|
|
||
| """ | ||
| SubIndex(indices) | ||
|
|
||
| Subtype of `ArrayIndex` that provides a multidimensional view of another `ArrayIndex`. | ||
| """ | ||
| struct SubIndex{N,I} <: ArrayIndex{N} | ||
| indices::I | ||
|
|
||
| SubIndex{N}(inds::Tuple) where {N} = new{N,typeof(inds)}(inds) | ||
| end | ||
|
|
||
| @inline function Base.getindex(x::SubIndex{N}, i::AbstractCartesianIndex{N}) where {N} | ||
| return NDIndex(_reindex(x.indices, Tuple(i))) | ||
| end | ||
| @generated function _reindex(subinds::S, inds::I) where {S,I} | ||
| inds_i = 1 | ||
| subinds_i = 1 | ||
| NS = known_length(S) | ||
| NI = known_length(I) | ||
| out = Expr(:tuple) | ||
| while inds_i <= NI | ||
| subinds_type = S.parameters[subinds_i] | ||
| if subinds_type <: Integer | ||
| push!(out.args, :(getfield(subinds, $subinds_i))) | ||
| subinds_i += 1 | ||
| elseif eltype(subinds_type) <: AbstractCartesianIndex | ||
| push!(out.args, :(Tuple(@inbounds(getfield(subinds, $subinds_i)[getfield(inds, $inds_i)]))...)) | ||
| inds_i += 1 | ||
| subinds_i += 1 | ||
| else | ||
| push!(out.args, :(@inbounds(getfield(subinds, $subinds_i)[getfield(inds, $inds_i)]))) | ||
| inds_i += 1 | ||
| subinds_i += 1 | ||
| end | ||
| end | ||
| if subinds_i <= NS | ||
| for i in subinds_i:NS | ||
| push!(out.args, :(getfield(subinds, $subinds_i))) | ||
| end | ||
| end | ||
| return Expr(:block, Expr(:meta, :inline), :($out)) | ||
| end | ||
|
|
||
| """ | ||
| LinearSubIndex(offset, stride) | ||
|
|
||
| Subtype of `ArrayIndex` that provides linear indexing for `Base.FastSubArray` and | ||
| `FastContiguousSubArray`. | ||
| """ | ||
| struct LinearSubIndex{O<:CanonicalInt,S<:CanonicalInt} <: VectorIndex | ||
| offset::O | ||
| stride::S | ||
| end | ||
|
|
||
| const OffsetIndex{O} = LinearSubIndex{O,StaticInt{1}} | ||
| OffsetIndex(offset::CanonicalInt) = LinearSubIndex(offset, static(1)) | ||
|
|
||
| @inline function Base.getindex(x::LinearSubIndex, i::CanonicalInt) | ||
| getfield(x, :offset) + getfield(x, :stride) * i | ||
| end | ||
|
|
||
| """ | ||
| ComposedIndex(i1, i2) | ||
|
|
||
| A subtype of `ArrayIndex` that lazily combines index `i1` and `i2`. Indexing a | ||
| `ComposedIndex` whith `i` is equivalent to `i2[i1[i]]`. | ||
| """ | ||
| struct ComposedIndex{N,I1,I2} <: ArrayIndex{N} | ||
| i1::I1 | ||
| i2::I2 | ||
|
|
||
| ComposedIndex(i1::I1, i2::I2) where {I1,I2} = new{ndims(I1),I1,I2}(i1, i2) | ||
| end | ||
| # we should be able to assume that if `i1` was indexed without error than it's inbounds | ||
| @propagate_inbounds function Base.getindex(x::ComposedIndex) | ||
| ii = getfield(x, :i1)[] | ||
| @inbounds(getfield(x, :i2)[ii]) | ||
| end | ||
| @propagate_inbounds function Base.getindex(x::ComposedIndex, i::CanonicalInt) | ||
| ii = getfield(x, :i1)[i] | ||
| @inbounds(getfield(x, :i2)[ii]) | ||
| end | ||
| @propagate_inbounds function Base.getindex(x::ComposedIndex, i::AbstractCartesianIndex) | ||
| ii = getfield(x, :i1)[i] | ||
| @inbounds(getfield(x, :i2)[ii]) | ||
| end | ||
|
|
||
| Base.getindex(x::ArrayIndex, i::ArrayIndex) = ComposedIndex(i, x) | ||
| @inline function Base.getindex(x::ComposedIndex, i::ArrayIndex) | ||
| ComposedIndex(getfield(x, :i1)[i], getfield(x, :i2)) | ||
| end | ||
| @inline function Base.getindex(x::ArrayIndex, i::ComposedIndex) | ||
| ComposedIndex(getfield(i, :i1), x[getfield(i, :i2)]) | ||
| end | ||
| @inline function Base.getindex(x::ComposedIndex, i::ComposedIndex) | ||
| ComposedIndex(getfield(i, :i1), ComposedIndex(getfield(x, :i1)[getfield(i, :i2)], getfield(x, :i2))) | ||
| end | ||
|
|
||
| ## getindex | ||
| @propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)] | ||
| @propagate_inbounds function Base.getindex(ind::BidiagonalIndex, i::Int) | ||
| @boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i)) | ||
|
|
@@ -288,3 +406,126 @@ end | |
| end | ||
| return Expr(:block, Expr(:meta, :inline), out) | ||
| end | ||
|
|
||
| @inline function Base.getindex(x::StrideIndex, i::SubIndex{N,I}) where {N,I} | ||
| _composed_sub_strides(stride_preserving_index(I), x, i) | ||
| end | ||
| _composed_sub_strides(::False, x::StrideIndex, i::SubIndex) = ComposedIndex(i, x) | ||
| @inline function _composed_sub_strides(::True, x::StrideIndex{N,R,C}, i::SubIndex{Ns,I}) where {N,R,C,Ns,I<:Tuple{Vararg{Any,N}}} | ||
| c = static(C) | ||
| if _get_tuple(I, c) <: AbstractUnitRange | ||
| c2 = known(getfield(_from_sub_dims(I), C)) | ||
| elseif (_get_tuple(I, c) <: AbstractArray) && (_get_tuple(I, c) <: Integer) | ||
| c2 = -1 | ||
| else | ||
| c2 = nothing | ||
| end | ||
|
|
||
| pdims = _to_sub_dims(I) | ||
| o = offsets(x) | ||
| s = strides(x) | ||
| inds = getfield(i, :indices) | ||
| out = StrideIndex{Ns,permute(R, pdims),c2}( | ||
| eachop(getmul, pdims, map(maybe_static_step, inds), s), | ||
| permute(o, pdims) | ||
| ) | ||
| return OffsetIndex(reduce_tup(+, map(*, map(_diff, inds, o), s)))[out] | ||
| end | ||
| @inline _diff(::Base.Slice, ::Any) = Zero() | ||
| @inline _diff(x::AbstractRange, o) = static_first(x) - o | ||
| @inline _diff(x::Integer, o) = x - o | ||
|
|
||
| @inline function Base.getindex(x::StrideIndex{1,R,C}, ::PermutedIndex{2,(2,1),(2,)}) where {R,C} | ||
| if C === nothing | ||
| c2 = nothing | ||
| elseif C === 1 | ||
| c2 = 2 | ||
| else | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should add elseif C === 2
c2 = 1Although given the |
||
| c2 = -1 | ||
| end | ||
| s = getfield(strides(x), 1) | ||
| return StrideIndex{2,(2,1),c2}((s, s), (static(1), offset1(x))) | ||
| end | ||
| @inline function Base.getindex(x::StrideIndex{N,R,C}, ::PermutedIndex{N,perm,iperm}) where {N,R,C,perm,iperm} | ||
| if C === nothing || C === -1 | ||
| c2 = C | ||
| else | ||
| c2 = getfield(iperm, C) | ||
| end | ||
| return StrideIndex{N,permute(R, Val(perm)),c2}( | ||
| permute(strides(x), Val(perm)), | ||
| permute(offsets(x), Val(perm)), | ||
| ) | ||
| end | ||
| @inline function Base.getindex(x::PermutedIndex, i::PermutedIndex) | ||
| PermutedIndex( | ||
| permute(to_parent_dims(x), to_parent_dims(i)), | ||
| permute(from_parent_dims(x), from_parent_dims(i)) | ||
| ) | ||
| end | ||
|
|
||
| @inline function Base.getindex(x::LinearSubIndex, i::LinearSubIndex) | ||
| s = getfield(x, :stride) | ||
| LinearSubIndex( | ||
| getfield(x, :offset) + getfield(i, :offset) * s, | ||
| getfield(i, :stride) * s | ||
| ) | ||
| end | ||
| Base.getindex(::OffsetIndex{StaticInt{0}}, i::StrideIndex) = i | ||
|
|
||
|
|
||
| ## ArrayIndex constructorrs | ||
| @inline _to_cartesian(a) = CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a)))) | ||
| @inline function _to_linear(a) | ||
| N = ndims(a) | ||
| StrideIndex{N,ntuple(+, Val(N)),nothing}(size_to_strides(size(a), static(1)), offsets(a)) | ||
| end | ||
|
|
||
| ## DenseArray | ||
| """ | ||
| ArrayIndex{N}(A) -> index | ||
|
|
||
| Constructs a subtype of `ArrayIndex` such that an `N` dimensional indexing argument may be | ||
| converted to an appropriate state for accessing the buffer of `A`. For example: | ||
|
|
||
| ```julia | ||
| julia> A = reshape(1:20, 4, 5); | ||
|
|
||
| julia> index = ArrayInterface.ArrayIndex{2}(A); | ||
|
|
||
| julia> ArrayInterface.buffer(A)[index[2, 2]] == A[2, 2] | ||
| true | ||
|
|
||
| ``` | ||
| """ | ||
| ArrayIndex{N}(x::DenseArray) where {N} = StrideIndex(x) | ||
| ArrayIndex{1}(x::DenseArray) = OffsetIndex(static(0)) | ||
|
|
||
| ArrayIndex{1}(x::ReshapedArray) = OffsetIndex(static(0)) | ||
| ArrayIndex{N}(x::ReshapedArray) where {N} = _to_linear(x) | ||
|
|
||
| ArrayIndex{1}(x::AbstractRange) = OffsetIndex(static(0)) | ||
|
|
||
| ## SubArray | ||
| ArrayIndex{N}(x::SubArray) where {N} = SubIndex{ndims(x)}(getfield(x, :indices)) | ||
| function ArrayIndex{1}(x::SubArray{<:Any,N}) where {N} | ||
| ComposedIndex(_to_cartesian(x), SubIndex{N}(getfield(x, :indices))) | ||
| end | ||
| ArrayIndex{1}(x::Base.FastContiguousSubArray) = OffsetIndex(getfield(x, :offset1)) | ||
| function ArrayIndex{1}(x::Base.FastSubArray) | ||
| LinearSubIndex(getfield(x, :offset1), getfield(x, :stride1)) | ||
| end | ||
|
|
||
| ## Permuted arrays | ||
| ArrayIndex{2}(::MatAdjTrans) = PermutedIndex{2,(2,1),(2,1)}() | ||
| ArrayIndex{2}(::VecAdjTrans) = PermutedIndex{2,(2,1),(2,)}() | ||
| ArrayIndex{1}(x::MatAdjTrans) = ComposedIndex(_to_cartesian(x), ArrayIndex{2}(x)) | ||
| ArrayIndex{1}(x::VecAdjTrans) = OffsetIndex(static(0)) # just unwrap permuting struct | ||
| ArrayIndex{1}(::PermutedDimsArray{<:Any,1}) = OffsetIndex(static(0)) | ||
| function ArrayIndex{N}(::PermutedDimsArray{<:Any,N,perm,iperm}) where {N,perm,iperm} | ||
| PermutedIndex{N,perm,iperm}() | ||
| end | ||
| function ArrayIndex{1}(x::PermutedDimsArray{<:Any,N,perm,iperm}) where {N,perm,iperm} | ||
| ComposedIndex(_to_cartesian(x), PermutedIndex{N,perm,iperm}()) | ||
| end | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because many
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The way I'm treating
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @chriselrod, should this PR have something like the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sure, although why not have these things live in a separate, experimental repo? It's hard for me to assess changes like these if I don't know the vision for their future use.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hopefully recent commits involving the indexing pipeline make it more clear that these changes allow unique interactions between layers of nested arrays. If we were to have a separate package for this we'd need ArrayInterface.jl to depend on it so that we could complete the indexing interface. That being said, we don't need to strictly use what I've proposed here. The point of the |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -668,7 +668,7 @@ end | |
| end | ||
| end | ||
|
|
||
| @testset "" begin | ||
| @testset "ArrayIndex" begin | ||
| include("array_index.jl") | ||
| end | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.