diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index 496b253dff5b3d..980b7cb35410b5 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -80,8 +80,14 @@ namespace dynload { __macro(cublasSgetriBatched); \ __macro(cublasDgetrfBatched); \ __macro(cublasDgetriBatched); \ + __macro(cublasCgetrfBatched); \ + __macro(cublasCgetriBatched); \ + __macro(cublasZgetrfBatched); \ + __macro(cublasZgetriBatched); \ __macro(cublasSmatinvBatched); \ __macro(cublasDmatinvBatched); \ + __macro(cublasCmatinvBatched); \ + __macro(cublasZmatinvBatched); \ __macro(cublasSgetrsBatched); \ __macro(cublasDgetrsBatched); diff --git a/paddle/phi/backends/dynload/cublas.h b/paddle/phi/backends/dynload/cublas.h index 308ae2accef146..48bf53ba2349d4 100644 --- a/paddle/phi/backends/dynload/cublas.h +++ b/paddle/phi/backends/dynload/cublas.h @@ -94,8 +94,14 @@ extern void *cublas_dso_handle; __macro(cublasSgetriBatched); \ __macro(cublasDgetrfBatched); \ __macro(cublasDgetriBatched); \ + __macro(cublasCgetrfBatched); \ + __macro(cublasCgetriBatched); \ + __macro(cublasZgetrfBatched); \ + __macro(cublasZgetriBatched); \ __macro(cublasSmatinvBatched); \ __macro(cublasDmatinvBatched); \ + __macro(cublasCmatinvBatched); \ + __macro(cublasZmatinvBatched); \ __macro(cublasSgetrsBatched); \ __macro(cublasDgetrsBatched); diff --git a/paddle/phi/kernels/cpu/determinant_grad_kernel.cc b/paddle/phi/kernels/cpu/determinant_grad_kernel.cc index e57d7263f88bfc..0eb588c0dc4b4f 100644 --- a/paddle/phi/kernels/cpu/determinant_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/determinant_grad_kernel.cc @@ -22,4 +22,6 @@ PD_REGISTER_KERNEL(determinant_grad, ALL_LAYOUT, phi::DeterminantGradKernel, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/determinant_kernel.cc b/paddle/phi/kernels/cpu/determinant_kernel.cc index 5810e88e92527f..fe212b848b66d0 100644 --- a/paddle/phi/kernels/cpu/determinant_kernel.cc +++ b/paddle/phi/kernels/cpu/determinant_kernel.cc @@ -17,5 +17,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/determinant_kernel_impl.h" -PD_REGISTER_KERNEL( - determinant, CPU, ALL_LAYOUT, phi::DeterminantKernel, float, double) {} +PD_REGISTER_KERNEL(determinant, + CPU, + ALL_LAYOUT, + phi::DeterminantKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index 96b2128eee16c5..75ab4e1023e8a9 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -657,6 +657,63 @@ struct CUBlas> { #endif } + static void GETRF_BATCH(cublasHandle_t handle, + int n, + phi::dtype::complex **A, + int lda, + int *ipiv, + int *info, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasCgetrfBatched(handle, + n, + reinterpret_cast(A), + lda, + ipiv, + info, + batch_size)); + } + + static void GETRI_BATCH(cublasHandle_t handle, + int n, + const phi::dtype::complex **A, + int lda, + const int *ipiv, + phi::dtype::complex **A_inv, + int lda_inv, + int *info, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetriBatched( + handle, + n, + reinterpret_cast(A), + lda, + ipiv, + reinterpret_cast(A_inv), + lda_inv, + info, + batch_size)); + } + + static void MATINV_BATCH(cublasHandle_t handle, + int n, + const phi::dtype::complex **A, + int lda, + phi::dtype::complex **A_inv, + int lda_inv, + int *info, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCmatinvBatched( + handle, + n, + reinterpret_cast(A), + lda, + reinterpret_cast(A_inv), + lda_inv, + info, + batch_size)); + } + static void TRSM_BATCH(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, @@ -836,6 +893,63 @@ struct CUBlas> { ldb)); } + static void GETRF_BATCH(cublasHandle_t handle, + int n, + phi::dtype::complex **A, + int lda, + int *ipiv, + int *info, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetrfBatched( + handle, + n, + reinterpret_cast(A), + lda, + ipiv, + info, + batch_size)); + } + + static void GETRI_BATCH(cublasHandle_t handle, + int n, + const phi::dtype::complex **A, + int lda, + const int *ipiv, + phi::dtype::complex **A_inv, + int lda_inv, + int *info, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetriBatched( + handle, + n, + reinterpret_cast(A), + lda, + ipiv, + reinterpret_cast(A_inv), + lda_inv, + info, + batch_size)); + } + + static void MATINV_BATCH(cublasHandle_t handle, + int n, + const phi::dtype::complex **A, + int lda, + phi::dtype::complex **A_inv, + int lda_inv, + int *info, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZmatinvBatched( + handle, + n, + reinterpret_cast(A), + lda, + reinterpret_cast(A_inv), + lda_inv, + info, + batch_size)); + } + static void TRSM_BATCH(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, diff --git a/paddle/phi/kernels/funcs/matrix_inverse.cc b/paddle/phi/kernels/funcs/matrix_inverse.cc index c316970e6a5600..fc0f1340cf8f98 100644 --- a/paddle/phi/kernels/funcs/matrix_inverse.cc +++ b/paddle/phi/kernels/funcs/matrix_inverse.cc @@ -23,11 +23,13 @@ template void MatrixInverseFunctor::operator()(const Context& dev_ctx, const DenseTensor& a, DenseTensor* a_inv) { - ComputeInverseEigen(dev_ctx, a, a_inv); + MatrixInverseTrait::ComputeInverseEigen(dev_ctx, a, a_inv); } template class MatrixInverseFunctor; template class MatrixInverseFunctor; +template class MatrixInverseFunctor>; +template class MatrixInverseFunctor>; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/matrix_inverse.cu b/paddle/phi/kernels/funcs/matrix_inverse.cu index c0ea7ad84c41b1..827f8b00ed2083 100644 --- a/paddle/phi/kernels/funcs/matrix_inverse.cu +++ b/paddle/phi/kernels/funcs/matrix_inverse.cu @@ -125,12 +125,14 @@ void MatrixInverseFunctor::operator()(const Context& dev_ctx, info[i])); } #else - ComputeInverseEigen(dev_ctx, a, a_inv); + MatrixInverseTrait::ComputeInverseEigen(dev_ctx, a, a_inv); #endif } template class MatrixInverseFunctor; template class MatrixInverseFunctor; +template class MatrixInverseFunctor>; +template class MatrixInverseFunctor>; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/matrix_inverse.h b/paddle/phi/kernels/funcs/matrix_inverse.h index f0cd265a546481..472077d245487f 100644 --- a/paddle/phi/kernels/funcs/matrix_inverse.h +++ b/paddle/phi/kernels/funcs/matrix_inverse.h @@ -26,34 +26,108 @@ namespace phi { namespace funcs { template -void ComputeInverseEigen(const Context& dev_ctx, - const DenseTensor& a, - DenseTensor* a_inv) { - using Matrix = - Eigen::Matrix; - using EigenMatrixMap = Eigen::Map; - using ConstEigenMatrixMap = Eigen::Map; - const auto& mat_dims = a.dims(); - const int rank = mat_dims.size(); - int n = mat_dims[rank - 1]; - int batch_size = rank > 2 ? a.numel() / (n * n) : 1; - - const T* a_ptr = a.data(); - T* a_inv_ptr = dev_ctx.template Alloc(a_inv); - - for (int i = 0; i < batch_size; ++i) { - ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n); - EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n); - Eigen::PartialPivLU lu; - lu.compute(mat); - - const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff(); - PADDLE_ENFORCE_GT(min_abs_pivot, - static_cast(0), - errors::InvalidArgument("Input is not invertible.")); - mat_inv.noalias() = lu.inverse(); +struct MatrixInverseTrait { + static void ComputeInverseEigen(const Context& dev_ctx, + const DenseTensor& a, + DenseTensor* a_inv) { + using Matrix = + Eigen::Matrix; + using EigenMatrixMap = Eigen::Map; + using ConstEigenMatrixMap = Eigen::Map; + const auto& mat_dims = a.dims(); + const int rank = mat_dims.size(); + int n = mat_dims[rank - 1]; + int batch_size = rank > 2 ? a.numel() / (n * n) : 1; + + const T* a_ptr = a.data(); + T* a_inv_ptr = dev_ctx.template Alloc(a_inv); + + for (int i = 0; i < batch_size; ++i) { + ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n); + EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n); + Eigen::PartialPivLU lu; + lu.compute(mat); + + PADDLE_ENFORCE_NE(mat.determinant(), + static_cast(0), + errors::InvalidArgument("Input is not invertible.")); + + mat_inv.noalias() = lu.inverse(); + } } -} +}; + +template +struct MatrixInverseTrait> { + static void ComputeInverseEigen(const Context& dev_ctx, + const DenseTensor& a, + DenseTensor* a_inv) { + using Matrix = Eigen::Matrix, + Eigen::Dynamic, + Eigen::Dynamic, + Eigen::RowMajor>; + using EigenMatrixMap = Eigen::Map; + using ConstEigenMatrixMap = Eigen::Map; + const auto& mat_dims = a.dims(); + const int rank = mat_dims.size(); + int n = mat_dims[rank - 1]; + int batch_size = rank > 2 ? a.numel() / (n * n) : 1; + + const auto* a_ptr = reinterpret_cast*>( + a.data>()); + auto* a_inv_ptr = reinterpret_cast*>( + dev_ctx.template Alloc>(a_inv)); + + for (int i = 0; i < batch_size; ++i) { + ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n); + EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n); + Eigen::PartialPivLU lu; + lu.compute(mat); + + PADDLE_ENFORCE_NE(mat.determinant(), + static_cast>(0), + errors::InvalidArgument("Input is not invertible.")); + + mat_inv.noalias() = lu.inverse(); + } + } +}; + +template +struct MatrixInverseTrait> { + static void ComputeInverseEigen(const Context& dev_ctx, + const DenseTensor& a, + DenseTensor* a_inv) { + using Matrix = Eigen::Matrix, + Eigen::Dynamic, + Eigen::Dynamic, + Eigen::RowMajor>; + using EigenMatrixMap = Eigen::Map; + using ConstEigenMatrixMap = Eigen::Map; + const auto& mat_dims = a.dims(); + const int rank = mat_dims.size(); + int n = mat_dims[rank - 1]; + int batch_size = rank > 2 ? a.numel() / (n * n) : 1; + + const auto* a_ptr = reinterpret_cast*>( + a.data>()); + auto* a_inv_ptr = reinterpret_cast*>( + dev_ctx.template Alloc>(a_inv)); + + for (int i = 0; i < batch_size; ++i) { + ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n); + EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n); + Eigen::PartialPivLU lu; + lu.compute(mat); + + PADDLE_ENFORCE_NE(mat.determinant(), + static_cast>(0), + errors::InvalidArgument("Input is not invertible.")); + + mat_inv.noalias() = lu.inverse(); + } + } +}; template class MatrixInverseFunctor { diff --git a/paddle/phi/kernels/gpu/determinant_grad_kernel.cu b/paddle/phi/kernels/gpu/determinant_grad_kernel.cu index f3187d5fefb519..26cb97f74866bc 100644 --- a/paddle/phi/kernels/gpu/determinant_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/determinant_grad_kernel.cu @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(determinant_grad, phi::DeterminantGradKernel, phi::dtype::float16, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/determinant_kernel.cu b/paddle/phi/kernels/gpu/determinant_kernel.cu index 58e27e3ce4abda..0cbd70e0b1a0f1 100644 --- a/paddle/phi/kernels/gpu/determinant_kernel.cu +++ b/paddle/phi/kernels/gpu/determinant_kernel.cu @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(determinant, phi::DeterminantKernel, phi::dtype::float16, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h b/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h index 8b135c4b520ae8..1baaf57e8d7ded 100644 --- a/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h @@ -19,6 +19,7 @@ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/complex_kernel.h" #include "paddle/phi/kernels/determinant_grad_kernel.h" #include "paddle/phi/kernels/elementwise_multiply_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" diff --git a/paddle/phi/kernels/impl/determinant_kernel_impl.h b/paddle/phi/kernels/impl/determinant_kernel_impl.h index 4a308a5798192d..86e76228fd29dd 100644 --- a/paddle/phi/kernels/impl/determinant_kernel_impl.h +++ b/paddle/phi/kernels/impl/determinant_kernel_impl.h @@ -51,6 +51,24 @@ class EigenMatrix { using MatrixType = Eigen::MatrixXd; }; +template <> +class EigenMatrix> { + public: + using MatrixType = Eigen::Matrix, + Eigen::Dynamic, + Eigen::Dynamic, + Eigen::RowMajor>; +}; + +template <> +class EigenMatrix> { + public: + using MatrixType = Eigen::Matrix, + Eigen::Dynamic, + Eigen::Dynamic, + Eigen::RowMajor>; +}; + inline int64_t GetBatchCount(const DDim dims) { int64_t batch_count = 1; auto dim_size = dims.size(); @@ -91,7 +109,7 @@ struct DeterminantFunctor { typename detail::EigenMatrix::MatrixType matrix(rank, rank); for (int64_t i = 0; i < rank; ++i) { for (int64_t j = 0; j < rank; ++j) { - matrix(i, j) = sub_vec[rank * i + j]; + matrix(i, j) = sub_vec[i * rank + j]; } } output_vec.push_back( @@ -101,6 +119,68 @@ struct DeterminantFunctor { } }; +template +struct DeterminantFunctor, Context> { + void operator()(const Context& dev_ctx, + const DenseTensor& input, + int64_t rank, + int64_t batch_count, + DenseTensor* output) { + std::vector> input_vec; + std::vector> output_vec; + phi::TensorToVector(input, dev_ctx, &input_vec); + for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel + auto begin_iter = input_vec.begin() + i * rank * rank; + auto end_iter = input_vec.begin() + (i + 1) * rank * rank; + std::vector> sub_vec( + begin_iter, + end_iter); // get every square matrix data + typename detail::EigenMatrix>::MatrixType + matrix(rank, rank); + for (int64_t i = 0; i < rank; ++i) { + for (int64_t j = 0; j < rank; ++j) { + matrix(i, j) = std::complex(sub_vec[i * rank + j].real, + sub_vec[i * rank + j].imag); + } + } + output_vec.push_back( + static_cast>(matrix.determinant())); + } + phi::TensorFromVector(output_vec, dev_ctx, output); + } +}; + +template +struct DeterminantFunctor, Context> { + void operator()(const Context& dev_ctx, + const DenseTensor& input, + int64_t rank, + int64_t batch_count, + DenseTensor* output) { + std::vector> input_vec; + std::vector> output_vec; + phi::TensorToVector(input, dev_ctx, &input_vec); + for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel + auto begin_iter = input_vec.begin() + i * rank * rank; + auto end_iter = input_vec.begin() + (i + 1) * rank * rank; + std::vector> sub_vec( + begin_iter, + end_iter); // get every square matrix data + typename detail::EigenMatrix>::MatrixType + matrix(rank, rank); + for (int64_t i = 0; i < rank; ++i) { + for (int64_t j = 0; j < rank; ++j) { + matrix(i, j) = std::complex(sub_vec[i * rank + j].real, + sub_vec[i * rank + j].imag); + } + } + output_vec.push_back( + static_cast>(matrix.determinant())); + } + phi::TensorFromVector(output_vec, dev_ctx, output); + } +}; + template void DeterminantKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 5ff36cdb754d53..18c8387d4497f1 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2385,7 +2385,12 @@ def det(x, name=None): if in_dynamic_or_pir_mode(): return _C_ops.det(x) else: - check_dtype(x.dtype, 'Input', ['float16', 'float32', 'float64'], 'det') + check_dtype( + x.dtype, + 'Input', + ['float16', 'float32', 'float64', 'complex64', 'complex128'], + 'det', + ) input_shape = list(x.shape) assert len(input_shape) >= 2, ( diff --git a/test/legacy_test/test_determinant_op.py b/test/legacy_test/test_determinant_op.py index 2fe7217225f74b..c2d20f5817b3f4 100644 --- a/test/legacy_test/test_determinant_op.py +++ b/test/legacy_test/test_determinant_op.py @@ -79,6 +79,52 @@ def init_data(self): ) +class TestDeterminantOpCaseComplex64(TestDeterminantOp): + def init_data(self): + np.random.seed(0) + self.case = ( + np.random.uniform(-1, 1, (10, 10)) + + 1j * np.random.uniform(-1, 1, (10, 10)) + ).astype("complex64") + self.inputs = {'Input': self.case} + self.target = np.linalg.det(self.case) + + def test_check_output(self): + self.check_output(check_pir=True, check_prim=False) + + def test_check_grad(self): + self.check_grad( + ['Input'], + ['Out'], + check_pir=True, + check_prim=False, + max_relative_error=2, + ) + + +class TestDeterminantOpCaseComplex128(TestDeterminantOp): + def init_data(self): + np.random.seed(0) + self.case = ( + np.random.uniform(-1, 1, (10, 10)) + + 1j * np.random.uniform(-1, 1, (10, 10)) + ).astype("complex128") + self.inputs = {'Input': self.case} + self.target = np.linalg.det(self.case) + + def test_check_output(self): + self.check_output(check_pir=True, check_prim=False) + + def test_check_grad(self): + self.check_grad( + ['Input'], + ['Out'], + check_pir=True, + check_prim=False, + max_relative_error=2, + ) + + class TestDeterminantAPI(unittest.TestCase): def setUp(self): np.random.seed(0)