diff --git a/backends/metax_gpu/kernels/metax_kernel/qr_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/qr_kernel_register.cu index 7b133371f4d..745069e2eda 100644 --- a/backends/metax_gpu/kernels/metax_kernel/qr_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/qr_kernel_register.cu @@ -22,9 +22,8 @@ #include #include -#include "kernels/impl/values_vectors_functor.h" +#include "kernels/metax_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/complex.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" @@ -39,7 +38,6 @@ #include "paddle/phi/kernels/slice_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/tril_triu_kernel.h" - namespace phi { template @@ -358,47 +356,47 @@ void QrKernel(const Context& dev_ctx, #ifdef PADDLE_WITH_HIP #define FUNC_WITH_TYPES(m) m(float, s) m(double, d) -#define GEQRF_BATCH_INSTANCE(T, C) \ - template <> \ - void BatchedGeqrf(const GPUContext& dev_ctx, \ - int batch_size, \ - int m, \ - int n, \ - T* a, \ - int lda, \ - T* tau, \ - int a_stride, \ - int tau_stride) { \ - auto handle = dev_ctx.cusolver_dn_handle(); \ - for (int i = 0; i < batch_size; ++i) { \ - T* a_working_ptr = &a[i * a_stride]; \ - T* tau_working_ptr = &tau[i * tau_stride]; \ - PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_##C##geqrf( \ - handle, m, n, a_working_ptr, lda, tau_working_ptr)); \ - } \ +#define GEQRF_BATCH_INSTANCE(T, C) \ + template <> \ + void BatchedGeqrf(const GPUContext& dev_ctx, \ + int batch_size, \ + int m, \ + int n, \ + T* a, \ + int lda, \ + T* tau, \ + int a_stride, \ + int tau_stride) { \ + auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); \ + for (int i = 0; i < batch_size; ++i) { \ + T* a_working_ptr = &a[i * a_stride]; \ + T* tau_working_ptr = &tau[i * tau_stride]; \ + PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_##C##geqrf( \ + handle, m, n, a_working_ptr, lda, tau_working_ptr)); \ + } \ } FUNC_WITH_TYPES(GEQRF_BATCH_INSTANCE); -#define ORGQR_BATCH_INSTANCE(T, C) \ - template <> \ - void BatchedOrgqr(const GPUContext& dev_ctx, \ - int batch_size, \ - int m, \ - int n, \ - int k, \ - T* a, \ - int lda, \ - T* tau, \ - int a_stride, \ - int tau_stride) { \ - auto handle = dev_ctx.cusolver_dn_handle(); \ - for (int i = 0; i < batch_size; ++i) { \ - T* a_working_ptr = &a[i * a_stride]; \ - T* tau_working_ptr = &tau[i * tau_stride]; \ - PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_##C##orgqr( \ - handle, m, n, k, a_working_ptr, lda, tau_working_ptr)); \ - } \ +#define ORGQR_BATCH_INSTANCE(T, C) \ + template <> \ + void BatchedOrgqr(const GPUContext& dev_ctx, \ + int batch_size, \ + int m, \ + int n, \ + int k, \ + T* a, \ + int lda, \ + T* tau, \ + int a_stride, \ + int tau_stride) { \ + auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); \ + for (int i = 0; i < batch_size; ++i) { \ + T* a_working_ptr = &a[i * a_stride]; \ + T* tau_working_ptr = &tau[i * tau_stride]; \ + PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_##C##orgqr( \ + handle, m, n, k, a_working_ptr, lda, tau_working_ptr)); \ + } \ } FUNC_WITH_TYPES(ORGQR_BATCH_INSTANCE); @@ -421,7 +419,6 @@ void BatchedGeqrf(const GPUContext& dev_ctx, const int64_t a_stride_64 = static_cast(a_stride); const int64_t tau_stride_64 = static_cast(tau_stride); - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); size_t workspace_in_bytes_on_device = 0; @@ -499,7 +496,6 @@ void BatchedGeqrf(const GPUContext& dev_ctx, } else { int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSgeqrf_bufferSize( handle, m, n, a, lda, &lwork)); @@ -555,7 +551,6 @@ void BatchedGeqrf(const GPUContext& dev_ctx, int tau_stride) { int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cusolverDnDgeqrf_bufferSize(handle, m, n, a, lda, &lwork)); @@ -599,35 +594,33 @@ void BatchedGeqrf(const GPUContext& dev_ctx, } template <> -void BatchedGeqrf>( - const GPUContext& dev_ctx, - int batch_size, - int m, - int n, - phi::dtype::complex* a, - int lda, - phi::dtype::complex* tau, - int a_stride, - int tau_stride) { +void BatchedGeqrf(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + phi::complex64* a, + int lda, + phi::complex64* tau, + int a_stride, + int tau_stride) { int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgeqrf_bufferSize( handle, m, n, reinterpret_cast(a), lda, &lwork)); DenseTensor workspace = DenseTensor(); workspace.Resize(common::make_ddim({lwork})); - phi::dtype::complex* workspace_ptr = - dev_ctx.template Alloc>(&workspace); + phi::complex64* workspace_ptr = + dev_ctx.template Alloc(&workspace); DenseTensor info = DenseTensor(); info.Resize(common::make_ddim({1})); int* info_d = dev_ctx.template Alloc(&info); for (int i = 0; i < batch_size; ++i) { - phi::dtype::complex* a_working_ptr = &a[i * a_stride]; - phi::dtype::complex* tau_working_ptr = &tau[i * tau_stride]; + phi::complex64* a_working_ptr = &a[i * a_stride]; + phi::complex64* tau_working_ptr = &tau[i * tau_stride]; // compute geqrf PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgeqrf( handle, @@ -657,35 +650,33 @@ void BatchedGeqrf>( } template <> -void BatchedGeqrf>( - const GPUContext& dev_ctx, - int batch_size, - int m, - int n, - phi::dtype::complex* a, - int lda, - phi::dtype::complex* tau, - int a_stride, - int tau_stride) { +void BatchedGeqrf(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + phi::complex128* a, + int lda, + phi::complex128* tau, + int a_stride, + int tau_stride) { int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgeqrf_bufferSize( handle, m, n, reinterpret_cast(a), lda, &lwork)); DenseTensor workspace = DenseTensor(); workspace.Resize(common::make_ddim({lwork})); - phi::dtype::complex* workspace_ptr = - dev_ctx.template Alloc>(&workspace); + phi::complex128* workspace_ptr = + dev_ctx.template Alloc(&workspace); DenseTensor info = DenseTensor(); info.Resize(common::make_ddim({1})); int* info_d = dev_ctx.template Alloc(&info); for (int i = 0; i < batch_size; ++i) { - phi::dtype::complex* a_working_ptr = &a[i * a_stride]; - phi::dtype::complex* tau_working_ptr = &tau[i * tau_stride]; + phi::complex128* a_working_ptr = &a[i * a_stride]; + phi::complex128* tau_working_ptr = &tau[i * tau_stride]; // compute geqrf PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgeqrf( handle, @@ -727,7 +718,6 @@ void BatchedOrgqr(const GPUContext& dev_ctx, int tau_stride) { int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr_bufferSize( handle, m, n, k, a, lda, tau, &lwork)); @@ -784,7 +774,6 @@ void BatchedOrgqr(const GPUContext& dev_ctx, int tau_stride) { int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr_bufferSize( handle, m, n, k, a, lda, tau, &lwork)); @@ -829,20 +818,18 @@ void BatchedOrgqr(const GPUContext& dev_ctx, } template <> -void BatchedOrgqr>( - const GPUContext& dev_ctx, - int batch_size, - int m, - int n, - int k, - phi::dtype::complex* a, - int lda, - phi::dtype::complex* tau, - int a_stride, - int tau_stride) { +void BatchedOrgqr(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + int k, + phi::complex64* a, + int lda, + phi::complex64* tau, + int a_stride, + int tau_stride) { int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCungqr_bufferSize( handle, @@ -856,16 +843,16 @@ void BatchedOrgqr>( DenseTensor workspace = DenseTensor(); workspace.Resize(common::make_ddim({lwork})); - phi::dtype::complex* workspace_ptr = - dev_ctx.template Alloc>(&workspace); + phi::complex64* workspace_ptr = + dev_ctx.template Alloc(&workspace); DenseTensor info = DenseTensor(); info.Resize(common::make_ddim({1})); int* info_d = dev_ctx.template Alloc(&info); for (int i = 0; i < batch_size; ++i) { - phi::dtype::complex* a_working_ptr = &a[i * a_stride]; - phi::dtype::complex* tau_working_ptr = &tau[i * tau_stride]; + phi::complex64* a_working_ptr = &a[i * a_stride]; + phi::complex64* tau_working_ptr = &tau[i * tau_stride]; // compute orggr PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCungqr( handle, @@ -896,20 +883,18 @@ void BatchedOrgqr>( } template <> -void BatchedOrgqr>( - const GPUContext& dev_ctx, - int batch_size, - int m, - int n, - int k, - phi::dtype::complex* a, - int lda, - phi::dtype::complex* tau, - int a_stride, - int tau_stride) { +void BatchedOrgqr(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + int k, + phi::complex128* a, + int lda, + phi::complex128* tau, + int a_stride, + int tau_stride) { int lwork = 0; - // auto handle = dev_ctx.cusolver_dn_handle(); auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZungqr_bufferSize( handle, @@ -923,16 +908,16 @@ void BatchedOrgqr>( DenseTensor workspace = DenseTensor(); workspace.Resize(common::make_ddim({lwork})); - phi::dtype::complex* workspace_ptr = - dev_ctx.template Alloc>(&workspace); + phi::complex128* workspace_ptr = + dev_ctx.template Alloc(&workspace); DenseTensor info = DenseTensor(); info.Resize(common::make_ddim({1})); int* info_d = dev_ctx.template Alloc(&info); for (int i = 0; i < batch_size; ++i) { - phi::dtype::complex* a_working_ptr = &a[i * a_stride]; - phi::dtype::complex* tau_working_ptr = &tau[i * tau_stride]; + phi::complex128* a_working_ptr = &a[i * a_stride]; + phi::complex128* tau_working_ptr = &tau[i * tau_stride]; // compute orggr PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZungqr( handle, @@ -965,11 +950,15 @@ void BatchedOrgqr>( } // namespace phi +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(qr, GPU, ALL_LAYOUT, phi::QrKernel, float, double) {} +#else PD_REGISTER_PLUGIN_KERNEL(qr, metax_gpu, ALL_LAYOUT, phi::QrKernel, float, double, - phi::dtype::complex, - phi::dtype::complex) {} + phi::complex64, + phi::complex128) {} +#endif