diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 4933839c6..97dad15c2 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -107,7 +107,7 @@ end ## copy a triangular part of a matrix to another matrix if isdefined(LinearAlgebra, :copytrito!) - function LinearAlgebra.copytrito!(B::AbstractGPUMatrix, A::AbstractGPUMatrix, uplo::AbstractChar) + function LinearAlgebra.copytrito!(B::AbstractGPUMatrix{T}, A::AbstractGPUMatrix{T}, uplo::AbstractChar) where {T} LinearAlgebra.BLAS.chkuplo(uplo) m,n = size(A) m1,n1 = size(B) @@ -376,6 +376,13 @@ function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::Abs LinearAlgebra.@stable_muladdmul generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(a, b)) end end +@static if VERSION ≥ v"1.12.0-rc" + # we need to use the generic wrapper to avoid dispatch to the 2x2or3x3 method + using LinearAlgebra: generic_matmatmul_wrapper!, BlasFlag + function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::LinearAlgebra.BlasFlag.SyrkHerkGemm) where {T} + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) + end +end function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R} if size(A,2) != size(B,1)