diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 97dad15c2..2282528e7 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -244,6 +244,20 @@ function Base.:\(D::Diagonal{<:Any, <:AbstractGPUArray}, B::AbstractGPUVecOrMat) end end +function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray}, + A::Diagonal{<:Any, <:AbstractGPUArray}, + B::Diagonal{<:Any, <:AbstractGPUArray}) + dc = C.diag + da = A.diag + db = B.diag + d = length(dc) + length(da) == d || throw(DimensionMismatch("right hand side has $(length(da)) rows but output is $d by $d")) + length(db) == d || throw(DimensionMismatch("left hand side has $(length(db)) rows but output is $d by $d")) + @. dc = da * db + + return C +end + function LinearAlgebra.mul!(B::AbstractGPUVecOrMat, D::Diagonal{<:Any, <:AbstractGPUArray}, A::AbstractGPUVecOrMat) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index c17868f81..5770c85f0 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -243,6 +243,13 @@ mul!(X, B, D, α, β) mul!(Y, collect(B), Diagonal(collect(d)), α, β) @test collect(X) ≈ Y + a = AT(rand(elty, n)) + b = AT(rand(elty, n)) + C = Diagonal(d) + B = Diagonal(b) + A = Diagonal(a) + mul!(C, A, B) + @test collect(C.diag) ≈ collect(A.diag) .* collect(B.diag) end end