Skip to content

Commit 37d2f39

Browse files
committed
Add tests with rectangular matrices
1 parent c4f94c1 commit 37d2f39

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

src/host/linalg.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,7 @@ if isdefined(LinearAlgebra, :copytrito!)
112112
m,n = size(A)
113113
m1,n1 = size(B)
114114
if uplo == 'U'
115-
if n < m
116-
(m1 < n || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($n,$n)"))
117-
else
118-
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)"))
119-
end
115+
LinearAlgebra.LAPACK.lacpy_size_check((m1, n1), (n < m ? n : m, n))
120116
@kernel function U_kernel!(_A, _B)
121117
I = @index(Global, Cartesian)
122118
i, j = Tuple(I)
@@ -126,11 +122,7 @@ if isdefined(LinearAlgebra, :copytrito!)
126122
end
127123
U_kernel!(get_backend(B))(A, B; ndrange = size(A))
128124
else # uplo == 'L'
129-
if m < n
130-
(m1 < m || n1 < m) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$m)"))
131-
else
132-
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)"))
133-
end
125+
LinearAlgebra.LAPACK.lacpy_size_check((m1, n1), (m, m < n ? m : n))
134126
@kernel function L_kernel!(_A, _B)
135127
I = @index(Global, Cartesian)
136128
i, j = Tuple(I)

test/testsuite/linalg.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@
8585
B = zeros(T,n,n)
8686
@test compare(copytrito!, AT, B, A, uplo)
8787
end
88+
@testset for T in eltypes, uplo in ('L', 'U')
89+
n = 16
90+
m = 32
91+
A = uplo == 'U' ? rand(T,m,n) : rand(T,n,m)
92+
B = zeros(T,n,n)
93+
@test compare(copytrito!, AT, B, A, uplo)
94+
end
8895
end
8996
end
9097

0 commit comments

Comments
 (0)