Skip to content

Commit f54187f

Browse files
duqimengStareAtYoumetax666jxwangmetaxzhang-chenyi
authored
[metax] updata_qr_kernel (#11)
* [metax] chang patch fix copy * [metax] chang patch fix copy * [Metax] update metax_gpu unit test * [Metax] fix test CMakeList.txt * [metax]change_cupti_and_fix_softmax * [metax]change_patch * [metax]change_patch * [metax] updata_qr_kernel * [metax] updata_qr_kernel --------- Co-authored-by: Mingkun.Zhang <[email protected]> Co-authored-by: metax666 <[email protected]> Co-authored-by: jiaxinWang-metax <[email protected]> Co-authored-by: MingkunZhang <[email protected]> Co-authored-by: chezhang <[email protected]> Co-authored-by: zhang-chenyi <[email protected]> Co-authored-by: ZhouDuan <[email protected]>
1 parent 8938293 commit f54187f

File tree

1 file changed

+98
-109
lines changed

1 file changed

+98
-109
lines changed

backends/metax_gpu/kernels/metax_kernel/qr_kernel_register.cu

Lines changed: 98 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@
2222
#include <algorithm>
2323
#include <vector>
2424

25-
#include "kernels/impl/values_vectors_functor.h"
25+
#include "kernels/metax_context.h"
2626
#include "paddle/phi/backends/gpu/gpu_context.h"
27-
#include "paddle/phi/common/complex.h"
2827
#include "paddle/phi/common/memory_utils.h"
2928
#include "paddle/phi/core/enforce.h"
3029
#include "paddle/phi/core/kernel_registry.h"
@@ -39,7 +38,6 @@
3938
#include "paddle/phi/kernels/slice_kernel.h"
4039
#include "paddle/phi/kernels/transpose_kernel.h"
4140
#include "paddle/phi/kernels/tril_triu_kernel.h"
42-
4341
namespace phi {
4442

4543
template <class T, class Context>
@@ -358,47 +356,47 @@ void QrKernel(const Context& dev_ctx,
358356

359357
#ifdef PADDLE_WITH_HIP
360358
#define FUNC_WITH_TYPES(m) m(float, s) m(double, d)
361-
#define GEQRF_BATCH_INSTANCE(T, C) \
362-
template <> \
363-
void BatchedGeqrf<GPUContext, T>(const GPUContext& dev_ctx, \
364-
int batch_size, \
365-
int m, \
366-
int n, \
367-
T* a, \
368-
int lda, \
369-
T* tau, \
370-
int a_stride, \
371-
int tau_stride) { \
372-
auto handle = dev_ctx.cusolver_dn_handle(); \
373-
for (int i = 0; i < batch_size; ++i) { \
374-
T* a_working_ptr = &a[i * a_stride]; \
375-
T* tau_working_ptr = &tau[i * tau_stride]; \
376-
PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_##C##geqrf( \
377-
handle, m, n, a_working_ptr, lda, tau_working_ptr)); \
378-
} \
359+
#define GEQRF_BATCH_INSTANCE(T, C) \
360+
template <> \
361+
void BatchedGeqrf<GPUContext, T>(const GPUContext& dev_ctx, \
362+
int batch_size, \
363+
int m, \
364+
int n, \
365+
T* a, \
366+
int lda, \
367+
T* tau, \
368+
int a_stride, \
369+
int tau_stride) { \
370+
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); \
371+
for (int i = 0; i < batch_size; ++i) { \
372+
T* a_working_ptr = &a[i * a_stride]; \
373+
T* tau_working_ptr = &tau[i * tau_stride]; \
374+
PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_##C##geqrf( \
375+
handle, m, n, a_working_ptr, lda, tau_working_ptr)); \
376+
} \
379377
}
380378

381379
FUNC_WITH_TYPES(GEQRF_BATCH_INSTANCE);
382380

