Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 58 additions & 42 deletions lib/cublas/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# interfacing with LinearAlgebra standard library

cublas_size(t::Char, M::CuVecOrMat) = (size(M, t=='N' ? 1 : 2), size(M, t=='N' ? 2 : 1))



#
# BLAS 1
Expand Down Expand Up @@ -74,40 +71,52 @@ end

# GEMV

function gemv_wrapper!(y::CuVector{T}, tA::Char, A::CuMatrix{T}, x::CuVector{T},
alpha::Number = true, beta::Number = false) where T<:CublasFloat
mA, nA = cublas_size(tA, A)
if nA != length(x)
throw(DimensionMismatch("second dimension of A, $nA, does not match length of x, $(length(x))"))
function gemv_dispatch!(Y::CuVector, A, B, alpha::Number=true, beta::Number=false)
mA, nA = size(A)

if nA != length(B)
throw(DimensionMismatch("second dimension of A, $nA, does not match length of B, $(length(B))"))
end
if mA != length(y)
throw(DimensionMismatch("first dimension of A, $mA, does not match length of y, $(length(y))"))

if mA != length(Y)
throw(DimensionMismatch("first dimension of A, $mA, does not match length of Y, $(length(Y))"))
end

if mA == 0
return y
return Y
end

if nA == 0
return rmul!(y, 0)
return rmul!(Y, 0)
end

tA, dA = if A isa Transpose
'T', parent(A)
elseif A isa Adjoint
'C', parent(A)
else
'N', A
end

T = eltype(Y)
if T <: CublasFloat && A isa StridedCuArray{T} && B isa StridedCuArray{T}
gemv!(tA, alpha, dA, B, beta, Y)
else
gemm_dispatch!(Y, A, B, alpha, beta)
end
gemv!(tA, alpha, A, x, beta, y)
end

LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}, a::Number, b::Number) where T<:CublasFloat =
gemv_wrapper!(Y, 'N', A, B, a, b)
LinearAlgebra.mul!(Y::CuVector{T}, A::Transpose{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasFloat =
gemv_wrapper!(Y, 'T', A.parent, B, a, b)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Real, b::Real) where T<:CublasReal =
gemv_wrapper!(Y, 'T', A.parent, B, a, b)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Number, b::Number) where T<:CublasComplex =
gemv_wrapper!(Y, 'C', A.parent, B, a, b)

