@@ -923,10 +923,20 @@ for (fname, elty) in
923923 function gemm_strided_batched! (transA:: Char ,
924924 transB:: Char ,
925925 alpha:: Number ,
926- A:: AbstractArray{$elty, 3} ,
926+ A:: DenseCuArray{$elty, 3} ,
927+ B:: DenseCuArray{$elty, 3} ,
928+ beta:: Number ,
929+ C:: DenseCuArray{$elty, 3} )
930+ _gemm_strided_batched (transA, transB, alpha, A, B, beta, C)
931+ end
932+ function _gemm_strided_batched! (transA:: Char ,
933+ transB:: Char ,
934+ alpha:: Number ,
935+ A:: AbstractArray{$elty, 3} , # allows PermutedDimsArray
927936 B:: AbstractArray{$elty, 3} ,
928937 beta:: Number ,
929938 C:: AbstractArray{$elty, 3} )
939+
930940 m = size (A, transA == ' N' ? 1 : 2 )
931941 k = size (A, transA == ' N' ? 2 : 1 )
932942 n = size (B, transB == ' N' ? 2 : 1 )
@@ -952,15 +962,15 @@ for (fname, elty) in
952962 function gemm_strided_batched (transA:: Char ,
953963 transB:: Char ,
954964 alpha:: Number ,
955- A:: AbstractArray {$elty, 3} ,
956- B:: AbstractArray {$elty, 3} )
965+ A:: DenseCuArray {$elty, 3} ,
966+ B:: DenseCuArray {$elty, 3} )
957967 C = similar (B, (size (A, transA == ' N' ? 1 : 2 ), size (B, transB == ' N' ? 2 : 1 ), max (size (A, 3 ), size (B, 3 ))))
958968 gemm_strided_batched! (transA, transB, alpha, A, B, zero ($ elty), C )
959969 end
960970 function gemm_strided_batched (transA:: Char ,
961971 transB:: Char ,
962- A:: AbstractArray {$elty, 3} ,
963- B:: AbstractArray {$elty, 3} )
972+ A:: DenseCuArray {$elty, 3} ,
973+ B:: DenseCuArray {$elty, 3} )
964974 gemm_strided_batched (transA, transB, one ($ elty), A, B)
965975 end
966976 end
0 commit comments