@@ -120,6 +120,11 @@ struct CUBlas<float> {
120120 PADDLE_ENFORCE_CUDA_SUCCESS (
121121 platform::dynload::cublasSgetrsBatched (args...));
122122 }
123+
124+ template <typename ... ARGS>
125+ static void TRSM_BATCH (ARGS... args) {
126+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasStrsmBatched (args...));
127+ }
123128};
124129
125130template <>
@@ -194,6 +199,11 @@ struct CUBlas<double> {
194199 PADDLE_ENFORCE_CUDA_SUCCESS (
195200 platform::dynload::cublasDgetrsBatched (args...));
196201 }
202+
203+ template <typename ... ARGS>
204+ static void TRSM_BATCH (ARGS... args) {
205+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasDtrsmBatched (args...));
206+ }
197207};
198208
199209template <>
@@ -339,6 +349,19 @@ struct CUBlas<platform::complex<float>> {
339349 reinterpret_cast <cuFloatComplex *>(C), ldc));
340350 }
341351
352+ static void TRSM (cublasHandle_t handle, cublasSideMode_t side,
353+ cublasFillMode_t uplo, cublasOperation_t transa,
354+ cublasDiagType_t diag, int m, int n,
355+ const paddle::platform::complex <float > *alpha,
356+ const paddle::platform::complex <float > *A, int lda,
357+ paddle::platform::complex <float > *B, int ldb) {
358+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasCtrsm (
359+ handle, side, uplo, transa, diag, m, n,
360+ reinterpret_cast <const cuFloatComplex *>(alpha),
361+ reinterpret_cast <const cuFloatComplex *>(A), lda,
362+ reinterpret_cast <cuFloatComplex *>(B), ldb));
363+ }
364+
342365 // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
343366 // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
344367 template <typename ... ARGS>
@@ -370,6 +393,20 @@ struct CUBlas<platform::complex<float>> {
370393 " cublasGemmEx is not supported on cuda <= 7.5" ));
371394#endif
372395 }
396+
397+ static void TRSM_BATCH (cublasHandle_t handle, cublasSideMode_t side,
398+ cublasFillMode_t uplo, cublasOperation_t transa,
399+ cublasDiagType_t diag, int m, int n,
400+ const paddle::platform::complex <float > *alpha,
401+ const paddle::platform::complex <float > **A, int lda,
402+ paddle::platform::complex <float > **B, int ldb,
403+ int batch_size) {
404+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasCtrsmBatched (
405+ handle, side, uplo, transa, diag, m, n,
406+ reinterpret_cast <const cuFloatComplex *>(alpha),
407+ reinterpret_cast <const cuFloatComplex **>(A), lda,
408+ reinterpret_cast <cuFloatComplex **>(B), ldb, batch_size));
409+ }
373410};
374411
375412template <>
@@ -440,6 +477,33 @@ struct CUBlas<platform::complex<double>> {
440477 reinterpret_cast <cuDoubleComplex *>(C), ldc));
441478 }
442479
480+ static void TRSM (cublasHandle_t handle, cublasSideMode_t side,
481+ cublasFillMode_t uplo, cublasOperation_t transa,
482+ cublasDiagType_t diag, int m, int n,
483+ const paddle::platform::complex <double > *alpha,
484+ const paddle::platform::complex <double > *A, int lda,
485+ paddle::platform::complex <double > *B, int ldb) {
486+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasZtrsm (
487+ handle, side, uplo, transa, diag, m, n,
488+ reinterpret_cast <const cuDoubleComplex *>(alpha),
489+ reinterpret_cast <const cuDoubleComplex *>(A), lda,
490+ reinterpret_cast <cuDoubleComplex *>(B), ldb));
491+ }
492+
493+ static void TRSM_BATCH (cublasHandle_t handle, cublasSideMode_t side,
494+ cublasFillMode_t uplo, cublasOperation_t transa,
495+ cublasDiagType_t diag, int m, int n,
496+ const paddle::platform::complex <double > *alpha,
497+ const paddle::platform::complex <double > **A, int lda,
498+ paddle::platform::complex <double > **B, int ldb,
499+ int batch_size) {
500+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasZtrsmBatched (
501+ handle, side, uplo, transa, diag, m, n,
502+ reinterpret_cast <const cuDoubleComplex *>(alpha),
503+ reinterpret_cast <const cuDoubleComplex **>(A), lda,
504+ reinterpret_cast <cuDoubleComplex **>(B), ldb, batch_size));
505+ }
506+
443507 // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
444508 // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
445509 template <typename ... ARGS>
@@ -897,6 +961,30 @@ void Blas<platform::CUDADeviceContext>::BatchedGETRS(
897961 });
898962}
899963
964+ template <>
965+ template <typename T>
966+ void Blas<platform::CUDADeviceContext>::BatchedTRSM (
967+ CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, CBLAS_DIAG diag,
968+ int M, int N, T alpha, const T **A, int lda, T **B, int ldb,
969+ int batch_size) const {
970+ // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'`
971+ // where ' stands for transpose
972+ cublasSideMode_t cuSide =
973+ (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
974+ cublasFillMode_t cuUplo =
975+ (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
976+ // use CUBLAS_OP_C (conjugate transpose) for complex
977+ cublasOperation_t cuTransA =
978+ (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
979+ cublasDiagType_t cuDiag =
980+ (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
981+
982+ context_.CublasCall ([&](cublasHandle_t handle) {
983+ CUBlas<T>::TRSM_BATCH (handle, cuSide, cuUplo, cuTransA, cuDiag, N, M,
984+ &alpha, A, lda, B, ldb, batch_size);
985+ });
986+ }
987+
900988} // namespace math
901989} // namespace operators
902990} // namespace paddle
0 commit comments