Skip to content

Commit 3b7062c

Browse files
committed
Make for iter::CartesianIndices{1/2} better vectorized.
1 parent f2c627e commit 3b7062c

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

base/multidimensional.jl

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -401,35 +401,38 @@ module IteratorsMD
401401
iterfirst, iterfirst
402402
end
403403
@inline function iterate(iter::CartesianIndices, state)
404-
valid, I = __inc(state.I, iter.indices)
404+
valid, I = __inc(state.I, iter.indices, Val(ndims(iter)))
405405
valid || return nothing
406406
return CartesianIndex(I...), CartesianIndex(I...)
407407
end
408408

409409
# increment & carry
410410
@inline function inc(state, indices)
411-
_, I = __inc(state, indices)
411+
_, I = __inc(state, indices, Val(length(state)))
412412
return CartesianIndex(I...)
413413
end
414414

415415
# Unlike ordinary ranges, CartesianIndices continues the iteration in the next column when the
416416
# current column is consumed. The implementation is written recursively to achieve this.
417417
# `iterate` returns `Union{Nothing, Tuple}`, we explicitly pass a `valid` flag to eliminate
418418
# the type instability inside the core `__inc` logic, and this gives better runtime performance.
419-
__inc(::Tuple{}, ::Tuple{}) = false, ()
420-
@inline function __inc(state::Tuple{Int}, indices::Tuple{OrdinalRangeInt})
419+
__inc(::Tuple{}, ::Tuple{}, ::Val) = false, ()
420+
@inline function __inc(state::Tuple{Int}, indices::Tuple{OrdinalRangeInt}, ::Val{N}) where {N}
421421
rng = indices[1]
422422
I = state[1] + step(rng)
423-
valid = __is_valid_range(I, rng) && state[1] != last(rng)
423+
if N == 1
424+
valid = state[1] != last(rng)
425+
else
426+
valid = __is_valid_range(I, rng)
427+
end
424428
return valid, (I, )
425429
end
426-
@inline function __inc(state::Tuple{Int,Int,Vararg{Int}}, indices::Tuple{OrdinalRangeInt,OrdinalRangeInt,Vararg{OrdinalRangeInt}})
430+
@inline function __inc(state::Tuple{Int,Int,Vararg{Int}}, indices::Tuple{OrdinalRangeInt,OrdinalRangeInt,Vararg{OrdinalRangeInt}}, ndim::Val)
427431
rng = indices[1]
428-
I = state[1] + step(rng)
429-
if __is_valid_range(I, rng) && state[1] != last(rng)
430-
return true, (I, tail(state)...)
432+
if state[1] != last(rng)
433+
return true, (state[1] + step(rng), tail(state)...)
431434
end
432-
valid, I = __inc(tail(state), tail(indices))
435+
valid, I = __inc(tail(state), tail(indices), ndim)
433436
return valid, (first(rng), I...)
434437
end
435438

@@ -516,32 +519,35 @@ module IteratorsMD
516519
iterfirst, iterfirst
517520
end
518521
@inline function iterate(r::Reverse{<:CartesianIndices}, state)
519-
valid, I = __dec(state.I, r.itr.indices)
522+
valid, I = __dec(state.I, r.itr.indices, Val(ndims(r.itr)))
520523
valid || return nothing
521524
return CartesianIndex(I...), CartesianIndex(I...)
522525
end
523526

524527
# decrement & carry
525528
@inline function dec(state, indices)
526-
_, I = __dec(state, indices)
529+
_, I = __dec(state, indices, Val(length(state)))
527530
return CartesianIndex(I...)
528531
end
529532

530533
# decrement post check to avoid integer overflow
531-
@inline __dec(::Tuple{}, ::Tuple{}) = false, ()
532-
@inline function __dec(state::Tuple{Int}, indices::Tuple{OrdinalRangeInt})
534+
@inline __dec(::Tuple{}, ::Tuple{}, ::Val) = false, ()
535+
@inline function __dec(state::Tuple{Int}, indices::Tuple{OrdinalRangeInt}, ::Val{N}) where {N}
533536
rng = indices[1]
534537
I = state[1] - step(rng)
535-
valid = __is_valid_range(I, rng) && state[1] != first(rng)
538+
if N == 1
539+
valid = state[1] != first(rng)
540+
else
541+
valid = __is_valid_range(I, rng)
542+
end
536543
return valid, (I,)
537544
end
538-
@inline function __dec(state::Tuple{Int,Int,Vararg{Int}}, indices::Tuple{OrdinalRangeInt,OrdinalRangeInt,Vararg{OrdinalRangeInt}})
545+
@inline function __dec(state::Tuple{Int,Int,Vararg{Int}}, indices::Tuple{OrdinalRangeInt,OrdinalRangeInt,Vararg{OrdinalRangeInt}}, ndim::Val)
539546
rng = indices[1]
540-
I = state[1] - step(rng)
541-
if __is_valid_range(I, rng) && state[1] != first(rng)
542-
return true, (I, tail(state)...)
547+
if state[1] != first(rng)
548+
return true, (state[1] - step(rng), tail(state)...)
543549
end
544-
valid, I = __dec(tail(state), tail(indices))
550+
valid, I = __dec(tail(state), tail(indices), ndim)
545551
return valid, (last(rng), I...)
546552
end
547553

0 commit comments

Comments
 (0)