Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ function Base.copyto!(A::Array{T,N}, B::Transpose{T, <:AbstractGPUArray{T,N}}) w
copyto!(A, Transpose(Array(parent(B))))
end

## trace

function LinearAlgebra.tr(A::AnyGPUMatrix)
LinearAlgebra.checksquare(A)
sum(diag(A))
end

## copy upper triangle to lower and vice versa

function LinearAlgebra.copytri!(A::AbstractGPUMatrix, uplo::AbstractChar, conjugate::Bool=false)
Expand Down
3 changes: 3 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
@test compare(transpose!, AT, Array{Float32}(undef, 128, 32), rand(Float32, 32, 128))
end

@testset "tr" begin
@test compare(tr, AT, rand(Float32, 32, 32))
end
@testset "permutedims" begin
@test compare(x -> permutedims(x, (2, 1)), AT, rand(Float32, 2, 3))
@test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6))
Expand Down
Loading