From 489e0990df5fc4a02cc38fdac318854d78a72652 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 9 Oct 2025 17:14:22 -0400 Subject: [PATCH 1/6] Add `tr` function for `AbstractGPUMatrix` --- src/host/linalg.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 9fad7bbb8..b026dda32 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -58,6 +58,11 @@ function Base.copyto!(A::Array{T,N}, B::Transpose{T, <:AbstractGPUArray{T,N}}) w copyto!(A, Transpose(Array(parent(B)))) end +function LinearAlgebra.tr(A::AbstractGPUMatrix) + checksquare(A) + sum(diag(A)) +end + ## copy upper triangle to lower and vice versa function LinearAlgebra.copytri!(A::AbstractGPUMatrix, uplo::AbstractChar, conjugate::Bool=false) From 5ec920570ab5eae3b5e774ae23d41da65372ecca Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 9 Oct 2025 17:17:52 -0400 Subject: [PATCH 2/6] Add test for 'tr' function in linalg.jl Added tests for the 'tr' function in the linalg test suite. --- test/testsuite/linalg.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index bdcc6cd50..87983fe85 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -14,6 +14,10 @@ @test compare(transpose!, AT, Array{Float32}(undef, 128, 32), rand(Float32, 32, 128)) end + @testset "tr" begin + @test compare(tr, AT, rand(Float32, 32, 32)) + end + @testset "permutedims" begin @test compare(x -> permutedims(x, (2, 1)), AT, rand(Float32, 2, 3)) @test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6)) From 46103c1c68b1b9b5cb0955531a2091cb1de2bbe9 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 9 Oct 2025 18:02:23 -0400 Subject: [PATCH 3/6] Fix call to LinearAlgebra.checksquare --- src/host/linalg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index b026dda32..9693f03aa 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -59,7 +59,7 @@ function Base.copyto!(A::Array{T,N}, B::Transpose{T, <:AbstractGPUArray{T,N}}) w end function LinearAlgebra.tr(A::AbstractGPUMatrix) - checksquare(A) + LinearAlgebra.checksquare(A) sum(diag(A)) end From c0351a6f7dbee1e1c679f1bfe7882467049f5ec7 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 9 Oct 2025 18:04:06 -0400 Subject: [PATCH 4/6] Generalize `tr` to AnyGPUMatrix --- src/host/linalg.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 9693f03aa..8f871868b 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -58,7 +58,9 @@ function Base.copyto!(A::Array{T,N}, B::Transpose{T, <:AbstractGPUArray{T,N}}) w copyto!(A, Transpose(Array(parent(B)))) end -function LinearAlgebra.tr(A::AbstractGPUMatrix) +## trace + +function LinearAlgebra.tr(A::AnyGPUMatrix) LinearAlgebra.checksquare(A) sum(diag(A)) end From 6ee51c97d88eed43c6c3a5ada676171c76dce07c Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 9 Oct 2025 18:15:06 -0400 Subject: [PATCH 5/6] Format --- test/testsuite/linalg.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 87983fe85..7e6765755 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -17,7 +17,6 @@ @testset "tr" begin @test compare(tr, AT, rand(Float32, 32, 32)) end - @testset "permutedims" begin @test compare(x -> permutedims(x, (2, 1)), AT, rand(Float32, 2, 3)) @test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6)) From 64b163b795f409d01c5ca573beb651058edc6e8f Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 9 Oct 2025 18:15:25 -0400 Subject: [PATCH 6/6] Format --- test/testsuite/linalg.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 7e6765755..f9ac5d924 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -17,6 +17,7 @@ @testset "tr" begin @test compare(tr, AT, rand(Float32, 32, 32)) end + @testset "permutedims" begin @test compare(x -> permutedims(x, (2, 1)), AT, rand(Float32, 2, 3)) @test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6))