diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 26e59fadd..bc910be4e 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -417,6 +417,21 @@ end function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::LinearAlgebra.BlasFlag.SyrkHerkGemm) where {T} LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) end + # Julia 1.12 introduced generic_mul! for scalar * array operations + function LinearAlgebra.generic_mul!(C::AbstractGPUVecOrMat, X::AbstractGPUVecOrMat, s::Number, alpha::Number, beta::Number) + if length(C) != length(X) + throw(DimensionMismatch(lazy"first array has length $(length(C)) which does not match the length of the second, $(length(X)).")) + end + @. C = X * s * alpha + C * beta + return C + end + function LinearAlgebra.generic_mul!(C::AbstractGPUVecOrMat, s::Number, X::AbstractGPUVecOrMat, alpha::Number, beta::Number) + if length(C) != length(X) + throw(DimensionMismatch(lazy"first array has length $(length(C)) which does not match the length of the second, $(length(X)).")) + end + @. C = s * X * alpha + C * beta + return C + end end function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R} @@ -730,7 +745,7 @@ function LinearAlgebra.rotate!(x::AbstractGPUArray, y::AbstractGPUArray, c::Numb @inbounds xi = x[i] @inbounds yi = y[i] @inbounds x[i] = s*yi + c *xi - @inbounds y[i] = c*yi - conj(s)*xi + @inbounds y[i] = c*yi - conj(s)*xi end rotate_kernel!(get_backend(x))(x, y, c, s; ndrange = size(x)) return x, y diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index d77a2bb50..7b79a62e2 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -256,17 +256,17 @@ B = Diagonal(b) A = Diagonal(a) mul!(C, A, B) - @test collect(C.diag) ≈ collect(A.diag) .* collect(B.diag) + @test collect(C.diag) ≈ collect(A.diag) .* collect(B.diag) a = AT(diagm(rand(elty, n))) b = AT(diagm(rand(elty, n))) C = Diagonal(d) mul!(C, a, b) - @test collect(C) ≈ Diagonal(collect(a) * collect(b)) + @test collect(C) ≈ Diagonal(collect(a) * collect(b)) a = transpose(AT(diagm(rand(elty, n)))) b = adjoint(AT(diagm(rand(elty, n)))) C = Diagonal(d) mul!(C, a, b) - @test collect(C) ≈ Diagonal(collect(a) * collect(b)) + @test collect(C) ≈ Diagonal(collect(a) * collect(b)) end end @@ -303,6 +303,42 @@ end end + @testset "mul! + UniformScaling" begin + for elty in (Float32, ComplexF32) + n = 128 + s = rand(elty) + I_s = UniformScaling(s) + + # Test vector operations + a = AT(rand(elty, n)) + b = AT(rand(elty, n)) + b_copy = copy(b) + + # Test mul!(a, I*s, b) - should compute a = s * b + mul!(a, I_s, b) + @test collect(a) ≈ s .* collect(b_copy) + + # Test mul!(a, b, s) - should compute a = b * s + a = AT(rand(elty, n)) + mul!(a, b, s) + @test collect(a) ≈ collect(b_copy) .* s + + # Test matrix operations + A = AT(rand(elty, n, n)) + B = AT(rand(elty, n, n)) + B_copy = copy(B) + + # Test mul!(A, I*s, B) + mul!(A, I_s, B) + @test collect(A) ≈ s .* collect(B_copy) + + # Test mul!(A, B, s) + A = AT(rand(elty, n, n)) + mul!(A, B, s) + @test collect(A) ≈ collect(B_copy) .* s + end + end + @testset "lmul! and rmul!" for (a,b) in [((3,4),(4,3)), ((3,), (1,3)), ((1,3), (3))], T in eltypes @test compare(rmul!, AT, rand(T, a), Ref(rand(T))) @test compare(lmul!, AT, Ref(rand(T)), rand(T, b))