Skip to content

Commit 18f2f9f

Browse files
dkarraschvtjnash
andauthored
Concatenation with UniformScaling and numbers (#41394)
Co-authored-by: Jameson Nash <[email protected]>
1 parent 4b1e6f3 commit 18f2f9f

File tree

4 files changed

+43
-16
lines changed

4 files changed

+43
-16
lines changed

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
1515
oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
1616
setindex!, show, similar, sin, sincos, sinh, size, sqrt,
1717
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
18-
using Base: IndexLinear, promote_op, promote_typeof,
19-
@propagate_inbounds, @pure, reduce, typed_vcat, require_one_based_indexing,
18+
using Base: IndexLinear, promote_eltype, promote_op, promote_typeof,
19+
@propagate_inbounds, @pure, reduce, typed_hvcat, typed_vcat, require_one_based_indexing,
2020
splat
2121
using Base.Broadcast: Broadcasted, broadcasted
2222
import Libdl

stdlib/LinearAlgebra/src/uniformscaling.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ end
391391
# in A to matrices of type T and sizes given by n[k:end]. n is an array
392392
# so that the same promotion code can be used for hvcat. We pass the type T
393393
# so that we can re-use this code for sparse-matrix hcat etcetera.
394+
promote_to_arrays_(n::Int, ::Type, a::Number) = a
394395
promote_to_arrays_(n::Int, ::Type{Matrix}, J::UniformScaling{T}) where {T} = copyto!(Matrix{T}(undef, n,n), J)
395396
promote_to_arrays_(n::Int, ::Type, A::AbstractVecOrMat) = A
396397
promote_to_arrays(n,k, ::Type) = ()
@@ -401,11 +402,11 @@ promote_to_arrays(n,k, ::Type{T}, A, B, C) where {T} =
401402
(promote_to_arrays_(n[k], T, A), promote_to_arrays_(n[k+1], T, B), promote_to_arrays_(n[k+2], T, C))
402403
promote_to_arrays(n,k, ::Type{T}, A, B, Cs...) where {T} =
403404
(promote_to_arrays_(n[k], T, A), promote_to_arrays_(n[k+1], T, B), promote_to_arrays(n,k+2, T, Cs...)...)
404-
promote_to_array_type(A::Tuple{Vararg{Union{AbstractVecOrMat,UniformScaling}}}) = Matrix
405+
promote_to_array_type(A::Tuple{Vararg{Union{AbstractVecOrMat,UniformScaling,Number}}}) = Matrix
405406

406407
for (f,dim,name) in ((:hcat,1,"rows"), (:vcat,2,"cols"))
407408
@eval begin
408-
function $f(A::Union{AbstractVecOrMat,UniformScaling}...)
409+
function $f(A::Union{AbstractVecOrMat,UniformScaling,Number}...)
409410
n = -1
410411
for a in A
411412
if !isa(a, UniformScaling)
@@ -418,13 +419,13 @@ for (f,dim,name) in ((:hcat,1,"rows"), (:vcat,2,"cols"))
418419
end
419420
end
420421
n == -1 && throw(ArgumentError($("$f of only UniformScaling objects cannot determine the matrix size")))
421-
return $f(promote_to_arrays(fill(n,length(A)),1, promote_to_array_type(A), A...)...)
422+
return cat(promote_to_arrays(fill(n, length(A)), 1, promote_to_array_type(A), A...)..., dims=Val(3-$dim))
422423
end
423424
end
424425
end
425426

426427

427-
function hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScaling}...)
428+
function hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScaling,Number}...)
428429
require_one_based_indexing(A...)
429430
nr = length(rows)
430431
sum(rows) == length(A) || throw(ArgumentError("mismatch between row sizes and number of arguments"))
@@ -467,16 +468,27 @@ function hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScalin
467468
j = 0
468469
for i = 1:nr
469470
if rows[i] > 0 && n[j+1] == -1 # this row consists entirely of UniformScalings
470-
nci = nc ÷ rows[i]
471-
nci * rows[i] != nc && throw(DimensionMismatch("indivisible UniformScaling sizes"))
471+
nci, r = divrem(nc, rows[i])
472+
r != 0 && throw(DimensionMismatch("indivisible UniformScaling sizes"))
472473
for k = 1:rows[i]
473474
n[j+k] = nci
474475
end
475476
end
476477
j += rows[i]
477478
end
478479
end
479-
return hvcat(rows, promote_to_arrays(n,1, promote_to_array_type(A), A...)...)
480+
Atyp = promote_to_array_type(A)
481+
Amat = promote_to_arrays(n, 1, Atyp, A...)
482+
# We have two methods for promote_to_array_type, one returning Matrix and
483+
# another one returning SparseMatrixCSC (in SparseArrays.jl). In the dense
484+
# case, we cannot call hvcat for the promoted UniformScalings because this
485+
# causes a stack overflow. In the sparse case, however, we cannot call
486+
# typed_hvcat because we need a sparse output.
487+
if Atyp == Matrix
488+
return typed_hvcat(promote_eltype(Amat...), rows, Amat...)
489+
else
490+
return hvcat(rows, Amat...)
491+
end
480492
end
481493