# ambiguity hacks: Base and GPUArrays has mul! with a::Real, b::Real
LinearAlgebra.mul!(Y::CuVector{T}, A::CuMatrix{T}, B::CuVector{T}, a::Real, b::Real) where T<:CublasFloat =
gemv_wrapper!(Y, 'N', A, B, a, b)
LinearAlgebra.mul!(Y::CuVector{T}, A::Transpose{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Real, b::Real) where T<:CublasFloat =
gemv_wrapper!(Y, 'T', A.parent, B, a, b)
LinearAlgebra.mul!(Y::CuVector{T}, A::Adjoint{<:Any, <:CuVecOrMat{T}}, B::CuVector{T}, a::Real, b::Real) where T<:CublasComplex =
gemv_wrapper!(Y, 'C', A.parent, B, a, b)
for NT in (Number, Real)
# NOTE: alpha/beta also ::Real to avoid ambiguities with certain Base methods
@eval begin
LinearAlgebra.mul!(Y::CuVector, A::StridedCuMatrix, B::StridedCuVector, a::$NT, b::$NT) =
gemv_dispatch!(Y, A, B, a, b)
LinearAlgebra.mul!(Y::CuVector, A::Transpose{<:Any, <:StridedCuVecOrMat}, B::StridedCuVector, a::$NT, b::$NT) =
gemv_dispatch!(Y, A, B, a, b)
LinearAlgebra.mul!(Y::CuVector, A::Adjoint{<:Any, <:StridedCuVecOrMat}, B::StridedCuVector, a::$NT, b::$NT) =
gemv_dispatch!(Y, A, B, a, b)
end
end

# TRSV

Expand Down Expand Up @@ -162,8 +171,13 @@ end
# GEMM

function gemm_dispatch!(C::CuVecOrMat, A, B, alpha::Number=true, beta::Number=false)
mA, nA = size(A)
mB, nB = size(B)
if ndims(A) > 2
throw(ArgumentError("A has more than 2 dimensions"))
elseif ndims(B) > 2
throw(ArgumentError("B has more than 2 dimensions"))
end
mA, nA = size(A,1), size(A,2)
mB, nB = size(B,1), size(B,2)

if nA != mB
throw(DimensionMismatch("A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
Expand Down Expand Up @@ -196,9 +210,11 @@ function gemm_dispatch!(C::CuVecOrMat, A, B, alpha::Number=true, beta::Number=fa
'N', B
end

if gemmExComputeType(eltype(A), eltype(B), eltype(C), mA, nA, nB) !== nothing
T = eltype(C)
if dA isa DenseCuArray && dB isa DenseCuArray &&
gemmExComputeType(eltype(A), eltype(B), eltype(C), mA, nA, nB) !== nothing
gemmEx!(tA, tB, alpha, dA, dB, beta, C)
elseif eltype(A) === eltype(B) === eltype(C) && eltype(A) <: CublasFloat
elseif T <: CublasFloat && dA isa DenseCuArray{T} && dB isa DenseCuArray{T}
gemm!(tA, tB, alpha, dA, dB, beta, C)
else
GPUArrays.generic_matmatmul!(C, A, B, alpha, beta)
Expand All @@ -208,26 +224,26 @@ end
for NT in (Number, Real)
# NOTE: alpha/beta also ::Real to avoid ambiguities with certain Base methods
@eval begin
LinearAlgebra.mul!(C::CuMatrix, A::CuVecOrMat, B::CuVecOrMat, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::StridedCuVecOrMat, B::StridedCuVecOrMat, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)

LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:CuVecOrMat}, B::CuMatrix, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:StridedCuVecOrMat}, B::StridedCuMatrix, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)
LinearAlgebra.mul!(C::CuMatrix, A::CuMatrix, B::Transpose{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::StridedCuMatrix, B::Transpose{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)
LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:CuVecOrMat}, B::Transpose{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:StridedCuVecOrMat}, B::Transpose{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)

LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:CuVecOrMat}, B::CuMatrix, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:StridedCuVecOrMat}, B::StridedCuMatrix, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)
LinearAlgebra.mul!(C::CuMatrix, A::CuMatrix, B::Adjoint{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::StridedCuMatrix, B::Adjoint{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)
LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:CuVecOrMat}, B::Adjoint{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:StridedCuVecOrMat}, B::Adjoint{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)

LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:CuVecOrMat}, B::Adjoint{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::Transpose{<:Any, <:StridedCuVecOrMat}, B::Adjoint{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)
LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:CuVecOrMat}, B::Transpose{<:Any, <:CuVecOrMat}, a::$NT, b::$NT) =
LinearAlgebra.mul!(C::CuMatrix, A::Adjoint{<:Any, <:StridedCuVecOrMat}, B::Transpose{<:Any, <:StridedCuVecOrMat}, a::$NT, b::$NT) =
gemm_dispatch!(C, A, B, a, b)
end
end
Expand Down
30 changes: 15 additions & 15 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,10 @@ for (fname, elty) in ((:cublasDgemv_v2,:Float64),
@eval begin
function gemv!(trans::Char,
alpha::Number,
A::CuMatrix{$elty},
X::CuVector{$elty},
A::StridedCuMatrix{$elty},
X::StridedCuVector{$elty},
beta::Number,
Y::CuVector{$elty})
Y::DenseCuVector{$elty})
# handle trans
m,n = size(A)
# check dimensions
Expand All @@ -296,10 +296,10 @@ for (fname, elty) in ((:cublasDgemv_v2,:Float64),
$fname(handle(), trans, m, n, alpha, A, lda, X, incx, beta, Y, incy)
Y
end
function gemv(trans::Char, alpha::Number, A::CuMatrix{$elty}, X::CuVector{$elty})
function gemv(trans::Char, alpha::Number, A::StridedCuMatrix{$elty}, X::StridedCuVector{$elty})
gemv!(trans, alpha, A, X, zero($elty), similar(X, $elty, size(A, (trans == 'N' ? 1 : 2))))
end
function gemv(trans::Char, A::CuMatrix{$elty}, X::CuVector{$elty})
function gemv(trans::Char, A::StridedCuMatrix{$elty}, X::StridedCuVector{$elty})
gemv!(trans, one($elty), A, X, zero($elty), similar(X, $elty, size(A, (trans == 'N' ? 1 : 2))))
end
end
Expand Down Expand Up @@ -708,10 +708,10 @@ for (fname, elty) in
function gemm!(transA::Char,
transB::Char,
alpha::Number,
A::CuVecOrMat{$elty},
B::CuVecOrMat{$elty},
A::DenseCuVecOrMat{$elty},
B::DenseCuVecOrMat{$elty},
beta::Number,
C::CuVecOrMat{$elty})
C::DenseCuVecOrMat{$elty})
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
n = size(B, transB == 'N' ? 2 : 1)
Expand All @@ -727,16 +727,16 @@ for (fname, elty) in
function gemm(transA::Char,
transB::Char,
alpha::Number,
A::CuMatrix{$elty},
B::CuMatrix{$elty})
A::DenseCuMatrix{$elty},
B::DenseCuMatrix{$elty})
gemm!(transA, transB, alpha, A, B, zero($elty),
similar(B, $elty, (size(A, transA == 'N' ? 1 : 2),
size(B, transB == 'N' ? 2 : 1))))
end
function gemm(transA::Char,
transB::Char,
A::CuMatrix{$elty},
B::CuMatrix{$elty})
A::DenseCuMatrix{$elty},
B::DenseCuMatrix{$elty})
gemm(transA, transB, one($elty), A, B)
end
end
Expand Down Expand Up @@ -810,10 +810,10 @@ end

