diff --git a/lib/cublas/linalg.jl b/lib/cublas/linalg.jl index 401caec08d..f5ffaffe7c 100644 --- a/lib/cublas/linalg.jl +++ b/lib/cublas/linalg.jl @@ -1,8 +1,5 @@ # interfacing with LinearAlgebra standard library -cublas_size(t::Char, M::CuVecOrMat) = (size(M, t=='N' ? 1 : 2), size(M, t=='N' ? 2 : 1)) - - # # BLAS 1 @@ -74,40 +71,52 @@ end # GEMV -function gemv_wrapper!(y::CuVector{T}, tA::Char, A::CuMatrix{T}, x::CuVector{T}, - alpha::Number = true, beta::Number = false) where T<:CublasFloat - mA, nA = cublas_size(tA, A) - if nA != length(x) - throw(DimensionMismatch("second dimension of A, $nA, does not match length of x, $(length(x))")) +function gemv_dispatch!(Y::CuVector, A, B, alpha::Number=true, beta::Number=false) + mA, nA = size(A) + + if nA != length(B) + throw(DimensionMismatch("second dimension of A, $nA, does not match length of B, $(length(B))")) end - if mA != length(y) - throw(DimensionMismatch("first dimension of A, $mA, does not match length of y, $(length(y))")) + + if mA != length(Y) + throw(DimensionMismatch("first dimension of A, $mA, does not match length of Y, $(length(Y))")) end + if mA == 0 - return y + return Y end + if nA == 0 - return rmul!(y, 0) + return rmul!(Y, 0) + end + + tA, dA = if A isa Transpose + 'T', parent(A) + elseif A isa Adjoint + 'C', parent(A) + else + 'N', A + end + + T = eltype(Y) + if T <: CublasFloat && A isa StridedCuArray{T} && B isa StridedCuArray{T} + gemv!(tA, alpha, dA, B, beta, Y) + else + gemm_dispatch!(Y, A, B, alpha, beta) end - gemv!(tA, alpha, A, x, beta, y) end -LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}, a::Number, b::Number) where T<:CublasFloat = - gemv_wrapper!(Y, 'N', A, B, a, b) -LinearAlgebra.mul!(Y::CuVector{T}, A::Transpose{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasFloat = - gemv_wrapper!(Y, 'T', A.parent, B, a, b) -LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Real, b::Real) where T<:CublasReal = - gemv_wrapper!(Y, 'T', A.parent, B, a, b) -LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasComplex = - gemv_wrapper!(Y, 'C', A.parent, B, a, b) - -# ambiguity hacks: Base and GPUArrays has mul! with a::Real, b::Real -LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}, a::Real, b::Real) where T<:CublasFloat = - gemv_wrapper!(Y, 'N', A, B, a, b) -LinearAlgebra.mul!(Y::CuVector{T}, A::Transpose{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Real, b::Real) where T<:CublasFloat = - gemv_wrapper!(Y, 'T', A.parent, B, a, b) -LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Real, b::Real) where T<:CublasComplex = - gemv_wrapper!(Y, 'C', A.parent, B, a, b) +for NT in (Number, Real) + # NOTE: alpha/beta also ::Real to avoid ambiguities with certain Base methods + @eval begin + LinearAlgebra.mul!(Y::CuVector, A::StridedCuMatrix, B::StridedCuVector, a::$NT, b::$NT) = + gemv_dispatch!(Y, A, B, a, b) + LinearAlgebra.mul!(Y::CuVector, A::Transpose{<:Any, <:StridedCuVecOrMat}, B::StridedCuVector, a::$NT, b::$NT) = + gemv_dispatch!(Y, A, B, a, b) + LinearAlgebra.mul!(Y::CuVector, A::Adjoint{<:Any, <:StridedCuVecOrMat}, B::StridedCuVector, a::$NT, b::$NT) = + gemv_dispatch!(Y, A, B, a, b) + end +end # TRSV @@ -162,8 +171,13 @@ end # GEMM function gemm_dispatch!(C::CuVecOrMat, A, B, alpha::Number=true, beta::Number=false) - mA, nA = size(A) - mB, nB = size(B) + if ndims(A) > 2 + throw(ArgumentError("A has more than 2 dimensions")) + elseif ndims(B) > 2 + throw(ArgumentError("B has more than 2 dimensions")) + end + mA, nA = size(A,1), size(A,2) + mB, nB = size(B,1), size(B,2) if nA != mB throw(DimensionMismatch("A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)")) @@ -196,9 +210,11 @@ function gemm_dispatch!(C::CuVecOrMat, A, B, alpha::Number=true, beta::Number=fa 'N', B end - if gemmExComputeType(eltype(A), eltype(B), eltype(C), mA, nA, nB) !== nothing + T = eltype(C) + if dA isa DenseCuArray && dB isa DenseCuArray && + gemmExComputeType(eltype(A), eltype(B), eltype(C), mA, nA, nB) !== nothing gemmEx!(tA, tB, alpha, dA, dB, beta, C) - elseif eltype(A) === eltype(B) === eltype(C) && eltype(A) <: CublasFloat + elseif T <: CublasFloat && dA isa DenseCuArray{T} && dB isa DenseCuArray{T} gemm!(tA, tB, alpha, dA, dB, beta, C) else GPUArrays.generic_matmatmul!(C, A, B, alpha, beta) @@ -208,26 +224,26 @@ end for NT in (Number, Real) # NOTE: alpha/beta also ::Real to avoid ambiguities with certain Base methods @eval begin - LinearAlgebra.mul!(C::CuMatrix, A::CuVecOrMat, B::CuVecOrMat, a::$NT, b::$NT) = + LinearAlgebra.mul!(C::CuMatrix, A::StridedCuVecOrMat, B::StridedCuVecOrMat, a::$NT, b::$NT) = gemm_dispatch!(C, A, B, a, b) - LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:CuVecOrMat}, B::CuMatrix, a::$NT, b::$NT) = + LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:StridedCuVecOrMat}, B::StridedCuMatrix, a::$NT, b::$NT) = gemm_dispatch!(C, A, B, a, b) - LinearAlgebra.mul!(C::CuMatrix, A::CuMatrix, B::Transpose{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) = + LinearAlgebra.mul!(C::CuMatrix, A::StridedCuMatrix, B::Transpose{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) = gemm_dispatch!(C, A, B, a, b) - LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:CuVecOrMat}, B::Transpose{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) = + LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:StridedCuVecOrMat}, B::Transpose{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) = gemm_dispatch!(C, A, B, a, b) - LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:CuVecOrMat}, B::CuMatrix, a::$NT, b::$NT) = + LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:StridedCuVecOrMat}, B::StridedCuMatrix, a::$NT, b::$NT) = gemm_dispatch!(C, A, B, a, b) - LinearAlgebra.mul!(C::CuMatrix, A::CuMatrix, B::Adjoint{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) = + LinearAlgebra.mul!(C::CuMatrix, A::StridedCuMatrix, B::Adjoint{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) = gemm_dispatch!(C, A, B, a, b) - LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:CuVecOrMat}, B::Adjoint{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) = + LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:StridedCuVecOrMat}, B::Adjoint{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) = gemm_dispatch!(C, A, B, a, b) - LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:CuVecOrMat}, B::Adjoint{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) = + LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:StridedCuVecOrMat}, B::Adjoint{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) = gemm_dispatch!(C, A, B, a, b) - LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:CuVecOrMat}, B::Transpose{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) = + LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:StridedCuVecOrMat}, B::Transpose{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) = gemm_dispatch!(C, A, B, a, b) end end diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 5b4d33f0e8..b2682cae8a 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -281,10 +281,10 @@ for (fname, elty) in ((:cublasDgemv_v2,:Float64), @eval begin function gemv!(trans::Char, alpha::Number, - A::CuMatrix{$elty}, - X::CuVector{$elty}, + A::StridedCuMatrix{$elty}, + X::StridedCuVector{$elty}, beta::Number, - Y::CuVector{$elty}) + Y::DenseCuVector{$elty}) # handle trans m,n = size(A) # check dimensions @@ -296,10 +296,10 @@ for (fname, elty) in ((:cublasDgemv_v2,:Float64), $fname(handle(), trans, m, n, alpha, A, lda, X, incx, beta, Y, incy) Y end - function gemv(trans::Char, alpha::Number, A::CuMatrix{$elty}, X::CuVector{$elty}) + function gemv(trans::Char, alpha::Number, A::StridedCuMatrix{$elty}, X::StridedCuVector{$elty}) gemv!(trans, alpha, A, X, zero($elty), similar(X, $elty, size(A, (trans == 'N' ? 1 : 2)))) end - function gemv(trans::Char, A::CuMatrix{$elty}, X::CuVector{$elty}) + function gemv(trans::Char, A::StridedCuMatrix{$elty}, X::StridedCuVector{$elty}) gemv!(trans, one($elty), A, X, zero($elty), similar(X, $elty, size(A, (trans == 'N' ? 1 : 2)))) end end @@ -708,10 +708,10 @@ for (fname, elty) in function gemm!(transA::Char, transB::Char, alpha::Number, - A::CuVecOrMat{$elty}, - B::CuVecOrMat{$elty}, + A::DenseCuVecOrMat{$elty}, + B::DenseCuVecOrMat{$elty}, beta::Number, - C::CuVecOrMat{$elty}) + C::DenseCuVecOrMat{$elty}) m = size(A, transA == 'N' ? 1 : 2) k = size(A, transA == 'N' ? 2 : 1) n = size(B, transB == 'N' ? 2 : 1) @@ -727,16 +727,16 @@ for (fname, elty) in function gemm(transA::Char, transB::Char, alpha::Number, - A::CuMatrix{$elty}, - B::CuMatrix{$elty}) + A::DenseCuMatrix{$elty}, + B::DenseCuMatrix{$elty}) gemm!(transA, transB, alpha, A, B, zero($elty), similar(B, $elty, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1)))) end function gemm(transA::Char, transB::Char, - A::CuMatrix{$elty}, - B::CuMatrix{$elty}) + A::DenseCuMatrix{$elty}, + B::DenseCuMatrix{$elty}) gemm(transA, transB, one($elty), A, B) end end @@ -810,10 +810,10 @@ end function gemmEx!(transA::Char, transB::Char, @nospecialize(alpha::Number), - @nospecialize(A::CuVecOrMat), - @nospecialize(B::CuVecOrMat), + @nospecialize(A::DenseCuVecOrMat), + @nospecialize(B::DenseCuVecOrMat), @nospecialize(beta::Number), - @nospecialize(C::CuVecOrMat); + @nospecialize(C::DenseCuVecOrMat); algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT) m = size(A, transA == 'N' ? 1 : 2) k = size(A, transA == 'N' ? 2 : 1) diff --git a/src/pointer.jl b/src/pointer.jl index 6ba077fdfa..ad9a2a8b19 100644 --- a/src/pointer.jl +++ b/src/pointer.jl @@ -70,6 +70,11 @@ function Base.unsafe_convert(::Type{CuPtr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{ Base._memory_offset(V.parent, map(first, V.indices)...) end +# from reshaped subarrays +function Base.unsafe_convert(::Type{CuPtr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{Base.RangeIndex,Base.ReshapedUnitRange}}}}) where {T,N,P} + return Base. unsafe_convert(CuPtr{T}, parent(V)) + + (Base.first_index(V)-1)*sizeof(T) +end ## limited pointer arithmetic & comparison diff --git a/test/cublas.jl b/test/cublas.jl index 9c8dbad540..fd0875cbb9 100644 --- a/test/cublas.jl +++ b/test/cublas.jl @@ -79,6 +79,7 @@ end dA = CuArray(A) @test_throws DimensionMismatch mul!(dy, dA, dx) end + @testset "mul! y = $f(A) * x * $Ts(a) + y * $Ts(b)" for f in (identity, transpose, adjoint), Ts in (Int, elty) y, A, x = rand(elty, 5), rand(elty, 5, 5), rand(elty, 5) dy, dA, dx = CuArray(y), CuArray(A), CuArray(x) @@ -419,6 +420,13 @@ end end end end + + @testset "gemv! with strided inputs" begin # JuliaGPU/CUDA.jl#445 + testf(rand(16), rand(4)) do p, b + W = @view p[reshape(1:(16),4,4)] + W*b + end + end end ############################################################################################ @@ -1284,6 +1292,14 @@ end @test C ≈ Array(dC) rtol=rtol end end + + @testset "gemm! with strided inputs" begin # JuliaGPU/CUDA.jl#78 + inn = 784; out = 32 + testf(randn(784*100), rand(Float32, 784, 100)) do p, x + p[reshape(1:(out*inn),out,inn)] * x + @view(p[reshape(1:(out*inn),out,inn)]) * x + end + end end ############################################################################################