Skip to content

Commit 7bec6c0

Browse files
authored
Merge pull request #406 from Nemocas/th/cat
Add generalized matrix concatenation
2 parents 19b9b38 + ffa2764 commit 7bec6c0

2 files changed

Lines changed: 161 additions & 0 deletions

File tree

src/generic/Matrix.jl

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4764,6 +4764,139 @@ function vcat(a::AbstractAlgebra.MatElem, b::AbstractAlgebra.MatElem)
47644764
return c
47654765
end
47664766

4767+
@doc Markdown.doc"""
4768+
vcat(A::Vector{<: MatrixElem}) -> MatrixElem
4769+
> Return the horizontal concatenation of the matrices in $A$.
4770+
> All component matrices need to have the same base ring and number of columns.
4771+
"""
4772+
function vcat(A::Vector{<: MatrixElem})
4773+
return _vcat(A)
4774+
end
4775+
4776+
function Base.vcat(A::MatrixElem...)
4777+
return _vcat(A)
4778+
end
4779+
4780+
function _vcat(A)
4781+
if length(A) == 0
4782+
error("Number of matrices to concatenate must be positive")
4783+
end
4784+
4785+
if any(x -> ncols(x) != ncols(A[1]), A)
4786+
error("Matrices must have the same number of columns")
4787+
end
4788+
4789+
if any(x -> base_ring(x) != base_ring(A[1]), A)
4790+
error("Matrices must have the same base ring")
4791+
end
4792+
4793+
M = similar(A[1], sum(nrows, A), ncols(A[1]))
4794+
s = 0
4795+
for N in A
4796+
for j in 1:nrows(N)
4797+
for k in 1:ncols(N)
4798+
M[s+j, k] = N[j,k]
4799+
end
4800+
end
4801+
s += nrows(N)
4802+
end
4803+
return M
4804+
end
4805+
4806+
@doc Markdown.doc"""
4807+
hcat(A::Vector{<: MatrixElem}) -> MatrixElem
4808+
> Return the horizontal concatenating of the matrices in $A$.
4809+
> All component matrices need to have the same base ring and number of rows.
4810+
"""
4811+
function hcat(A::Vector{<: MatrixElem})
4812+
return _hcat(A)
4813+
end
4814+
4815+
function _hcat(A)
4816+
if length(A) == 0
4817+
error("Number of matrices to concatenate must be positive")
4818+
end
4819+
4820+
if any(x -> nrows(x) != nrows(A[1]), A)
4821+
error("Matrices must have the same number of rows")
4822+
end
4823+
4824+
if any(x -> base_ring(x) != base_ring(A[1]), A)
4825+
error("Matrices must have the same base ring")
4826+
end
4827+
4828+
M = similar(A[1], nrows(A[1]), sum(ncols, A))
4829+
s = 0
4830+
for N in A
4831+
for j in 1:ncols(N)
4832+
for k in 1:nrows(N)
4833+
M[k, s + j] = N[k, j]
4834+
end
4835+
end
4836+
s += ncols(N)
4837+
end
4838+
return M
4839+
end
4840+
4841+
function Base.hcat(A::MatrixElem...)
4842+
return _hcat(A)
4843+
end
4844+
4845+
function Base.cat(A::MatrixElem...;dims)
4846+
@assert dims == (1,2) || isa(dims, Int)
4847+
4848+
if isa(dims, Int)
4849+
if dims == 1
4850+
return hcat(A...)
4851+
elseif dims == 2
4852+
return vcat(A...)
4853+
else
4854+
error("dims must be 1, 2, or (1,2)")
4855+
end
4856+
end
4857+
4858+
local X
4859+
for i in 1:length(A)
4860+
if i == 1
4861+
X = hcat(A[1], zero(A[1], nrows(A[1]), sum(Int[ncols(A[j]) for j=2:length(A)])))
4862+
else
4863+
X = vcat(X, hcat(zero(A[1], nrows(A[i]), sum(ncols(A[j]) for j=1:i-1)), A[i], zero(A[1], nrows(A[i]), sum(Int[ncols(A[j]) for j in (i+1):length(A)]))))
4864+
end
4865+
end
4866+
return X
4867+
end
4868+
4869+
function Base.hvcat(rows::Tuple{Vararg{Int}}, A::MatrixElem...)
4870+
nr = 0
4871+
k = 1
4872+
for i in 1:length(rows)
4873+
nr += nrows(A[k])
4874+
k += rows[i]
4875+
end
4876+
4877+
nc = sum(ncols(A[i]) for i in 1:rows[1])
4878+
4879+
M = similar(A[1], nr, nc)
4880+
mat_offset = 0
4881+
row_offset = 0
4882+
for j in 1:length(rows)
4883+
s = 0
4884+
for i in 1:rows[j]
4885+
N = A[mat_offset + i]
4886+
for l in 1:ncols(N)
4887+
for k in 1:nrows(N)
4888+
M[row_offset + k, s + l] = N[k, l]
4889+
end
4890+
end
4891+
s += ncols(N)
4892+
end
4893+
row_offset += nrows(A[1+ mat_offset])
4894+
mat_offset += rows[j]
4895+
end
4896+
4897+
return M
4898+
end
4899+
47674900
###############################################################################
47684901
#
47694902
# Random generation

test/generic/Matrix-test.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,6 +1641,34 @@ function test_gen_mat_concat()
16411641
@test vcat(transpose(M1), transpose(M2)) == transpose(hcat(M1, M2))
16421642
end
16431643

1644+
A = matrix(R, 2, 2, [1, 2, 3, 4])
1645+
B = matrix(R, 4, 2, [1, 2, 3, 4, 0, 1, 0, 1])
1646+
C = matrix(R, 4, 1, [0, 1, 0, 2])
1647+
D = matrix(R, 2, 3, [1, 2, 3, 4, 5, 6])
1648+
1649+
@test hcat(B, C) == matrix(R, [1 2 0;
1650+
3 4 1;
1651+
0 1 0;
1652+
0 1 2;])
1653+
@test hcat(B, C) == [B C]
1654+
@test hcat([B, C]) == [B C]
1655+
1656+
@test vcat(A, B) == matrix(R, [1 2;
1657+
3 4;
1658+
1 2;
1659+
3 4;
1660+
0 1;
1661+
0 1;])
1662+
1663+
@test vcat(A, B) == [A; B]
1664+
@test vcat(A, B) == vcat([A, B])
1665+
1666+
@test [A D; B B C] == matrix(R, [1 2 1 2 3;
1667+
3 4 4 5 6;
1668+
1 2 1 2 0;
1669+
3 4 3 4 1;
1670+
0 1 0 1 0;
1671+
0 1 0 1 2;])
16441672
println("PASS")
16451673
end
16461674

0 commit comments

Comments
 (0)