Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddle/fluid/platform/dynload/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,14 @@ namespace dynload {
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
__macro(cublasDgetriBatched); \
__macro(cublasCgetrfBatched); \
__macro(cublasCgetriBatched); \
__macro(cublasZgetrfBatched); \
__macro(cublasZgetriBatched); \
__macro(cublasSmatinvBatched); \
__macro(cublasDmatinvBatched); \
__macro(cublasCmatinvBatched); \
__macro(cublasZmatinvBatched); \
__macro(cublasSgetrsBatched); \
__macro(cublasDgetrsBatched);

Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/backends/dynload/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,14 @@ extern void *cublas_dso_handle;
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
__macro(cublasDgetriBatched); \
__macro(cublasCgetrfBatched); \
__macro(cublasCgetriBatched); \
__macro(cublasZgetrfBatched); \
__macro(cublasZgetriBatched); \
__macro(cublasSmatinvBatched); \
__macro(cublasDmatinvBatched); \
__macro(cublasCmatinvBatched); \
__macro(cublasZmatinvBatched); \
__macro(cublasSgetrsBatched); \
__macro(cublasDgetrsBatched);

Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/determinant_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ PD_REGISTER_KERNEL(determinant_grad,
ALL_LAYOUT,
phi::DeterminantGradKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/determinant_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"

PD_REGISTER_KERNEL(
determinant, CPU, ALL_LAYOUT, phi::DeterminantKernel, float, double) {}
PD_REGISTER_KERNEL(determinant,
CPU,
ALL_LAYOUT,
phi::DeterminantKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
114 changes: 114 additions & 0 deletions paddle/phi/kernels/funcs/blas/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,63 @@ struct CUBlas<phi::dtype::complex<float>> {
#endif
}

static void GETRF_BATCH(cublasHandle_t handle,
int n,
phi::dtype::complex<float> **A,
int lda,
int *ipiv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cublasCgetrfBatched(handle,
n,
reinterpret_cast<cuComplex **>(A),
lda,
ipiv,
info,
batch_size));
}

static void GETRI_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<float> **A,
int lda,
const int *ipiv,
phi::dtype::complex<float> **A_inv,
int lda_inv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetriBatched(
handle,
n,
reinterpret_cast<const cuComplex **>(A),
lda,
ipiv,
reinterpret_cast<cuComplex **>(A_inv),
lda_inv,
info,
batch_size));
}

static void MATINV_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<float> **A,
int lda,
phi::dtype::complex<float> **A_inv,
int lda_inv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCmatinvBatched(
handle,
n,
reinterpret_cast<const cuComplex **>(A),
lda,
reinterpret_cast<cuComplex **>(A_inv),
lda_inv,
info,
batch_size));
}

static void TRSM_BATCH(cublasHandle_t handle,
cublasSideMode_t side,
cublasFillMode_t uplo,
Expand Down Expand Up @@ -836,6 +893,63 @@ struct CUBlas<phi::dtype::complex<double>> {
ldb));
}

static void GETRF_BATCH(cublasHandle_t handle,
int n,
phi::dtype::complex<double> **A,
int lda,
int *ipiv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetrfBatched(
handle,
n,
reinterpret_cast<cuDoubleComplex **>(A),
lda,
ipiv,
info,
batch_size));
}

static void GETRI_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<double> **A,
int lda,
const int *ipiv,
phi::dtype::complex<double> **A_inv,
int lda_inv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetriBatched(
handle,
n,
reinterpret_cast<const cuDoubleComplex **>(A),
lda,
ipiv,
reinterpret_cast<cuDoubleComplex **>(A_inv),
lda_inv,
info,
batch_size));
}

static void MATINV_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<double> **A,
int lda,
phi::dtype::complex<double> **A_inv,
int lda_inv,
int *info,
int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZmatinvBatched(
handle,
n,
reinterpret_cast<const cuDoubleComplex **>(A),
lda,
reinterpret_cast<cuDoubleComplex **>(A_inv),
lda_inv,
info,
batch_size));
}

static void TRSM_BATCH(cublasHandle_t handle,
cublasSideMode_t side,
cublasFillMode_t uplo,
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/funcs/matrix_inverse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ template <typename Context, typename T>
void MatrixInverseFunctor<Context, T>::operator()(const Context& dev_ctx,
const DenseTensor& a,
DenseTensor* a_inv) {
ComputeInverseEigen<Context, T>(dev_ctx, a, a_inv);
MatrixInverseTrait<Context, T>::ComputeInverseEigen(dev_ctx, a, a_inv);
}

template class MatrixInverseFunctor<CPUContext, float>;
template class MatrixInverseFunctor<CPUContext, double>;
template class MatrixInverseFunctor<CPUContext, phi::dtype::complex<float>>;
template class MatrixInverseFunctor<CPUContext, phi::dtype::complex<double>>;

} // namespace funcs
} // namespace phi
4 changes: 3 additions & 1 deletion paddle/phi/kernels/funcs/matrix_inverse.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,14 @@ void MatrixInverseFunctor<Context, T>::operator()(const Context& dev_ctx,
info[i]));
}
#else
ComputeInverseEigen<Context, T>(dev_ctx, a, a_inv);
MatrixInverseTrait<Context, T>::ComputeInverseEigen(dev_ctx, a, a_inv);
#endif
}

