2222#include < algorithm>
2323#include < vector>
2424
25- #include " kernels/impl/values_vectors_functor .h"
25+ #include " kernels/metax_context .h"
2626#include " paddle/phi/backends/gpu/gpu_context.h"
27- #include " paddle/phi/common/complex.h"
2827#include " paddle/phi/common/memory_utils.h"
2928#include " paddle/phi/core/enforce.h"
3029#include " paddle/phi/core/kernel_registry.h"
3938#include " paddle/phi/kernels/slice_kernel.h"
4039#include " paddle/phi/kernels/transpose_kernel.h"
4140#include " paddle/phi/kernels/tril_triu_kernel.h"
42-
4341namespace phi {
4442
4543template <class T , class Context >
@@ -358,47 +356,47 @@ void QrKernel(const Context& dev_ctx,
358356
359357#ifdef PADDLE_WITH_HIP
360358#define FUNC_WITH_TYPES (m ) m(float , s) m(double , d)
361- #define GEQRF_BATCH_INSTANCE (T, C ) \
362- template <> \
363- void BatchedGeqrf<GPUContext, T>(const GPUContext& dev_ctx, \
364- int batch_size, \
365- int m, \
366- int n, \
367- T* a, \
368- int lda, \
369- T* tau, \
370- int a_stride, \
371- int tau_stride) { \
372- auto handle = dev_ctx.cusolver_dn_handle (); \
373- for (int i = 0 ; i < batch_size; ++i) { \
374- T* a_working_ptr = &a[i * a_stride]; \
375- T* tau_working_ptr = &tau[i * tau_stride]; \
376- PADDLE_ENFORCE_GPU_SUCCESS (dynload::rocsolver_##C##geqrf ( \
377- handle, m, n, a_working_ptr, lda, tau_working_ptr)); \
378- } \
359+ #define GEQRF_BATCH_INSTANCE (T, C ) \
360+ template <> \
361+ void BatchedGeqrf<GPUContext, T>(const GPUContext& dev_ctx, \
362+ int batch_size, \
363+ int m, \
364+ int n, \
365+ T* a, \
366+ int lda, \
367+ T* tau, \
368+ int a_stride, \
369+ int tau_stride) { \
370+ auto handle = GetCusolverDnHandle ( dev_ctx.stream (), dev_ctx. GetPlace ()); \
371+ for (int i = 0 ; i < batch_size; ++i) { \
372+ T* a_working_ptr = &a[i * a_stride]; \
373+ T* tau_working_ptr = &tau[i * tau_stride]; \
374+ PADDLE_ENFORCE_GPU_SUCCESS (dynload::rocsolver_##C##geqrf ( \
375+ handle, m, n, a_working_ptr, lda, tau_working_ptr)); \
376+ } \
379377 }
380378
381379FUNC_WITH_TYPES (GEQRF_BATCH_INSTANCE);
382380
383- #define ORGQR_BATCH_INSTANCE (T, C ) \
384- template <> \
385- void BatchedOrgqr<GPUContext, T>(const GPUContext& dev_ctx, \
386- int batch_size, \
387- int m, \
388- int n, \
389- int k, \
390- T* a, \
391- int lda, \
392- T* tau, \
393- int a_stride, \
394- int tau_stride) { \
395- auto handle = dev_ctx.cusolver_dn_handle (); \
396- for (int i = 0 ; i < batch_size; ++i) { \
397- T* a_working_ptr = &a[i * a_stride]; \
398- T* tau_working_ptr = &tau[i * tau_stride]; \
399- PADDLE_ENFORCE_GPU_SUCCESS (dynload::rocsolver_##C##orgqr ( \
400- handle, m, n, k, a_working_ptr, lda, tau_working_ptr)); \
401- } \
381+ #define ORGQR_BATCH_INSTANCE (T, C ) \
382+ template <> \
383+ void BatchedOrgqr<GPUContext, T>(const GPUContext& dev_ctx, \
384+ int batch_size, \
385+ int m, \
386+ int n, \
387+ int k, \
388+ T* a, \
389+ int lda, \
390+ T* tau, \
391+ int a_stride, \
392+ int tau_stride) { \
393+ auto handle = GetCusolverDnHandle ( dev_ctx.stream (), dev_ctx. GetPlace ()); \
394+ for (int i = 0 ; i < batch_size; ++i) { \
395+ T* a_working_ptr = &a[i * a_stride]; \
396+ T* tau_working_ptr = &tau[i * tau_stride]; \
397+ PADDLE_ENFORCE_GPU_SUCCESS (dynload::rocsolver_##C##orgqr ( \
398+ handle, m, n, k, a_working_ptr, lda, tau_working_ptr)); \
399+ } \
402400 }
403401
404402FUNC_WITH_TYPES (ORGQR_BATCH_INSTANCE);
@@ -421,7 +419,6 @@ void BatchedGeqrf<GPUContext, float>(const GPUContext& dev_ctx,
421419 const int64_t a_stride_64 = static_cast <int64_t >(a_stride);
422420 const int64_t tau_stride_64 = static_cast <int64_t >(tau_stride);
423421
424- // auto handle = dev_ctx.cusolver_dn_handle();
425422 auto handle = GetCusolverDnHandle (dev_ctx.stream (), dev_ctx.GetPlace ());
426423
427424 size_t workspace_in_bytes_on_device = 0 ;
@@ -499,7 +496,6 @@ void BatchedGeqrf<GPUContext, float>(const GPUContext& dev_ctx,
499496 } else {
500497 int lwork = 0 ;
501498
502- // auto handle = dev_ctx.cusolver_dn_handle();
503499 auto handle = GetCusolverDnHandle (dev_ctx.stream (), dev_ctx.GetPlace ());
504500 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnSgeqrf_bufferSize (
505501 handle, m, n, a, lda, &lwork));
@@ -555,7 +551,6 @@ void BatchedGeqrf<GPUContext, double>(const GPUContext& dev_ctx,
555551 int tau_stride) {
556552 int lwork = 0 ;
557553
558- // auto handle = dev_ctx.cusolver_dn_handle();
559554 auto handle = GetCusolverDnHandle (dev_ctx.stream (), dev_ctx.GetPlace ());
560555 PADDLE_ENFORCE_GPU_SUCCESS (
561556 phi::dynload::cusolverDnDgeqrf_bufferSize (handle, m, n, a, lda, &lwork));
@@ -599,35 +594,33 @@ void BatchedGeqrf<GPUContext, double>(const GPUContext& dev_ctx,
599594}
600595
601596template <>
602- void BatchedGeqrf<GPUContext, phi::dtype::complex <float >>(
603- const GPUContext& dev_ctx,
604- int batch_size,
605- int m,
606- int n,
607- phi::dtype::complex <float >* a,
608- int lda,
609- phi::dtype::complex <float >* tau,
610- int a_stride,
611- int tau_stride) {
597+ void BatchedGeqrf<GPUContext, phi::complex64>(const GPUContext& dev_ctx,
598+ int batch_size,
599+ int m,
600+ int n,
601+ phi::complex64* a,
602+ int lda,
603+ phi::complex64* tau,
604+ int a_stride,
605+ int tau_stride) {
612606 int lwork = 0 ;
613607
614- // auto handle = dev_ctx.cusolver_dn_handle();
615608 auto handle = GetCusolverDnHandle (dev_ctx.stream (), dev_ctx.GetPlace ());
616609 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnCgeqrf_bufferSize (
617610 handle, m, n, reinterpret_cast <cuComplex*>(a), lda, &lwork));
618611
619612 DenseTensor workspace = DenseTensor ();
620613 workspace.Resize (common::make_ddim ({lwork}));
621- phi::dtype:: complex < float > * workspace_ptr =
622- dev_ctx.template Alloc <phi::dtype:: complex < float > >(&workspace);
614+ phi::complex64 * workspace_ptr =
615+ dev_ctx.template Alloc <phi::complex64 >(&workspace);
623616
624617 DenseTensor info = DenseTensor ();
625618 info.Resize (common::make_ddim ({1 }));
626619 int * info_d = dev_ctx.template Alloc <int >(&info);
627620
628621 for (int i = 0 ; i < batch_size; ++i) {
629- phi::dtype:: complex < float > * a_working_ptr = &a[i * a_stride];
630- phi::dtype:: complex < float > * tau_working_ptr = &tau[i * tau_stride];
622+ phi::complex64 * a_working_ptr = &a[i * a_stride];
623+ phi::complex64 * tau_working_ptr = &tau[i * tau_stride];
631624 // compute geqrf
632625 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnCgeqrf (
633626 handle,
@@ -657,35 +650,33 @@ void BatchedGeqrf<GPUContext, phi::dtype::complex<float>>(
657650}
658651
659652template <>
660- void BatchedGeqrf<GPUContext, phi::dtype::complex <double >>(
661- const GPUContext& dev_ctx,
662- int batch_size,
663- int m,
664- int n,
665- phi::dtype::complex <double >* a,
666- int lda,
667- phi::dtype::complex <double >* tau,
668- int a_stride,
669- int tau_stride) {
653+ void BatchedGeqrf<GPUContext, phi::complex128>(const GPUContext& dev_ctx,
654+ int batch_size,
655+ int m,
656+ int n,
657+ phi::complex128* a,
658+ int lda,
659+ phi::complex128* tau,
660+ int a_stride,
661+ int tau_stride) {
670662 int lwork = 0 ;
671663
672- // auto handle = dev_ctx.cusolver_dn_handle();
673664 auto handle = GetCusolverDnHandle (dev_ctx.stream (), dev_ctx.GetPlace ());
674665 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnZgeqrf_bufferSize (
675666 handle, m, n, reinterpret_cast <cuDoubleComplex*>(a), lda, &lwork));
676667
677668 DenseTensor workspace = DenseTensor ();
678669 workspace.Resize (common::make_ddim ({lwork}));
679- phi::dtype:: complex < double > * workspace_ptr =
680- dev_ctx.template Alloc <phi::dtype:: complex < double > >(&workspace);
670+ phi::complex128 * workspace_ptr =
671+ dev_ctx.template Alloc <phi::complex128 >(&workspace);
681672
682673 DenseTensor info = DenseTensor ();
683674 info.Resize (common::make_ddim ({1 }));
684675 int * info_d = dev_ctx.template Alloc <int >(&info);
685676
686677 for (int i = 0 ; i < batch_size; ++i) {
687- phi::dtype:: complex < double > * a_working_ptr = &a[i * a_stride];
688- phi::dtype:: complex < double > * tau_working_ptr = &tau[i * tau_stride];
678+ phi::complex128 * a_working_ptr = &a[i * a_stride];
679+ phi::complex128 * tau_working_ptr = &tau[i * tau_stride];
689680 // compute geqrf
690681 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnZgeqrf (
691682 handle,
@@ -727,7 +718,6 @@ void BatchedOrgqr<GPUContext, float>(const GPUContext& dev_ctx,
727718 int tau_stride) {
728719 int lwork = 0 ;
729720
730- // auto handle = dev_ctx.cusolver_dn_handle();
731721 auto handle = GetCusolverDnHandle (dev_ctx.stream (), dev_ctx.GetPlace ());
732722 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnSorgqr_bufferSize (
733723 handle, m, n, k, a, lda, tau, &lwork));
@@ -784,7 +774,6 @@ void BatchedOrgqr<GPUContext, double>(const GPUContext& dev_ctx,
784774 int tau_stride) {
785775 int lwork = 0 ;
786776
787- // auto handle = dev_ctx.cusolver_dn_handle();
788777 auto handle = GetCusolverDnHandle (dev_ctx.stream (), dev_ctx.GetPlace ());
789778 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnDorgqr_bufferSize (
790779 handle, m, n, k, a, lda, tau, &lwork));
@@ -829,20 +818,18 @@ void BatchedOrgqr<GPUContext, double>(const GPUContext& dev_ctx,
829818}
830819
831820template <>
832- void BatchedOrgqr<GPUContext, phi::dtype::complex <float >>(
833- const GPUContext& dev_ctx,
834- int batch_size,
835- int m,
836- int n,
837- int k,
838- phi::dtype::complex <float >* a,
839- int lda,
840- phi::dtype::complex <float >* tau,
841- int a_stride,
842- int tau_stride) {
821+ void BatchedOrgqr<GPUContext, phi::complex64>(const GPUContext& dev_ctx,
822+ int batch_size,
823+ int m,
824+ int n,
825+ int k,
826+ phi::complex64* a,
827+ int lda,
828+ phi::complex64* tau,
829+ int a_stride,
830+ int tau_stride) {
843831 int lwork = 0 ;
844832
845- // auto handle = dev_ctx.cusolver_dn_handle();
846833 auto handle = GetCusolverDnHandle (dev_ctx.stream (), dev_ctx.GetPlace ());
847834 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnCungqr_bufferSize (
848835 handle,
@@ -856,16 +843,16 @@ void BatchedOrgqr<GPUContext, phi::dtype::complex<float>>(
856843
857844 DenseTensor workspace = DenseTensor ();
858845 workspace.Resize (common::make_ddim ({lwork}));
859- phi::dtype:: complex < float > * workspace_ptr =
860- dev_ctx.template Alloc <phi::dtype:: complex < float > >(&workspace);
846+ phi::complex64 * workspace_ptr =
847+ dev_ctx.template Alloc <phi::complex64 >(&workspace);
861848
862849 DenseTensor info = DenseTensor ();
863850 info.Resize (common::make_ddim ({1 }));
864851 int * info_d = dev_ctx.template Alloc <int >(&info);
865852
866853 for (int i = 0 ; i < batch_size; ++i) {
867- phi::dtype:: complex < float > * a_working_ptr = &a[i * a_stride];
868- phi::dtype:: complex < float > * tau_working_ptr = &tau[i * tau_stride];
854+ phi::complex64 * a_working_ptr = &a[i * a_stride];
855+ phi::complex64 * tau_working_ptr = &tau[i * tau_stride];
869856 // compute orggr
870857 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnCungqr (
871858 handle,
@@ -896,20 +883,18 @@ void BatchedOrgqr<GPUContext, phi::dtype::complex<float>>(
896883}
897884
898885template <>
899- void BatchedOrgqr<GPUContext, phi::dtype::complex <double >>(
900- const GPUContext& dev_ctx,
901- int batch_size,
902- int m,
903- int n,
904- int k,
905- phi::dtype::complex <double >* a,
906- int lda,
907- phi::dtype::complex <double >* tau,
908- int a_stride,
909- int tau_stride) {
886+ void BatchedOrgqr<GPUContext, phi::complex128>(const GPUContext& dev_ctx,
887+ int batch_size,
888+ int m,
889+ int n,
890+ int k,
891+ phi::complex128* a,
892+ int lda,
893+ phi::complex128* tau,
894+ int a_stride,
895+ int tau_stride) {
910896 int lwork = 0 ;
911897
912- // auto handle = dev_ctx.cusolver_dn_handle();
913898 auto handle = GetCusolverDnHandle (dev_ctx.stream (), dev_ctx.GetPlace ());
914899 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnZungqr_bufferSize (
915900 handle,
@@ -923,16 +908,16 @@ void BatchedOrgqr<GPUContext, phi::dtype::complex<double>>(
923908
924909 DenseTensor workspace = DenseTensor ();
925910 workspace.Resize (common::make_ddim ({lwork}));
926- phi::dtype:: complex < double > * workspace_ptr =
927- dev_ctx.template Alloc <phi::dtype:: complex < double > >(&workspace);
911+ phi::complex128 * workspace_ptr =
912+ dev_ctx.template Alloc <phi::complex128 >(&workspace);
928913
929914 DenseTensor info = DenseTensor ();
930915 info.Resize (common::make_ddim ({1 }));
931916 int * info_d = dev_ctx.template Alloc <int >(&info);
932917
933918 for (int i = 0 ; i < batch_size; ++i) {
934- phi::dtype:: complex < double > * a_working_ptr = &a[i * a_stride];
935- phi::dtype:: complex < double > * tau_working_ptr = &tau[i * tau_stride];
919+ phi::complex128 * a_working_ptr = &a[i * a_stride];
920+ phi::complex128 * tau_working_ptr = &tau[i * tau_stride];
936921 // compute orggr
937922 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnZungqr (
938923 handle,
@@ -965,11 +950,15 @@ void BatchedOrgqr<GPUContext, phi::dtype::complex<double>>(
965950
966951} // namespace phi
967952
953+ #ifdef PADDLE_WITH_HIP
954+ PD_REGISTER_KERNEL (qr, GPU, ALL_LAYOUT, phi::QrKernel, float , double ) {}
955+ #else
968956PD_REGISTER_PLUGIN_KERNEL (qr,
969957 metax_gpu,
970958 ALL_LAYOUT,
971959 phi::QrKernel,
972960 float ,
973961 double ,
974- phi::dtype::complex <float >,
975- phi::dtype::complex <double >) {}
962+ phi::complex64,
963+ phi::complex128) {}
964+ #endif
0 commit comments