function gemmEx!(transA::Char, transB::Char,
@nospecialize(alpha::Number),
@nospecialize(A::CuVecOrMat),
@nospecialize(B::CuVecOrMat),
@nospecialize(A::DenseCuVecOrMat),
@nospecialize(B::DenseCuVecOrMat),
@nospecialize(beta::Number),
@nospecialize(C::CuVecOrMat);
@nospecialize(C::DenseCuVecOrMat);
algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT)
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
Expand Down
5 changes: 5 additions & 0 deletions src/pointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ function Base.unsafe_convert(::Type{CuPtr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{
Base._memory_offset(V.parent, map(first, V.indices)...)
end

# from reshaped subarrays
function Base.unsafe_convert(::Type{CuPtr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{Base.RangeIndex,Base.ReshapedUnitRange}}}}) where {T,N,P}
return Base. unsafe_convert(CuPtr{T}, parent(V)) +
(Base.first_index(V)-1)*sizeof(T)
end

## limited pointer arithmetic & comparison

Expand Down
16 changes: 16 additions & 0 deletions test/cublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ end
dA = CuArray(A)
@test_throws DimensionMismatch mul!(dy, dA, dx)
end

@testset "mul! y = $f(A) * x * $Ts(a) + y * $Ts(b)" for f in (identity, transpose, adjoint), Ts in (Int, elty)
y, A, x = rand(elty, 5), rand(elty, 5, 5), rand(elty, 5)
dy, dA, dx = CuArray(y), CuArray(A), CuArray(x)
Expand Down Expand Up @@ -419,6 +420,13 @@ end
end
end
end

@testset "gemv! with strided inputs" begin # JuliaGPU/CUDA.jl#445
testf(rand(16), rand(4)) do p, b
W = @view p[reshape(1:(16),4,4)]
W*b
end
end
end

############################################################################################
Expand Down Expand Up @@ -1284,6 +1292,14 @@ end
@test C ≈ Array(dC) rtol=rtol
end
end

@testset "gemm! with strided inputs" begin # JuliaGPU/CUDA.jl#78
inn = 784; out = 32
testf(randn(784*100), rand(Float32, 784, 100)) do p, x
p[reshape(1:(out*inn),out,inn)] * x
@view(p[reshape(1:(out*inn),out,inn)]) * x
end
end
end

############################################################################################
Expand Down