template class MatrixInverseFunctor<GPUContext, float>;
template class MatrixInverseFunctor<GPUContext, double>;
template class MatrixInverseFunctor<GPUContext, phi::dtype::complex<float>>;
template class MatrixInverseFunctor<GPUContext, phi::dtype::complex<double>>;

} // namespace funcs
} // namespace phi
128 changes: 101 additions & 27 deletions paddle/phi/kernels/funcs/matrix_inverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,108 @@ namespace phi {
namespace funcs {

template <typename Context, typename T>
void ComputeInverseEigen(const Context& dev_ctx,
const DenseTensor& a,
DenseTensor* a_inv) {
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
const auto& mat_dims = a.dims();
const int rank = mat_dims.size();
int n = mat_dims[rank - 1];
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;

const T* a_ptr = a.data<T>();
T* a_inv_ptr = dev_ctx.template Alloc<T>(a_inv);

for (int i = 0; i < batch_size; ++i) {
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);

const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_GT(min_abs_pivot,
static_cast<T>(0),
errors::InvalidArgument("Input is not invertible."));
mat_inv.noalias() = lu.inverse();
struct MatrixInverseTrait {
static void ComputeInverseEigen(const Context& dev_ctx,
const DenseTensor& a,
DenseTensor* a_inv) {
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
const auto& mat_dims = a.dims();
const int rank = mat_dims.size();
int n = mat_dims[rank - 1];
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;

const T* a_ptr = a.data<T>();
T* a_inv_ptr = dev_ctx.template Alloc<T>(a_inv);

for (int i = 0; i < batch_size; ++i) {
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);

PADDLE_ENFORCE_NE(mat.determinant(),
static_cast<T>(0),
errors::InvalidArgument("Input is not invertible."));

mat_inv.noalias() = lu.inverse();
}
}
}
};

template <typename Context>
struct MatrixInverseTrait<Context, phi::dtype::complex<float>> {
static void ComputeInverseEigen(const Context& dev_ctx,
const DenseTensor& a,
DenseTensor* a_inv) {
using Matrix = Eigen::Matrix<std::complex<float>,
Eigen::Dynamic,
Eigen::Dynamic,
Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
const auto& mat_dims = a.dims();
const int rank = mat_dims.size();
int n = mat_dims[rank - 1];
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;

const auto* a_ptr = reinterpret_cast<const std::complex<float>*>(
a.data<phi::dtype::complex<float>>());
auto* a_inv_ptr = reinterpret_cast<std::complex<float>*>(
dev_ctx.template Alloc<phi::dtype::complex<float>>(a_inv));

for (int i = 0; i < batch_size; ++i) {
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);

PADDLE_ENFORCE_NE(mat.determinant(),
static_cast<std::complex<float>>(0),
errors::InvalidArgument("Input is not invertible."));

mat_inv.noalias() = lu.inverse();
}
}
};

template <typename Context>
struct MatrixInverseTrait<Context, phi::dtype::complex<double>> {
static void ComputeInverseEigen(const Context& dev_ctx,
const DenseTensor& a,
DenseTensor* a_inv) {
using Matrix = Eigen::Matrix<std::complex<double>,
Eigen::Dynamic,
Eigen::Dynamic,
Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
const auto& mat_dims = a.dims();
const int rank = mat_dims.size();
int n = mat_dims[rank - 1];
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;

const auto* a_ptr = reinterpret_cast<const std::complex<double>*>(
a.data<phi::dtype::complex<double>>());
auto* a_inv_ptr = reinterpret_cast<std::complex<double>*>(
dev_ctx.template Alloc<phi::dtype::complex<double>>(a_inv));

for (int i = 0; i < batch_size; ++i) {
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);

PADDLE_ENFORCE_NE(mat.determinant(),
static_cast<std::complex<double>>(0),
errors::InvalidArgument("Input is not invertible."));

mat_inv.noalias() = lu.inverse();
}
}
};

template <typename Context, typename T>
class MatrixInverseFunctor {
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/determinant_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(determinant_grad,
phi::DeterminantGradKernel,
phi::dtype::float16,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/determinant_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(determinant,
phi::DeterminantKernel,
phi::dtype::float16,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/impl/determinant_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/determinant_grad_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
Expand Down
Loading