383-
#define ORGQR_BATCH_INSTANCE(T, C) \
384-
template <> \
385-
void BatchedOrgqr<GPUContext, T>(const GPUContext& dev_ctx, \
386-
int batch_size, \
387-
int m, \
388-
int n, \
389-
int k, \
390-
T* a, \
391-
int lda, \
392-
T* tau, \
393-
int a_stride, \
394-
int tau_stride) { \
395-
auto handle = dev_ctx.cusolver_dn_handle(); \
396-
for (int i = 0; i < batch_size; ++i) { \
397-
T* a_working_ptr = &a[i * a_stride]; \
398-
T* tau_working_ptr = &tau[i * tau_stride]; \
399-
PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_##C##orgqr( \
400-
handle, m, n, k, a_working_ptr, lda, tau_working_ptr)); \
401-
} \
381+
#define ORGQR_BATCH_INSTANCE(T, C) \
382+
template <> \
383+
void BatchedOrgqr<GPUContext, T>(const GPUContext& dev_ctx, \
384+
int batch_size, \
385+
int m, \
386+
int n, \
387+
int k, \
388+
T* a, \
389+
int lda, \
390+
T* tau, \
391+
int a_stride, \
392+
int tau_stride) { \
393+
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); \
394+
for (int i = 0; i < batch_size; ++i) { \
395+
T* a_working_ptr = &a[i * a_stride]; \
396+
T* tau_working_ptr = &tau[i * tau_stride]; \
397+
PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_##C##orgqr( \
398+
handle, m, n, k, a_working_ptr, lda, tau_working_ptr)); \
399+
} \
402400
}
403401

