From a554819db802f33bf2fc0df85baa36a36f0763fc Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 12 May 2023 18:06:34 +0200 Subject: [PATCH 1/7] Reduce number of `mul!` methods --- src/host/linalg.jl | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 5db045d91..4c3aada88 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -319,27 +319,13 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac C end -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) - -# specificity hacks -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) +using LinearAlgebra: MulAddMul +function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul=MulAddMul()) + transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint + transB = tB == 'N' ? identity : tB == 'T' ? transpose : adjoint + generic_matmatmul!(C, transA(A), transB(B), a, b) +end function generic_rmul!(X::AbstractArray, s::Number) gpu_call(X, s; name="rmul!") do ctx, X, s From 1185b64b2d50a0a6e8173bf64daaadf845574d87 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 12 May 2023 18:29:17 +0200 Subject: [PATCH 2/7] Update linalg.jl --- src/host/linalg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 4c3aada88..df9fc534e 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -324,7 +324,7 @@ using LinearAlgebra: MulAddMul function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul=MulAddMul()) transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint transB = tB == 'N' ? identity : tB == 'T' ? transpose : adjoint - generic_matmatmul!(C, transA(A), transB(B), a, b) + generic_matmatmul!(C, transA(A), transB(B), A, B) end function generic_rmul!(X::AbstractArray, s::Number) From b57b795497a386352874d5b1edd96d4359e74aab Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 12 May 2023 18:30:34 +0200 Subject: [PATCH 3/7] Update linalg.jl --- src/host/linalg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index df9fc534e..3cd508024 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -324,7 +324,7 @@ using LinearAlgebra: MulAddMul function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul=MulAddMul()) transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint transB = tB == 'N' ? identity : tB == 'T' ? transpose : adjoint - generic_matmatmul!(C, transA(A), transB(B), A, B) + generic_matmatmul!(C, transA(A), transB(B), _add.alpha, _add.beta) end function generic_rmul!(X::AbstractArray, s::Number) From 410fe477a03e0fe0748406e945d05b59862ddce3 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 12 May 2023 20:01:10 +0200 Subject: [PATCH 4/7] catch mv mul cases --- src/host/linalg.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 3cd508024..14204196e 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -321,6 +321,20 @@ end using LinearAlgebra: MulAddMul +function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number) + transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint + generic_matmatmul!(C, transA(A), B, a, b) +end +function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul()) + transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint + generic_matmatmul!(C, transA(A), B, a, b) +end +# disambiguation +function LinearAlgebra.gemv!(C::AbstractGPUVector{T}, tA::AbstractChar, A::AbstractGPUMatrix{T}, B::AbstractGPUVector{T}, a::Number, b::Number) where {T<:LinearAlgebra.BlasFloat} + transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint + generic_matmatmul!(C, transA(A), B, a, b) +end + function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul=MulAddMul()) transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint transB = tB == 'N' ? identity : tB == 'T' ? transpose : adjoint From 1feb1f7b33dbb9de7589bda89eecbbe99af426be Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 12 May 2023 20:17:16 +0200 Subject: [PATCH 5/7] Update linalg.jl --- src/host/linalg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 14204196e..65ca0bac6 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -327,7 +327,7 @@ function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::Abstract end function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul()) transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint - generic_matmatmul!(C, transA(A), B, a, b) + generic_matmatmul!(C, transA(A), B, _add.alpha, _add.beta) end # disambiguation function LinearAlgebra.gemv!(C::AbstractGPUVector{T}, tA::AbstractChar, A::AbstractGPUMatrix{T}, B::AbstractGPUVector{T}, a::Number, b::Number) where {T<:LinearAlgebra.BlasFloat} From e50f24a6a5d820d83cf30693e8d6fba6489f64d8 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 31 May 2023 20:49:21 +0200 Subject: [PATCH 6/7] update --- src/host/linalg.jl | 70 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 65ca0bac6..0fa96bcab 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -1,5 +1,29 @@ # integration with LinearAlgebra stdlib +using LinearAlgebra: MulAddMul + +if isdefined(LinearAlgebra, :wrap) # i.e., VERSION >= v"1.10.0-DEV.1365" + using LinearAlgebra: wrap +else + function wrap(A::AbstractVecOrMat, tA::AbstractChar) + if tA == 'N' + return A + elseif tA == 'T' + return transpose(A) + elseif tA == 'C' + return adjoint(A) + elseif tA == 'H' + return Hermitian(A, :U) + elseif tA == 'h' + return Hermitian(A, :L) + elseif tA == 'S' + return Symmetric(A, :U) + else # tA == 's' + return Symmetric(A, :L) + end + end +end + ## transpose and adjoint function LinearAlgebra.transpose!(B::AbstractGPUVector, A::AbstractGPUMatrix) @@ -319,28 +343,46 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac C end -using LinearAlgebra: MulAddMul +function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul()) + generic_matmatmul!(C, wrap(A, tA), B, _add.alpha, _add.beta) +end -function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number) - transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint - generic_matmatmul!(C, transA(A), B, a, b) +function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul=MulAddMul()) + generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta) end -function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul()) - transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint - generic_matmatmul!(C, transA(A), B, _add.alpha, _add.beta) + +if VERSION < v"1.10.0-DEV.1365" +# catch other functions that are called by LinearAlgebra's mul! +function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number) + generic_matmatmul!(C, wrap(A, tA), B, a, b) end # disambiguation function LinearAlgebra.gemv!(C::AbstractGPUVector{T}, tA::AbstractChar, A::AbstractGPUMatrix{T}, B::AbstractGPUVector{T}, a::Number, b::Number) where {T<:LinearAlgebra.BlasFloat} - transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint - generic_matmatmul!(C, transA(A), B, a, b) + generic_matmatmul!(C, wrap(A, tA), B, a, b) end -function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul=MulAddMul()) - transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint - transB = tB == 'N' ? identity : tB == 'T' ? transpose : adjoint - generic_matmatmul!(C, transA(A), transB(B), _add.alpha, _add.beta) +LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul) = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add) +# disambiguation +LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat{T}, tA, tB, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, _add::MulAddMul) where {T<:LinearAlgebra.BlasFloat} = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add) + +function LinearAlgebra.syrk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul()) + if tA == 'T' + LinearAlgebra.generic_matmatmul!(C, 'T', 'N', A, A, _add) + else # tA == 'N' + LinearAlgebra.generic_matmatmul!(C, 'N', 'T', A, A, _add) + end end - +function LinearAlgebra.herk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul()) + if tA == 'C' + LinearAlgebra.generic_matmatmul!(C, 'C', 'N', A, A, _add) + else # tA == 'N' + LinearAlgebra.generic_matmatmul!(C, 'N', 'C', A, A, _add) + end +end +end # VERSION + function generic_rmul!(X::AbstractArray, s::Number) gpu_call(X, s; name="rmul!") do ctx, X, s i = @linearidx X From 23382b2caca37f57f44fcd8a2c9da9ea876f0c44 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 31 May 2023 21:35:14 +0200 Subject: [PATCH 7/7] add type annotations --- src/host/linalg.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 0fa96bcab..fc90d9a1c 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -361,10 +361,10 @@ function LinearAlgebra.gemv!(C::AbstractGPUVector{T}, tA::AbstractChar, A::Abstr generic_matmatmul!(C, wrap(A, tA), B, a, b) end -LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul) = +LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul) = LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add) # disambiguation -LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat{T}, tA, tB, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, _add::MulAddMul) where {T<:LinearAlgebra.BlasFloat} = +LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, _add::MulAddMul) where {T<:LinearAlgebra.BlasFloat} = LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add) function LinearAlgebra.syrk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul())