From c4f94c1ad962636adaebbb92e9f68e5b1b1f5c82 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Wed, 29 May 2024 20:53:03 -0400 Subject: [PATCH 1/3] Accomodate for rectangular matrices in copytrito! --- src/host/linalg.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 7548b4f15..79384a7d4 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -111,8 +111,12 @@ if isdefined(LinearAlgebra, :copytrito!) LinearAlgebra.BLAS.chkuplo(uplo) m,n = size(A) m1,n1 = size(B) - (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)")) if uplo == 'U' + if n < m + (m1 < n || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($n,$n)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)")) + end @kernel function U_kernel!(_A, _B) I = @index(Global, Cartesian) i, j = Tuple(I) @@ -122,6 +126,11 @@ if isdefined(LinearAlgebra, :copytrito!) end U_kernel!(get_backend(B))(A, B; ndrange = size(A)) else # uplo == 'L' + if m < n + (m1 < m || n1 < m) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$m)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)")) + end @kernel function L_kernel!(_A, _B) I = @index(Global, Cartesian) i, j = Tuple(I) From 37d2f39838e77cfcc83a11a61927cd9c6b1b0ca3 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Tue, 29 Oct 2024 11:43:56 -0500 Subject: [PATCH 2/3] Add tests with rectangular matrices --- src/host/linalg.jl | 12 ++---------- test/testsuite/linalg.jl | 7 +++++++ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 79384a7d4..34745843f 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -112,11 +112,7 @@ if isdefined(LinearAlgebra, :copytrito!) m,n = size(A) m1,n1 = size(B) if uplo == 'U' - if n < m - (m1 < n || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($n,$n)")) - else - (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)")) - end + LinearAlgebra.LAPACK.lacpy_size_check((m1, n1), (n < m ? n : m, n)) @kernel function U_kernel!(_A, _B) I = @index(Global, Cartesian) i, j = Tuple(I) @@ -126,11 +122,7 @@ if isdefined(LinearAlgebra, :copytrito!) end U_kernel!(get_backend(B))(A, B; ndrange = size(A)) else # uplo == 'L' - if m < n - (m1 < m || n1 < m) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$m)")) - else - (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)")) - end + LinearAlgebra.LAPACK.lacpy_size_check((m1, n1), (m, m < n ? m : n)) @kernel function L_kernel!(_A, _B) I = @index(Global, Cartesian) i, j = Tuple(I) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 7c03e69f0..8de318549 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -85,6 +85,13 @@ B = zeros(T,n,n) @test compare(copytrito!, AT, B, A, uplo) end + @testset for T in eltypes, uplo in ('L', 'U') + n = 16 + m = 32 + A = uplo == 'U' ? rand(T,m,n) : rand(T,n,m) + B = zeros(T,n,n) + @test compare(copytrito!, AT, B, A, uplo) + end end end From e322a317580ea467d0c5a4d4f75791f434092eda Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Tue, 29 Oct 2024 14:40:48 -0500 Subject: [PATCH 3/3] Don't use lacpy_size_check --- src/host/linalg.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 34745843f..79384a7d4 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -112,7 +112,11 @@ if isdefined(LinearAlgebra, :copytrito!) m,n = size(A) m1,n1 = size(B) if uplo == 'U' - LinearAlgebra.LAPACK.lacpy_size_check((m1, n1), (n < m ? n : m, n)) + if n < m + (m1 < n || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($n,$n)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)")) + end @kernel function U_kernel!(_A, _B) I = @index(Global, Cartesian) i, j = Tuple(I) @@ -122,7 +126,11 @@ if isdefined(LinearAlgebra, :copytrito!) end U_kernel!(get_backend(B))(A, B; ndrange = size(A)) else # uplo == 'L' - LinearAlgebra.LAPACK.lacpy_size_check((m1, n1), (m, m < n ? m : n)) + if m < n + (m1 < m || n1 < m) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$m)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)")) + end @kernel function L_kernel!(_A, _B) I = @index(Global, Cartesian) i, j = Tuple(I)