404402
FUNC_WITH_TYPES(ORGQR_BATCH_INSTANCE);
@@ -421,7 +419,6 @@ void BatchedGeqrf<GPUContext, float>(const GPUContext& dev_ctx,
421419
const int64_t a_stride_64 = static_cast<int64_t>(a_stride);
422420
const int64_t tau_stride_64 = static_cast<int64_t>(tau_stride);
423421

424-
// auto handle = dev_ctx.cusolver_dn_handle();
425422
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
426423

427424
size_t workspace_in_bytes_on_device = 0;
@@ -499,7 +496,6 @@ void BatchedGeqrf<GPUContext, float>(const GPUContext& dev_ctx,
499496
} else {
500497
int lwork = 0;
501498

502-
// auto handle = dev_ctx.cusolver_dn_handle();
503499
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
504500
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSgeqrf_bufferSize(
505501
handle, m, n, a, lda, &lwork));
@@ -555,7 +551,6 @@ void BatchedGeqrf<GPUContext, double>(const GPUContext& dev_ctx,
555551
int tau_stride) {
556552
int lwork = 0;
557553

558-
// auto handle = dev_ctx.cusolver_dn_handle();
559554
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
560555
PADDLE_ENFORCE_GPU_SUCCESS(
561556
phi::dynload::cusolverDnDgeqrf_bufferSize(handle, m, n, a, lda, &lwork));
@@ -599,35 +594,33 @@ void BatchedGeqrf<GPUContext, double>(const GPUContext& dev_ctx,
599594
}
600595

601596
template <>
602-
void BatchedGeqrf<GPUContext, phi::dtype::complex<float>>(
603-
const GPUContext& dev_ctx,
604-
int batch_size,
605-
int m,
606-
int n,
607-
phi::dtype::complex<float>* a,
608-
int lda,
609-
phi::dtype::complex<float>* tau,
610-
int a_stride,
611-
int tau_stride) {
597+
void BatchedGeqrf<GPUContext, phi::complex64>(const GPUContext& dev_ctx,
598+
int batch_size,
599+
int m,
600+
int n,
601+
phi::complex64* a,
602+
int lda,
603+
phi::complex64* tau,
604+
int a_stride,
605+
int tau_stride) {
612606
int lwork = 0;
613607

614-
// auto handle = dev_ctx.cusolver_dn_handle();
615608
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
616609
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgeqrf_bufferSize(
617610
handle, m, n, reinterpret_cast<cuComplex*>(a), lda, &lwork));
618611

619612
DenseTensor workspace = DenseTensor();
620613
workspace.Resize(common::make_ddim({lwork}));
621-
phi::dtype::complex<float>* workspace_ptr =
622-
dev_ctx.template Alloc<phi::dtype::complex<float>>(&workspace);
614+
phi::complex64* workspace_ptr =
615+
dev_ctx.template Alloc<phi::complex64>(&workspace);
623616

624617
DenseTensor info = DenseTensor();
625618
info.Resize(common::make_ddim({1}));
626619
int* info_d = dev_ctx.template Alloc<int>(&info);
627620

628621
for (int i = 0; i < batch_size; ++i) {
629-
phi::dtype::complex<float>* a_working_ptr = &a[i * a_stride];
630-
phi::dtype::complex<float>* tau_working_ptr = &tau[i * tau_stride];
622+
phi::complex64* a_working_ptr = &a[i * a_stride];
623+
phi::complex64* tau_working_ptr = &tau[i * tau_stride];
631624
// compute geqrf
632625
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgeqrf(
633626
handle,
@@ -657,35 +650,33 @@ void BatchedGeqrf<GPUContext, phi::dtype::complex<float>>(
657650
}
658651

659652
template <>
660-
void BatchedGeqrf<GPUContext, phi::dtype::complex<double>>(
661-
const GPUContext& dev_ctx,
662-
int batch_size,
663-
int m,
664-
int n,
665-
phi::dtype::complex<double>* a,
666-
int lda,
667-
phi::dtype::complex<double>* tau,
668-
int a_stride,
669-
int tau_stride) {
653+
void BatchedGeqrf<GPUContext, phi::complex128>(const GPUContext& dev_ctx,
654+
int batch_size,
655+
int m,
656+
int n,
657+
phi::complex128* a,
658+
int lda,
659+
phi::complex128* tau,
660+
int a_stride,
661+
int tau_stride) {
670662
int lwork = 0;
671663

672-
// auto handle = dev_ctx.cusolver_dn_handle();
673664
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
674665
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgeqrf_bufferSize(
675666
handle, m, n, reinterpret_cast<cuDoubleComplex*>(a), lda, &lwork));
676667

677668
DenseTensor workspace = DenseTensor();
678669
workspace.Resize(common::make_ddim({lwork}));
679-
phi::dtype::complex<double>* workspace_ptr =
680-
dev_ctx.template Alloc<phi::dtype::complex<double>>(&workspace);
670+
phi::complex128* workspace_ptr =
671+
dev_ctx.template Alloc<phi::complex128>(&workspace);
681672

682673
DenseTensor info = DenseTensor();
683674
info.Resize(common::make_ddim({1}));
684675
int* info_d = dev_ctx.template Alloc<int>(&info);
685676

686677
for (int i = 0; i < batch_size; ++i) {
687-
phi::dtype::complex<double>* a_working_ptr = &a[i * a_stride];
688-
phi::dtype::complex<double>* tau_working_ptr = &tau[i * tau_stride];
678+
phi::complex128* a_working_ptr = &a[i * a_stride];
679+
phi::complex128* tau_working_ptr = &tau[i * tau_stride];
689680
// compute geqrf
690681
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgeqrf(
691682
handle,
@@ -727,7 +718,6 @@ void BatchedOrgqr<GPUContext, float>(const GPUContext& dev_ctx,
727718
int tau_stride) {
728719
int lwork = 0;
729720

730-
// auto handle = dev_ctx.cusolver_dn_handle();
731721
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
732722
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr_bufferSize(
733723
handle, m, n, k, a, lda, tau, &lwork));
@@ -784,7 +774,6 @@ void BatchedOrgqr<GPUContext, double>(const GPUContext& dev_ctx,
784774
int tau_stride) {
785775
int lwork = 0;
786776

787-
// auto handle = dev_ctx.cusolver_dn_handle();
788777
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
789778
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr_bufferSize(
790779
handle, m, n, k, a, lda, tau, &lwork));
@@ -829,20 +818,18 @@ void BatchedOrgqr<GPUContext, double>(const GPUContext& dev_ctx,
829818
}
830819

831820
template <>
832-
void BatchedOrgqr<GPUContext, phi::dtype::complex<float>>(
833-
const GPUContext& dev_ctx,
834-
int batch_size,
835-
int m,
836-
int n,
837-
int k,
838-
phi::dtype::complex<float>* a,
839-
int lda,
840-
phi::dtype::complex<float>* tau,
841-
int a_stride,
842-
int tau_stride) {
821+
void BatchedOrgqr<GPUContext, phi::complex64>(const GPUContext& dev_ctx,
822+
int batch_size,
823+
int m,
824+
int n,
825+
int k,
826+
phi::complex64* a,
827+
int lda,
828+
phi::complex64* tau,
829+
int a_stride,
830+
int tau_stride) {
843831
int lwork = 0;
844832

845-
// auto handle = dev_ctx.cusolver_dn_handle();
846833
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
847834
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCungqr_bufferSize(
848835
handle,
@@ -856,16 +843,16 @@ void BatchedOrgqr<GPUContext, phi::dtype::complex<float>>(
856843

857844
DenseTensor workspace = DenseTensor();
858845
workspace.Resize(common::make_ddim({lwork}));
859-
phi::dtype::complex<float>* workspace_ptr =
860-
dev_ctx.template Alloc<phi::dtype::complex<float>>(&workspace);
846+
phi::complex64* workspace_ptr =
847+
dev_ctx.template Alloc<phi::complex64>(&workspace);
861848

862849
DenseTensor info = DenseTensor();
863850
info.Resize(common::make_ddim({1}));
864851
int* info_d = dev_ctx.template Alloc<int>(&info);
865852

866853
for (int i = 0; i < batch_size; ++i) {
867-
phi::dtype::complex<float>* a_working_ptr = &a[i * a_stride];
868-
phi::dtype::complex<float>* tau_working_ptr = &tau[i * tau_stride];
854+
phi::complex64* a_working_ptr = &a[i * a_stride];
855+
phi::complex64* tau_working_ptr = &tau[i * tau_stride];
869856
// compute orggr
870857
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCungqr(
871858
handle,
@@ -896,20 +883,18 @@ void BatchedOrgqr<GPUContext, phi::dtype::complex<float>>(
896883
}
897884

898885
template <>
899-
void BatchedOrgqr<GPUContext, phi::dtype::complex<double>>(
900-
const GPUContext& dev_ctx,
901-
int batch_size,
902-
int m,
903-
int n,
904-
int k,
905-
phi::dtype::complex<double>* a,
906-
int lda,
907-
phi::dtype::complex<double>* tau,
908-
int a_stride,
909-
int tau_stride) {
886+
void BatchedOrgqr<GPUContext, phi::complex128>(const GPUContext& dev_ctx,
887+
int batch_size,
888+
int m,
889+
int n,
890+
int k,
891+
phi::complex128* a,
892+
int lda,
893+
phi::complex128* tau,
894+
int a_stride,
895+
int tau_stride) {
910896
int lwork = 0;
911897

912-
// auto handle = dev_ctx.cusolver_dn_handle();
913898
auto handle = GetCusolverDnHandle(dev_ctx.stream(), dev_ctx.GetPlace());
914899
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZungqr_bufferSize(
915900
handle,
@@ -923,16 +908,16 @@ void BatchedOrgqr<GPUContext, phi::dtype::complex<double>>(
923908

924909
DenseTensor workspace = DenseTensor();
925910
workspace.Resize(common::make_ddim({lwork}));
926-
phi::dtype::complex<double>* workspace_ptr =
927-
dev_ctx.template Alloc<phi::dtype::complex<double>>(&workspace);
911+
phi::complex128* workspace_ptr =
912+
dev_ctx.template Alloc<phi::complex128>(&workspace);
928913

929914
DenseTensor info = DenseTensor();
930915
info.Resize(common::make_ddim({1}));
931916
int* info_d = dev_ctx.template Alloc<int>(&info);
932917

933918
for (int i = 0; i < batch_size; ++i) {
934-
phi::dtype::complex<double>* a_working_ptr = &a[i * a_stride];
935-
phi::dtype::complex<double>* tau_working_ptr = &tau[i * tau_stride];
919+
phi::complex128* a_working_ptr = &a[i * a_stride];
920+
phi::complex128* tau_working_ptr = &tau[i * tau_stride];
936921
// compute orggr
937922
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZungqr(
938923
handle,
@@ -965,11 +950,15 @@ void BatchedOrgqr<GPUContext, phi::dtype::complex<double>>(
965950

966951
} // namespace phi
967952

953+
#ifdef PADDLE_WITH_HIP
954+
PD_REGISTER_KERNEL(qr, GPU, ALL_LAYOUT, phi::QrKernel, float, double) {}
955+
#else
968956
PD_REGISTER_PLUGIN_KERNEL(qr,
969957
metax_gpu,
970958
ALL_LAYOUT,
971959
phi::QrKernel,
972960
float,
973961
double,
974-
phi::dtype::complex<float>,
975-
phi::dtype::complex<double>) {}
962+
phi::complex64,
963+
phi::complex128) {}
964+
#endif

0 commit comments

Comments
 (0)