diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 79384a7d4..b4acc8e71 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -283,7 +283,8 @@ function LinearAlgebra.mul!(B::AbstractGPUVecOrMat, m′, n′ = size(B, 1), size(B, 2) n == d || throw(DimensionMismatch("left hand side has $n columns but D is $d by $d")) (m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′")) - B .= A .* transpose(dd) + ddT = transpose(dd) + @. B = A * ddT B end @@ -299,7 +300,8 @@ function LinearAlgebra.mul!(B::AbstractGPUVecOrMat, m′, n′ = size(B, 1), size(B, 2) n == d || throw(DimensionMismatch("left hand side has $n columns but D is $d by $d")) (m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′")) - B .= α * A .* transpose(dd) + β * B + ddT = transpose(dd) + @. B = α * A * ddT + β * B B end