Skip to content

Commit 16c70a6

Browse files
mcabbottararslan
authored andcommitted
Add stack (Julia PR 43334)
Backport of PR 777 to the `release-3` branch. (cherry picked from commit dce2f96)
1 parent 952300c commit 16c70a6

File tree

5 files changed

+349
-2
lines changed

5 files changed

+349
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Compat"
22
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
3-
version = "3.45.0"
3+
version = "3.46.0"
44

55
[deps]
66
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ changes in `julia`.
7272
* `popat!` removes the item at the given `i` and returns it ([#36070]). (since
7373
Compat 3.45.0)
7474

75+
* `stack` combines a collection of slices into one array ([#43334]). (since Compat 3.46.0, 4.2.0)
76+
7577
* `keepat!` removes the items at all the indices which are not given and returns
7678
the modified source ([#36229], [#42351]). (since Compat 3.44.0, 4.1.0)
7779

@@ -322,3 +324,4 @@ Note that you should specify the correct minimum version for `Compat` in the
322324
[#36229]: https://github.com/JuliaLang/julia/issues/36229
323325
[#39245]: https://github.com/JuliaLang/julia/issues/39245
324326
[#42351]: https://github.com/JuliaLang/julia/issues/42351
327+
[#43334]: https://github.com/JuliaLang/julia/issues/43334

src/Compat.jl

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,241 @@ end
13871387
end
13881388

13891389
include("iterators.jl")
1390+
1391+
# https://github.com/JuliaLang/julia/pull/43334
1392+
if VERSION < v"1.9.0-DEV.1163"
1393+
export stack
1394+
1395+
"""
1396+
stack(iter; [dims])
1397+
1398+
Combine a collection of arrays (or other iterable objects) of equal size
1399+
into one larger array, by arranging them along one or more new dimensions.
1400+
1401+
By default the axes of the elements are placed first,
1402+
giving `size(result) = (size(first(iter))..., size(iter)...)`.
1403+
This has the same order of elements as [`Iterators.flatten`](@ref)`(iter)`.
1404+
1405+
With keyword `dims::Integer`, instead the `i`th element of `iter` becomes the slice
1406+
[`selectdim`](@ref)`(result, dims, i)`, so that `size(result, dims) == length(iter)`.
1407+
In this case `stack` reverses the action of [`eachslice`](@ref) with the same `dims`.
1408+
1409+
The various [`cat`](@ref) functions also combine arrays. However, these all
1410+
extend the arrays' existing (possibly trivial) dimensions, rather than placing
1411+
the arrays along new dimensions.
1412+
They also accept arrays as separate arguments, rather than a single collection.
1413+
1414+
!!! compat "Julia 1.9"
1415+
This function is available in Julia 1.9, or in Compat 4.2.
1416+
1417+
# Examples
1418+
```jldoctest
1419+
julia> vecs = (1:2, [30, 40], Float32[500, 600]);
1420+
1421+
julia> mat = stack(vecs)
1422+
2×3 Matrix{Float32}:
1423+
1.0 30.0 500.0
1424+
2.0 40.0 600.0
1425+
1426+
julia> mat == hcat(vecs...) == reduce(hcat, collect(vecs))
1427+
true
1428+
1429+
julia> vec(mat) == vcat(vecs...) == reduce(vcat, collect(vecs))
1430+
true
1431+
1432+
julia> stack(zip(1:4, 10:99)) # accepts any iterators of iterators
1433+
2×4 Matrix{Int64}:
1434+
1 2 3 4
1435+
10 11 12 13
1436+
1437+
julia> vec(ans) == collect(Iterators.flatten(zip(1:4, 10:99)))
1438+
true
1439+
1440+
julia> stack(vecs; dims=1) # unlike any cat function, 1st axis of vecs[1] is 2nd axis of result
1441+
3×2 Matrix{Float32}:
1442+
1.0 2.0
1443+
30.0 40.0
1444+
500.0 600.0
1445+
1446+
julia> x = rand(3,4);
1447+
1448+
julia> x == stack(eachcol(x)) == stack(eachrow(x), dims=1) # inverse of eachslice
1449+
true
1450+
```
1451+
1452+
Higher-dimensional examples:
1453+
1454+
```jldoctest
1455+
julia> A = rand(5, 7, 11);
1456+
1457+
julia> E = eachslice(A, dims=2); # a vector of matrices
1458+
1459+
julia> (element = size(first(E)), container = size(E))
1460+
(element = (5, 11), container = (7,))
1461+
1462+
julia> stack(E) |> size
1463+
(5, 11, 7)
1464+
1465+
julia> stack(E) == stack(E; dims=3) == cat(E...; dims=3)
1466+
true
1467+
1468+
julia> A == stack(E; dims=2)
1469+
true
1470+
1471+
julia> M = (fill(10i+j, 2, 3) for i in 1:5, j in 1:7);
1472+
1473+
julia> (element = size(first(M)), container = size(M))
1474+
(element = (2, 3), container = (5, 7))
1475+
1476+
julia> stack(M) |> size # keeps all dimensions
1477+
(2, 3, 5, 7)
1478+
1479+
julia> stack(M; dims=1) |> size # vec(container) along dims=1
1480+
(35, 2, 3)
1481+
1482+
julia> hvcat(5, M...) |> size # hvcat puts matrices next to each other
1483+
(14, 15)
1484+
```
1485+
"""
1486+
stack(iter; dims=:) = _stack(dims, iter)
1487+
1488+
"""
1489+
stack(f, args...; [dims])
1490+
1491+
Apply a function to each element of a collection, and `stack` the result.
1492+
Or to several collections, [`zip`](@ref)ped together.
1493+
1494+
The function should return arrays (or tuples, or other iterators) all of the same size.
1495+
These become slices of the result, each separated along `dims` (if given) or by default
1496+
along the last dimensions.
1497+
1498+
See also [`mapslices`](@ref), [`eachcol`](@ref).
1499+
1500+
# Examples
1501+
```jldoctest
1502+
julia> stack(c -> (c, c-32), "julia")
1503+
2×5 Matrix{Char}:
1504+
'j' 'u' 'l' 'i' 'a'
1505+
'J' 'U' 'L' 'I' 'A'
1506+
1507+
julia> stack(eachrow([1 2 3; 4 5 6]), (10, 100); dims=1) do row, n
1508+
vcat(row, row .* n, row ./ n)
1509+
end
1510+
2×9 Matrix{Float64}:
1511+
1.0 2.0 3.0 10.0 20.0 30.0 0.1 0.2 0.3
1512+
4.0 5.0 6.0 400.0 500.0 600.0 0.04 0.05 0.06
1513+
```
1514+
"""
1515+
stack(f, iter; dims=:) = _stack(dims, f(x) for x in iter)
1516+
stack(f, xs, yzs...; dims=:) = _stack(dims, f(xy...) for xy in zip(xs, yzs...))
1517+
1518+
_stack(dims::Union{Integer, Colon}, iter) = _stack(dims, Base.IteratorSize(iter), iter)
1519+
1520+
_stack(dims, ::Base.IteratorSize, iter) = _stack(dims, collect(iter))
1521+
1522+
function _stack(dims, ::Union{Base.HasShape, Base.HasLength}, iter)
1523+
S = Base.@default_eltype iter
1524+
T = S != Union{} ? eltype(S) : Any # Union{} occurs for e.g. stack(1,2), postpone the error
1525+
if isconcretetype(T)
1526+
_typed_stack(dims, T, S, iter)
1527+
else # Need to look inside, but shouldn't run an expensive iterator twice:
1528+
array = iter isa Union{Tuple, AbstractArray} ? iter : collect(iter)
1529+
isempty(array) && return _empty_stack(dims, T, S, iter)
1530+
T2 = mapreduce(eltype, promote_type, array)
1531+
_typed_stack(dims, T2, eltype(array), array)
1532+
end
1533+
end
1534+
1535+
function _typed_stack(::Colon, ::Type{T}, ::Type{S}, A, Aax=_iterator_axes(A)) where {T, S}
1536+
xit = iterate(A)
1537+
nothing === xit && return _empty_stack(:, T, S, A)
1538+
x1, _ = xit
1539+
ax1 = _iterator_axes(x1)
1540+
B = similar(_ensure_array(x1), T, ax1..., Aax...)
1541+
off = firstindex(B)
1542+
len = length(x1)
1543+
while xit !== nothing
1544+
x, state = xit
1545+
_stack_size_check(x, ax1)
1546+
copyto!(B, off, x)
1547+
off += len
1548+
xit = iterate(A, state)
1549+
end
1550+
B
1551+
end
1552+
1553+
_iterator_axes(x) = _iterator_axes(x, Base.IteratorSize(x))
1554+
_iterator_axes(x, ::Base.HasLength) = (Base.OneTo(length(x)),)
1555+
_iterator_axes(x, ::Base.IteratorSize) = axes(x)
1556+
1557+
# For some dims values, stack(A; dims) == stack(vec(A)), and the : path will be faster
1558+
_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S} =
1559+
_typed_stack(dims, T, S, Base.IteratorSize(S), A)
1560+
_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::Base.HasLength, A) where {T,S} =
1561+
_typed_stack(dims, T, S, Base.HasShape{1}(), A)
1562+
function _typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::Base.HasShape{N}, A) where {T,S,N}
1563+
if dims == N+1
1564+
_typed_stack(:, T, S, A, (_vec_axis(A),))
1565+
else
1566+
_dim_stack(dims, T, S, A)
1567+
end
1568+
end
1569+
_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::Base.IteratorSize, A) where {T,S} =
1570+
_dim_stack(dims, T, S, A)
1571+
1572+
_vec_axis(A, ax=_iterator_axes(A)) = length(ax) == 1 ? only(ax) : Base.OneTo(prod(length, ax; init=1))
1573+
1574+
@constprop :aggressive function _dim_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S}
1575+
xit = Iterators.peel(A)
1576+
nothing === xit && return _empty_stack(dims, T, S, A)
1577+
x1, xrest = xit
1578+
ax1 = _iterator_axes(x1)
1579+
N1 = length(ax1)+1
1580+
dims in 1:N1 || throw(ArgumentError(string("cannot stack slices ndims(x) = ", N1-1, " along dims = ", dims)))
1581+
1582+
newaxis = _vec_axis(A)
1583+
outax = ntuple(d -> d==dims ? newaxis : ax1[d - (d>dims)], N1)
1584+
B = similar(_ensure_array(x1), T, outax...)
1585+
1586+
if dims == 1
1587+
_dim_stack!(Val(1), B, x1, xrest)
1588+
elseif dims == 2
1589+
_dim_stack!(Val(2), B, x1, xrest)
1590+
else
1591+
_dim_stack!(Val(dims), B, x1, xrest)
1592+
end
1593+
B
1594+
end
1595+
1596+
function _dim_stack!(::Val{dims}, B::AbstractArray, x1, xrest) where {dims}
1597+
before = ntuple(d -> Colon(), dims - 1)
1598+
after = ntuple(d -> Colon(), ndims(B) - dims)
1599+
1600+
i = firstindex(B, dims)
1601+
copyto!(view(B, before..., i, after...), x1)
1602+
1603+
for x in xrest
1604+
_stack_size_check(x, _iterator_axes(x1))
1605+
i += 1
1606+
@inbounds copyto!(view(B, before..., i, after...), x)
1607+
end
1608+
end
1609+
1610+
@inline function _stack_size_check(x, ax1::Tuple)
1611+
if _iterator_axes(x) != ax1
1612+
uax1 = map(UnitRange, ax1)
1613+
uaxN = map(UnitRange, axes(x))
1614+
throw(DimensionMismatch(
1615+
string("stack expects uniform slices, got axes(x) == ", uaxN, " while first had ", uax1)))
1616+
end
1617+
end
1618+
1619+
_ensure_array(x::AbstractArray) = x
1620+
_ensure_array(x) = 1:0 # passed to similar, makes stack's output an Array
1621+
1622+
_empty_stack(_...) = throw(ArgumentError("`stack` on an empty collection is not allowed"))
1623+
end
1624+
13901625
include("deprecated.jl")
13911626

13921627
end # module Compat

src/iterators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ for n in names(Base.Iterators)
3232
end
3333

3434
# Import unexported public APIs
35-
using Base.Iterators: filter
35+
using Base.Iterators: filter, peel
3636

3737
# https://github.com/JuliaLang/julia/pull/33437
3838
if VERSION < v"1.4.0-DEV.291" # 5f013d82f92026f7dfbe4234f283658beb1f8a2a

test/runtests.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,3 +1466,112 @@ end
14661466
badpop() = @inbounds popat!([1], 2)
14671467
@test_throws BoundsError badpop()
14681468
end
1469+
1470+
# https://github.com/JuliaLang/julia/pull/43334
1471+
@testset "stack" begin
1472+
# Basics
1473+
for args in ([[1, 2]], [1:2, 3:4], [[1 2; 3 4], [5 6; 7 8]],
1474+
AbstractVector[1:2, [3.5, 4.5]], Vector[[1,2], [3im, 4im]],
1475+
[[1:2, 3:4], [5:6, 7:8]], [fill(1), fill(2)])
1476+
X = stack(args)
1477+
Y = cat(args...; dims=ndims(args[1])+1)
1478+
@test X == Y
1479+
@test typeof(X) === typeof(Y)
1480+
1481+
X2 = stack(x for x in args)
1482+
@test X2 == Y
1483+
@test typeof(X2) === typeof(Y)
1484+
1485+
X3 = stack(x for x in args if true)
1486+
@test X3 == Y
1487+
@test typeof(X3) === typeof(Y)
1488+
1489+
if isconcretetype(eltype(args))
1490+
@inferred stack(args)
1491+
@inferred stack(x for x in args)
1492+
end
1493+
end
1494+
1495+
# Higher dims
1496+
@test size(stack([rand(2,3) for _ in 1:4, _ in 1:5])) == (2,3,4,5)
1497+
@test size(stack(rand(2,3) for _ in 1:4, _ in 1:5)) == (2,3,4,5)
1498+
@test size(stack(rand(2,3) for _ in 1:4, _ in 1:5 if true)) == (2, 3, 20)
1499+
@test size(stack([rand(2,3) for _ in 1:4, _ in 1:5]; dims=1)) == (20, 2, 3)
1500+
@test size(stack(rand(2,3) for _ in 1:4, _ in 1:5; dims=2)) == (2, 20, 3)
1501+
1502+
# Tuples
1503+
@test stack([(1,2), (3,4)]) == [1 3; 2 4]
1504+
@test stack(((1,2), (3,4))) == [1 3; 2 4]
1505+
@test stack(Any[(1,2), (3,4)]) == [1 3; 2 4]
1506+
@test stack([(1,2), (3,4)]; dims=1) == [1 2; 3 4]
1507+
@test stack(((1,2), (3,4)); dims=1) == [1 2; 3 4]
1508+
@test stack(Any[(1,2), (3,4)]; dims=1) == [1 2; 3 4]
1509+
@test size(@inferred stack(Iterators.product(1:3, 1:4))) == (2,3,4)
1510+
@test @inferred(stack([('a', 'b'), ('c', 'd')])) == ['a' 'c'; 'b' 'd']
1511+
@test @inferred(stack([(1,2+3im), (4, 5+6im)])) isa Matrix{Number}
1512+
1513+
# stack(f, iter)
1514+
@test @inferred(stack(x -> [x, 2x], 3:5)) == [3 4 5; 6 8 10]
1515+
@test @inferred(stack(x -> x*x'/2, [1:2, 3:4])) == reshape([0.5, 1.0, 1.0, 2.0, 4.5, 6.0, 6.0, 8.0], 2, 2, 2)
1516+
@test @inferred(stack(*, [1:2, 3:4], 5:6)) == [5 18; 10 24]
1517+
1518+
# Iterators
1519+
@test stack([(a=1,b=2), (a=3,b=4)]) == [1 3; 2 4]
1520+
@test stack([(a=1,b=2), (c=3,d=4)]) == [1 3; 2 4]
1521+
@test stack([(a=1,b=2), (c=3,d=4)]; dims=1) == [1 2; 3 4]
1522+
@test stack([(a=1,b=2), (c=3,d=4)]; dims=2) == [1 3; 2 4]
1523+
@test stack((x/y for x in 1:3) for y in 4:5) == (1:3) ./ (4:5)'
1524+
@test stack((x/y for x in 1:3) for y in 4:5; dims=1) == (1:3)' ./ (4:5)
1525+
1526+
# Exotic
1527+
ips = ((Iterators.product([i,i^2], [2i,3i,4i], 1:4)) for i in 1:5)
1528+
@test size(stack(ips)) == (2, 3, 4, 5)
1529+
@test stack(ips) == cat(collect.(ips)...; dims=4)
1530+
ips_cat2 = cat(reshape.(collect.(ips), Ref((2,1,3,4)))...; dims=2)
1531+
@test stack(ips; dims=2) == ips_cat2
1532+
@test stack(collect.(ips); dims=2) == ips_cat2
1533+
ips_cat3 = cat(reshape.(collect.(ips), Ref((2,3,1,4)))...; dims=3)
1534+
@test stack(ips; dims=3) == ips_cat3 # path for non-array accumulation on non-final dims
1535+
@test stack(collect, ips; dims=3) == ips_cat3 # ... and for array accumulation
1536+
@test stack(collect.(ips); dims=3) == ips_cat3
1537+
1538+
# Trivial, because numbers are iterable:
1539+
@test stack(abs2, 1:3) == [1, 4, 9] == collect(Iterators.flatten(abs2(x) for x in 1:3))
1540+
1541+
# Allocation tests
1542+
xv = [rand(10) for _ in 1:100]
1543+
xt = Tuple.(xv)
1544+
for dims in (1, 2, :)
1545+
@test stack(xv; dims=dims) == stack(xt; dims=dims)
1546+
@test_skip 9000 > @allocated stack(xv; dims=dims)
1547+
@test_skip 9000 > @allocated stack(xt; dims=dims)
1548+
end
1549+
xr = (reshape(1:1000,10,10,10) for _ = 1:1000)
1550+
for dims in (1, 2, 3, :)
1551+
stack(xr; dims=dims)
1552+
@test_skip 8.1e6 > @allocated stack(xr; dims=dims)
1553+
end
1554+
1555+
# Mismatched sizes
1556+
@test_throws DimensionMismatch stack([1:2, 1:3])
1557+
@test_throws DimensionMismatch stack([1:2, 1:3]; dims=1)
1558+
@test_throws DimensionMismatch stack([1:2, 1:3]; dims=2)
1559+
@test_throws DimensionMismatch stack([(1,2), (3,4,5)])
1560+
@test_throws DimensionMismatch stack([(1,2), (3,4,5)]; dims=1)
1561+
@test_throws DimensionMismatch stack(x for x in [1:2, 1:3])
1562+
@test_throws DimensionMismatch stack([[5 6; 7 8], [1, 2, 3, 4]])
1563+
@test_throws DimensionMismatch stack([[5 6; 7 8], [1, 2, 3, 4]]; dims=1)
1564+
@test_throws DimensionMismatch stack(x for x in [[5 6; 7 8], [1, 2, 3, 4]])
1565+
# Inner iterator of unknown length
1566+
@test_throws MethodError stack((x for x in 1:3 if true) for _ in 1:4)
1567+
@test_throws MethodError stack((x for x in 1:3 if true) for _ in 1:4; dims=1)
1568+
1569+
@test_throws ArgumentError stack([1:3, 4:6]; dims=0)
1570+
@test_throws ArgumentError stack([1:3, 4:6]; dims=3)
1571+
@test_throws ArgumentError stack(abs2, 1:3; dims=2)
1572+
1573+
# Empty
1574+
@test_throws ArgumentError stack(())
1575+
@test_throws ArgumentError stack([])
1576+
@test_throws ArgumentError stack(x for x in 1:3 if false)
1577+
end

0 commit comments

Comments
 (0)