Skip to content

Commit fec82b7

Browse files
dkarraschKristofferC
authored andcommitted
Complete size checks in BLAS.[sy/he]mm! (#45605)
(cherry picked from commit da13d78)
1 parent 5e9bb06 commit fec82b7

File tree

3 files changed

+56
-11
lines changed

3 files changed

+56
-11
lines changed

stdlib/LinearAlgebra/src/blas.jl

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
2-
32
"""
43
Interface to BLAS subroutines.
54
"""
@@ -1509,11 +1508,27 @@ for (mfname, elty) in ((:dsymm_,:Float64),
15091508
require_one_based_indexing(A, B, C)
15101509
m, n = size(C)
15111510
j = checksquare(A)
1512-
if j != (side == 'L' ? m : n)
1513-
throw(DimensionMismatch("A has size $(size(A)), C has size ($m,$n)"))
1514-
end
1515-
if size(B,2) != n
1516-
throw(DimensionMismatch("B has second dimension $(size(B,2)) but needs to match second dimension of C, $n"))
1511+
M, N = size(B)
1512+
if side == 'L'
1513+
if j != m
1514+
throw(DimensionMismatch("A has first dimension $j but needs to match first dimension of C, $m"))
1515+
end
1516+
if N != n
1517+
throw(DimensionMismatch("B has second dimension $N but needs to match second dimension of C, $n"))
1518+
end
1519+
if j != M
1520+
throw(DimensionMismatch("A has second dimension $j but needs to match first dimension of B, $M"))
1521+
end
1522+
else
1523+
if j != n
1524+
throw(DimensionMismatch("B has second dimension $j but needs to match second dimension of C, $n"))
1525+
end
1526+
if N != j
1527+
throw(DimensionMismatch("A has second dimension $N but needs to match first dimension of B, $j"))
1528+
end
1529+
if M != m
1530+
throw(DimensionMismatch("A has first dimension $M but needs to match first dimension of C, $m"))
1531+
end
15171532
end
15181533
chkstride1(A)
15191534
chkstride1(B)
@@ -1582,11 +1597,27 @@ for (mfname, elty) in ((:zhemm_,:ComplexF64),
15821597
require_one_based_indexing(A, B, C)
15831598
m, n = size(C)
15841599
j = checksquare(A)
1585-
if j != (side == 'L' ? m : n)
1586-
throw(DimensionMismatch("A has size $(size(A)), C has size ($m,$n)"))
1587-
end
1588-
if size(B,2) != n
1589-
throw(DimensionMismatch("B has second dimension $(size(B,2)) but needs to match second dimension of C, $n"))
1600+
M, N = size(B)
1601+
if side == 'L'
1602+
if j != m
1603+
throw(DimensionMismatch("A has first dimension $j but needs to match first dimension of C, $m"))
1604+
end
1605+
if N != n
1606+
throw(DimensionMismatch("B has second dimension $N but needs to match second dimension of C, $n"))
1607+
end
1608+
if j != M
1609+
throw(DimensionMismatch("A has second dimension $j but needs to match first dimension of B, $M"))
1610+
end
1611+
else
1612+
if j != n
1613+
throw(DimensionMismatch("B has second dimension $j but needs to match second dimension of C, $n"))
1614+
end
1615+
if N != j
1616+
throw(DimensionMismatch("A has second dimension $N but needs to match first dimension of B, $j"))
1617+
end
1618+
if M != m
1619+
throw(DimensionMismatch("A has first dimension $M but needs to match first dimension of C, $m"))
1620+
end
15901621
end
15911622
chkstride1(A)
15921623
chkstride1(B)

stdlib/LinearAlgebra/test/blas.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,19 @@ Random.seed!(100)
223223
@test_throws DimensionMismatch BLAS.symm('R','U',Cmn,Cnn)
224224
@test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cmn)
225225
@test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cnm)
226+
@test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cmn,one(elty),Cnn)
227+
@test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cnm,one(elty),Cmn)
228+
@test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cnn,one(elty),Cnm)
229+
@test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cmn,one(elty),Cnn)
226230
if elty <: BlasComplex
227231
@test_throws DimensionMismatch BLAS.hemm('L','U',Cnm,Cnn)
228232
@test_throws DimensionMismatch BLAS.hemm('R','U',Cmn,Cnn)
229233
@test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cmn)
230234
@test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cnm)
235+
@test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cmn,one(elty),Cnn)
236+
@test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cnm,one(elty),Cmn)
237+
@test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cnn,one(elty),Cnm)
238+
@test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cmn,one(elty),Cnn)
231239
end
232240
end
233241
end

stdlib/LinearAlgebra/test/symmetric.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,9 @@ end
340340
C = zeros(eltya,n,n)
341341
@test Hermitian(aherm) * a aherm * a
342342
@test a * Hermitian(aherm) a * aherm
343+
# rectangular multiplication
344+
@test [a; a] * Hermitian(aherm) [a; a] * aherm
345+
@test Hermitian(aherm) * [a a] aherm * [a a]
343346
@test Hermitian(aherm) * Hermitian(aherm) aherm*aherm
344347
@test_throws DimensionMismatch Hermitian(aherm) * Vector{eltya}(undef, n+1)
345348
LinearAlgebra.mul!(C,a,Hermitian(aherm))
@@ -348,6 +351,9 @@ end
348351
@test Symmetric(asym) * Symmetric(asym) asym*asym
349352
@test Symmetric(asym) * a asym * a
350353
@test a * Symmetric(asym) a * asym
354+
# rectangular multiplication
355+
@test Symmetric(asym) * [a a] asym * [a a]
356+
@test [a; a] * Symmetric(asym) [a; a] * asym
351357
@test_throws DimensionMismatch Symmetric(asym) * Vector{eltya}(undef, n+1)
352358
LinearAlgebra.mul!(C,a,Symmetric(asym))
353359
@test C a*asym

0 commit comments

Comments
 (0)