Skip to content

Commit 0692e58

Browse files
author
Michael Abbott
committed
work via _gemm_strided_batched
1 parent 03744a6 commit 0692e58

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

lib/cublas/wrappers.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -923,10 +923,20 @@ for (fname, elty) in
923923
function gemm_strided_batched!(transA::Char,
924924
transB::Char,
925925
alpha::Number,
926-
A::AbstractArray{$elty, 3},
926+
A::DenseCuArray{$elty, 3},
927+
B::DenseCuArray{$elty, 3},
928+
beta::Number,
929+
C::DenseCuArray{$elty, 3})
930+
_gemm_strided_batched(transA, transB, alpha, A, B, beta, C)
931+
end
932+
function _gemm_strided_batched!(transA::Char,
933+
transB::Char,
934+
alpha::Number,
935+
A::AbstractArray{$elty, 3}, # allows PermutedDimsArray
927936
B::AbstractArray{$elty, 3},
928937
beta::Number,
929938
C::AbstractArray{$elty, 3})
939+
930940
m = size(A, transA == 'N' ? 1 : 2)
931941
k = size(A, transA == 'N' ? 2 : 1)
932942
n = size(B, transB == 'N' ? 2 : 1)
@@ -952,15 +962,15 @@ for (fname, elty) in
952962
function gemm_strided_batched(transA::Char,
953963
transB::Char,
954964
alpha::Number,
955-
A::AbstractArray{$elty, 3},
956-
B::AbstractArray{$elty, 3})
965+
A::DenseCuArray{$elty, 3},
966+
B::DenseCuArray{$elty, 3})
957967
C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3))))
958968
gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C )
959969
end
960970
function gemm_strided_batched(transA::Char,
961971
transB::Char,
962-
A::AbstractArray{$elty, 3},
963-
B::AbstractArray{$elty, 3})
972+
A::DenseCuArray{$elty, 3},
973+
B::DenseCuArray{$elty, 3})
964974
gemm_strided_batched(transA, transB, one($elty), A, B)
965975
end
966976
end

src/nnlib.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,5 @@ end
2323

2424

2525
# Batched matrix multiplication
26-
# Using storage_type from https://github.com/FluxML/NNlib.jl/pull/191
27-
28-
NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) =
29-
CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C)
26+
NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) =
27+
CUBLAS._gemm_strided_batched!(transA, transB, α, A, B, β, C)

0 commit comments

Comments
 (0)