diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 9fad7bbb8..8f871868b 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -58,6 +58,13 @@ function Base.copyto!(A::Array{T,N}, B::Transpose{T, <:AbstractGPUArray{T,N}}) w copyto!(A, Transpose(Array(parent(B)))) end +## trace + +function LinearAlgebra.tr(A::AnyGPUMatrix) + LinearAlgebra.checksquare(A) + sum(diag(A)) +end + ## copy upper triangle to lower and vice versa function LinearAlgebra.copytri!(A::AbstractGPUMatrix, uplo::AbstractChar, conjugate::Bool=false) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index bdcc6cd50..f9ac5d924 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -14,6 +14,10 @@ @test compare(transpose!, AT, Array{Float32}(undef, 128, 32), rand(Float32, 32, 128)) end + @testset "tr" begin + @test compare(tr, AT, rand(Float32, 32, 32)) + end + @testset "permutedims" begin @test compare(x -> permutedims(x, (2, 1)), AT, rand(Float32, 2, 3)) @test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6))