482494
## Matrix construction from UniformScaling

stdlib/LinearAlgebra/test/uniformscaling.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,10 +335,19 @@ end
335335
B = T(rand(3,3))
336336
C = T(rand(0,3))
337337
D = T(rand(2,0))
338+
E = T(rand(1,3))
339+
F = T(rand(3,1))
340+
α = rand()
338341
@test (hcat(A, 2I))::T == hcat(A, Matrix(2I, 3, 3))
342+
@test (hcat(E, α))::T == hcat(E, [α])
343+
@test (hcat(E, α, 2I))::T == hcat(E, [α], fill(2, 1, 1))
339344
@test (vcat(A, 2I))::T == vcat(A, Matrix(2I, 4, 4))
345+
@test (vcat(F, α))::T == vcat(F, [α])
346+
@test (vcat(F, α, 2I))::T == vcat(F, [α], fill(2, 1, 1))
340347
@test (hcat(C, 2I))::T == C
348+
@test_throws DimensionMismatch hcat(C, α)
341349
@test (vcat(D, 2I))::T == D
350+
@test_throws DimensionMismatch vcat(D, α)
342351
@test (hcat(I, 3I, A, 2I))::T == hcat(Matrix(I, 3, 3), Matrix(3I, 3, 3), A, Matrix(2I, 3, 3))
343352
@test (vcat(I, 3I, A, 2I))::T == vcat(Matrix(I, 4, 4), Matrix(3I, 4, 4), A, Matrix(2I, 4, 4))
344353
@test (hvcat((2,1,2), B, 2I, I, 3I, 4I))::T ==
@@ -353,6 +362,9 @@ end
353362
hvcat((2,2,2), B, Matrix(2I, 3, 3), C, C, Matrix(3I, 3, 3), Matrix(4I, 3, 3))
354363
@test hvcat((3,2,1), C, C, I, B ,3I, 2I)::T ==
355364
hvcat((2,2,1), C, C, B, Matrix(3I,3,3), Matrix(2I,6,6))
365+
@test (hvcat((1,2), A, E, α))::T == hvcat((1,2), A, E, [α]) == hvcat((1,2), A, E, α*I)
366+
@test (hvcat((2,2), α, E, F, 3I))::T == hvcat((2,2), [α], E, F, Matrix(3I, 3, 3))
367+
@test (hvcat((2,2), 3I, F, E, α))::T == hvcat((2,2), Matrix(3I, 3, 3), F, E, [α])
356368
end
357369
end
358370

stdlib/SparseArrays/src/sparsevector.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,23 +1083,26 @@ const _Triangular_DenseArrays{T,A<:Matrix} = LinearAlgebra.AbstractTriangular{T,
10831083
const _Annotated_DenseArrays = Union{_Triangular_DenseArrays, _Symmetric_DenseArrays, _Hermitian_DenseArrays}
10841084
const _Annotated_Typed_DenseArrays{T} = Union{_Triangular_DenseArrays{T}, _Symmetric_DenseArrays{T}, _Hermitian_DenseArrays{T}}
10851085

1086-
const _SparseConcatGroup = Union{Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _SparseConcatArrays, _Annotated_SparseConcatArrays, _Annotated_DenseArrays}
1087-
const _DenseConcatGroup = Union{Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _Annotated_DenseArrays}
1086+
const _SparseConcatGroup = Union{Number, Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _SparseConcatArrays, _Annotated_SparseConcatArrays, _Annotated_DenseArrays}
1087+
const _DenseConcatGroup = Union{Number, Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _Annotated_DenseArrays}
10881088
const _TypedDenseConcatGroup{T} = Union{Vector{T}, Adjoint{T,Vector{T}}, Transpose{T,Vector{T}}, Matrix{T}, _Annotated_Typed_DenseArrays{T}}
10891089

10901090
# Concatenations involving un/annotated sparse/special matrices/vectors should yield sparse arrays
1091+
_makesparse(x::Number) = x
1092+
_makesparse(x::AbstractArray) = SparseMatrixCSC(issparse(x) ? x : sparse(x))
1093+
10911094
function Base._cat(dims, Xin::_SparseConcatGroup...)
1092-
X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin)
1095+
X = map(_makesparse, Xin)
10931096
T = promote_eltype(Xin...)
10941097
Base.cat_t(T, X...; dims=dims)
10951098
end
10961099
function hcat(Xin::_SparseConcatGroup...)
1097-
X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin)
1098-
hcat(X...)
1100+
X = map(_makesparse, Xin)
1101+
return cat(X..., dims=Val(2))
10991102
end
11001103
function vcat(Xin::_SparseConcatGroup...)
1101-
X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin)
1102-
vcat(X...)
1104+
X = map(_makesparse, Xin)
1105+
return cat(X..., dims=Val(1))
11031106
end
11041107
hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) =
11051108
vcat(_hvcat_rows(rows, X...)...)

0 commit comments

Comments
 (0)