From 410fb02a99a257ea83d112dc56a4b7af827b439f Mon Sep 17 00:00:00 2001 From: Zjq9409 <15205085056@163.com> Date: Thu, 16 Sep 2021 13:47:55 +0000 Subject: [PATCH 1/3] Optimization method implementation --- paddle/fluid/operators/eigh_op.cc | 13 +- paddle/fluid/operators/eigh_op.cu | 32 +--- paddle/fluid/operators/eigh_op.h | 33 ++-- .../operators/math/eigen_values_vectors.h | 165 +++++++----------- paddle/fluid/operators/svd_helper.h | 79 ++++----- 5 files changed, 129 insertions(+), 193 deletions(-) diff --git a/paddle/fluid/operators/eigh_op.cc b/paddle/fluid/operators/eigh_op.cc index b3056bd43ba53d..5577dfb8f889bb 100644 --- a/paddle/fluid/operators/eigh_op.cc +++ b/paddle/fluid/operators/eigh_op.cc @@ -47,12 +47,9 @@ class EighOp : public framework::OperatorWithKernel { input_dim[rank - 2], input_dim[rank - 1])); std::vector values_dim; - if (rank > 2) { - for (auto i = 0; i < rank - 1; i++) { - values_dim.emplace_back(input_dim[i]); - } - } else { - values_dim = {input_dim[1]}; + + for (auto i = 0; i < rank - 1; i++) { + values_dim.emplace_back(input_dim[i]); } ctx->SetOutputDim("Eigenvalues", framework::make_ddim(values_dim)); @@ -99,9 +96,9 @@ class EighGradOp : public framework::OperatorWithKernel { "EighGrad"); OP_INOUT_CHECK(ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", "EighGrad"); - OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvalues")), + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvalues")), "Input", "Eigenvalues@GRAD", "EighGrad"); - OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvectors")), + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvectors")), "Input", "Eigenvectors@GRAD", "EighGrad"); auto dims = ctx->GetInputDim("Eigenvectors"); auto x_grad_name = framework::GradVarName("X"); diff --git a/paddle/fluid/operators/eigh_op.cu b/paddle/fluid/operators/eigh_op.cu index cfc9eba4509596..61d2b66ea536d6 100644 --- a/paddle/fluid/operators/eigh_op.cu +++ b/paddle/fluid/operators/eigh_op.cu @@ -14,34 +14,14 @@ limitations under the License. */ #include "paddle/fluid/operators/eigh_op.h" -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class EighGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto input_var = ctx.Input("X"); - auto output_w_var = ctx.Output("Eigenvalues"); - auto output_v_var = ctx.Output("Eigenvectors"); - std::string lower = ctx.Attr("UPLO"); - bool is_lower = (lower == "L"); - math::MatrixEighFunctor functor; - functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true); - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; - REGISTER_OP_CUDA_KERNEL( - eigh, ops::EighGPUKernel, ops::EighGPUKernel, - ops::EighGPUKernel>, - ops::EighGPUKernel>); + eigh, ops::EighKernel, + ops::EighKernel, + ops::EighKernel>, + ops::EighKernel>); REGISTER_OP_CUDA_KERNEL( eigh_grad, diff --git a/paddle/fluid/operators/eigh_op.h b/paddle/fluid/operators/eigh_op.h index 0af38d44e54570..085e7531dd5232 100644 --- a/paddle/fluid/operators/eigh_op.h +++ b/paddle/fluid/operators/eigh_op.h @@ -22,24 +22,17 @@ namespace operators { using Tensor = framework::Tensor; -template -using EigenTensor = framework::EigenTensor; -template -using EigenVector = framework::EigenVector; - template class EighKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto input_var = ctx.Input("X"); - auto output_w_var = ctx.Output("Eigenvalues"); - auto output_v_var = ctx.Output("Eigenvectors"); + auto input = ctx.Input("X"); + auto output_w = ctx.Output("Eigenvalues"); + auto output_v = ctx.Output("Eigenvectors"); std::string lower = ctx.Attr("UPLO"); bool is_lower = (lower == "L"); - math::MatrixEighFunctorCPU functor; - functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true); + math::MatrixEighFunctor functor; + functor(ctx, *input, output_w, output_v, is_lower, true); } }; @@ -49,30 +42,30 @@ class EighGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto& x_grad = *ctx.Output(framework::GradVarName("X")); x_grad.mutable_data(ctx.GetPlace()); - auto& output_w_var = *ctx.Input("Eigenvalues"); - auto& output_v_var = *ctx.Input("Eigenvectors"); + auto& output_w = *ctx.Input("Eigenvalues"); + auto& output_v = *ctx.Input("Eigenvectors"); auto& output_w_grad = *ctx.Input(framework::GradVarName("Eigenvalues")); auto& output_v_grad = *ctx.Input(framework::GradVarName("Eigenvectors")); - auto& dims = output_v_var.dims(); + auto& dims = output_v.dims(); const int m = dims[dims.size() - 1]; auto dito = math::DeviceIndependenceTensorOperations( ctx); - auto tV = dito.Transpose(dito.Conj(output_v_var)); - auto W = dito.Sub_(dito.Unsqueeze(output_w_var, -2), - dito.Unsqueeze(output_w_var, -1)); + auto tV = dito.Transpose(dito.Conj(output_v)); + auto W = dito.template Sub(dito.Unsqueeze(output_w, -2), + dito.Unsqueeze(output_w, -1)); Tensor result = dito.Matmul(tV, output_v_grad); result.mutable_data(dims, ctx.GetPlace()); std::vector out_shape = framework::vectorize(dims); auto constant = dito.Fill(out_shape, 0.5); result = dito.Sub(result, dito.Conj(dito.Transpose(result))); result = dito.Mul(result, constant); - result = dito.Div_(result, W); + result = dito.Div(result, W); result = dito.DiagFill(m, m, m, 0, output_w_grad, result); - x_grad = dito.Matmul(output_v_var, dito.Matmul(result, tV)); + x_grad = dito.Matmul(output_v, dito.Matmul(result, tV)); } }; diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h index 4e2d180e336281..70a41f67d5ba83 100644 --- a/paddle/fluid/operators/math/eigen_values_vectors.h +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -26,87 +26,16 @@ namespace paddle { namespace operators { namespace math { -template -using EigenTensor = framework::EigenTensor; - template -using InputMatrixMap = Eigen::Map< - const Eigen::Matrix>; +using EigenMatrix = + Eigen::Matrix; -template -using OutputMatrixMap = Eigen::Map< - Eigen::Matrix>; - -template -inline void ComputeFloatEigenvaluesAndVectors(ValueType *x_data, - ValueType *eigenvalues_data, - ValueType *eigenvectors_data, - int batches, int rows, int cols, - bool has_vectors) { - int stride = rows * cols; - for (int i = 0; i < batches; i++) { - auto m = InputMatrixMap(x_data + i * stride, rows, cols); - auto eigenvalues = - OutputMatrixMap(eigenvalues_data + i * rows, 1, rows); - auto eigenvectors = - OutputMatrixMap(eigenvectors_data + i * stride, rows, cols); - - Eigen::SelfAdjointEigenSolver> - eigen_solver(m, has_vectors ? Eigen::ComputeEigenvectors - : Eigen::EigenvaluesOnly); - PADDLE_ENFORCE_EQ( - eigen_solver.info(), Eigen::Success, - platform::errors::InvalidArgument( - "Self Adjoint Eigen decomposition is not successful. " - "The %d-th input matrice might not be not be positive definite.", - i)); - - eigenvalues = eigen_solver.eigenvalues().transpose(); - if (has_vectors) { - eigenvectors = eigen_solver.eigenvectors().transpose(); - } - } -} +template +using InputMatrixMap = Eigen::Map>; -template -inline void ComputeComplexEigenvaluesAndVectors(T *x_data, - ValueType *eigenvalues_data, - T *eigenvectors_data, - int batches, int rows, int cols, - bool has_vectors) { - using Complex = std::complex; - Complex *input = reinterpret_cast(x_data); - Complex *eigenvectors_data_ = reinterpret_cast(eigenvectors_data); - - int stride = rows * cols; - for (int i = 0; i < batches; i++) { - auto m = InputMatrixMap(input + i * stride, rows, cols); - auto eigenvalues = - OutputMatrixMap(eigenvalues_data + i * rows, 1, rows); - auto eigenvectors = - OutputMatrixMap(eigenvectors_data_ + i * stride, rows, cols); - - Eigen::SelfAdjointEigenSolver< - Eigen::Matrix> - eigen_solver(m, has_vectors ? Eigen::ComputeEigenvectors - : Eigen::EigenvaluesOnly); - PADDLE_ENFORCE_EQ( - eigen_solver.info(), Eigen::Success, - platform::errors::InvalidArgument( - "Self Adjoint Eigen decomposition is not successful. " - "The %d-th input matrice might not be not be positive definite.", - i)); - - eigenvalues = eigen_solver.eigenvalues().transpose(); - if (has_vectors) { - eigenvectors = eigen_solver.eigenvectors().transpose(); - } - } -} +template +using OutputMatrixMap = Eigen::Map>; inline int64_t GetBatchSize(framework::DDim dims) { int64_t batch_size = 1; @@ -117,11 +46,18 @@ inline int64_t GetBatchSize(framework::DDim dims) { return batch_size; } +template +struct MatrixEighFunctor { + void operator()(const framework::ExecutionContext &ctx, const Tensor &input, + Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, + bool has_vectors); +}; + // Calculates the eigenvalues ​​and eigenvectors of Hermitian or real // symmetric matrices, and uses the variable has_vectors to // control whether to return the eigenvectors. -template -struct MatrixEighFunctorCPU { +template +struct MatrixEighFunctor { public: void operator()(const framework::ExecutionContext &ctx, const Tensor &input, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, @@ -134,7 +70,8 @@ struct MatrixEighFunctorCPU { for (int64_t i = 0; i < dim_size - 2; i++) { batch_size *= dims[i]; } - auto dito = DeviceIndependenceTensorOperations(ctx); + auto dito = + DeviceIndependenceTensorOperations(ctx); Tensor input_tensor; TensorCopy(input, ctx.GetPlace(), &input_tensor); if (!is_lower) { @@ -145,31 +82,61 @@ struct MatrixEighFunctorCPU { auto *value_data = eigen_values->mutable_data(output_value_dim, ctx.GetPlace()); - if (framework::IsComplexType(input_tensor.type())) { - auto *x_data = input_tensor.data(); - auto *vector_data = eigen_vectors->mutable_data(dims, ctx.GetPlace()); - ComputeComplexEigenvaluesAndVectors( - x_data, value_data, vector_data, batch_size, rows, rows, has_vectors); - } else { - auto *x_data = input_tensor.data(); - auto *vector_data = - eigen_vectors->mutable_data(dims, ctx.GetPlace()); - ComputeFloatEigenvaluesAndVectors( - x_data, value_data, vector_data, batch_size, rows, rows, has_vectors); - } - if (has_vectors) { - *eigen_vectors = dito.Transpose(*eigen_vectors); - } + auto *x_data = input_tensor.data(); + auto *vector_data = eigen_vectors->mutable_data(dims, ctx.GetPlace()); + ComputeEigenvaluesAndVectors(x_data, value_data, vector_data, batch_size, + rows, has_vectors); } + + inline void ComputeEigenvaluesAndVectors(T *x_data, ValueType *value_data, + T *vector_data, int batches, + int rows, bool has_vectors) const; }; +#define EIGEN_INSTANCE(ValueType, T, CastType) \ + template <> \ + inline void MatrixEighFunctor:: \ + ComputeEigenvaluesAndVectors(T *x_data, ValueType *value_data, \ + T *vector_data, int batches, int rows, \ + bool has_vectors) const { \ + int stride = rows * rows; \ + for (int i = 0; i < batches; i++) { \ + auto x_data_ = reinterpret_cast(x_data); \ + auto vector_data_ = reinterpret_cast(vector_data); \ + auto eigenvalues = \ + OutputMatrixMap(value_data + i * rows, 1, rows); \ + auto m = InputMatrixMap(x_data_ + i * stride, rows, rows); \ + auto eigenvectors = \ + OutputMatrixMap(vector_data_ + i * stride, rows, rows); \ + Eigen::SelfAdjointEigenSolver> eigen_solver( \ + m, \ + has_vectors ? Eigen::ComputeEigenvectors : Eigen::EigenvaluesOnly); \ + PADDLE_ENFORCE_EQ( \ + eigen_solver.info(), Eigen::Success, \ + platform::errors::InvalidArgument( \ + "Self Adjoint Eigen decomposition is not successful. " \ + "The %d-th input matrice might not be not be positive " \ + "definite.", \ + i)); \ + eigenvalues = eigen_solver.eigenvalues().transpose(); \ + if (has_vectors) { \ + eigenvectors = eigen_solver.eigenvectors(); \ + } \ + } \ + } + +EIGEN_INSTANCE(float, float, float); +EIGEN_INSTANCE(double, double, double); +EIGEN_INSTANCE(double, paddle::platform::complex, std::complex); +EIGEN_INSTANCE(float, paddle::platform::complex, std::complex); + #ifdef PADDLE_WITH_CUDA // Calculates the eigenvalues ​​and eigenvectors of Hermitian or real // symmetric matrices on GPU, and uses the variable has_vectors // to control whether to return the eigenvectors. template -struct MatrixEighFunctor { +struct MatrixEighFunctor { public: void operator()(const framework::ExecutionContext &ctx, const Tensor &input, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, @@ -278,7 +245,8 @@ struct MatrixEighFunctor { #define EVDBUFFER_INSTANCE(ValueType, T, C, CastType) \ template <> \ - inline void MatrixEighFunctor::EvdBuffer( \ + inline void \ + MatrixEighFunctor::EvdBuffer( \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, \ int *lwork) const { \ @@ -292,7 +260,8 @@ FUNC_WITH_TYPES(EVDBUFFER_INSTANCE); #define EVD_INSTANCE(ValueType, T, C, CastType) \ template <> \ - inline void MatrixEighFunctor::Evd( \ + inline void \ + MatrixEighFunctor::Evd( \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \ int lwork, int *devInfo) const { \ diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index 71d106c211f71a..1fa8c13973d836 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -289,10 +289,20 @@ struct DeviceIndependenceTensorOperations { framework::Tensor Div(const framework::Tensor& x, const framework::Tensor& y) { framework::Tensor ret; - std::vector out_shape = GetBroadcastShape({&x, &y}); - ret.Resize(framework::make_ddim(out_shape)); - ElementwiseComputeEx, DeviceContext, T>( - context, &x, &y, -1, DivFunctor(), &ret); + if (x.type() != y.type()) { + ret.mutable_data(x.dims(), context.GetPlace()); + auto x_vector = EigenVector::Flatten(x); + auto y_vector = EigenVector::Flatten(y); + auto out_vector = EigenVector::Flatten(ret); + auto& place = + *context.template device_context().eigen_device(); + out_vector.device(place) = x_vector / y_vector; + } else { + std::vector out_shape = GetBroadcastShape({&x, &y}); + ret.Resize(framework::make_ddim(out_shape)); + ElementwiseComputeEx, DeviceContext, T>( + context, &x, &y, -1, DivFunctor(), &ret); + } return ret; } framework::Tensor Add(const framework::Tensor& x, @@ -330,7 +340,7 @@ struct DeviceIndependenceTensorOperations { NameInTensorMap inputs({{"X", {&x}}}); return CreateOpRunAndReturnTensor("reduce_max", inputs, attrs, out_dim); } - + template framework::Tensor Sub(const framework::Tensor& x, const framework::Tensor& y) { framework::Tensor ret; @@ -340,18 +350,18 @@ struct DeviceIndependenceTensorOperations { #if defined(__NVCC__) || defined(__HIPCC__) // For GPU, there is no need to define XxxInverseFunctor and call // ElementwiseComputeEx in two branches. - ElementwiseComputeEx, DeviceContext, T>( - context, &x, &y, -1, SubFunctor(), &ret); + ElementwiseComputeEx, DeviceContext, T1>( + context, &x, &y, -1, SubFunctor(), &ret); #endif } else { if (x.dims().size() >= y.dims().size()) { - ElementwiseComputeEx, DeviceContext, T>( - context, &x, &y, -1, SubFunctor(), &ret); + ElementwiseComputeEx, DeviceContext, T1>( + context, &x, &y, -1, SubFunctor(), &ret); } else { - ElementwiseComputeEx, DeviceContext, T>( + ElementwiseComputeEx, DeviceContext, T1>( // This is copyed from elementwise_sub, which means we // need reverse will xrank < yrank - context, &x, &y, -1, InverseSubFunctor(), &ret); + context, &x, &y, -1, InverseSubFunctor(), &ret); } } return ret; @@ -461,36 +471,23 @@ struct DeviceIndependenceTensorOperations { return out; } - // Support x and y are different data types - Tensor Div_(const Tensor& x, const Tensor& y) { - Tensor out; - out.mutable_data(x.dims(), context.GetPlace()); - auto x_vector = EigenVector::Flatten(x); - auto y_vector = EigenVector::Flatten(y); - auto out_vector = EigenVector::Flatten(out); - auto& place = - *context.template device_context().eigen_device(); - out_vector.device(place) = x_vector / y_vector; - return out; - } - - framework::Tensor Sub_(const framework::Tensor& x, - const framework::Tensor& y) { - framework::Tensor ret; - std::vector out_shape = GetBroadcastShape({&x, &y}); - ret.Resize(framework::make_ddim(out_shape)); - if (x.dims().size() >= y.dims().size()) { - ElementwiseComputeEx, DeviceContext, ValueType>( - context, &x, &y, -1, SubFunctor(), &ret); - } else { - ElementwiseComputeEx, DeviceContext, - ValueType>( - // This is copyed from elementwise_sub, which means we - // need reverse will xrank < yrank - context, &x, &y, -1, InverseSubFunctor(), &ret); - } - return ret; - } + // framework::Tensor Sub_(const framework::Tensor& x, + // const framework::Tensor& y) { + // framework::Tensor ret; + // std::vector out_shape = GetBroadcastShape({&x, &y}); + // ret.Resize(framework::make_ddim(out_shape)); + // if (x.dims().size() >= y.dims().size()) { + // ElementwiseComputeEx, DeviceContext, ValueType>( + // context, &x, &y, -1, SubFunctor(), &ret); + // } else { + // ElementwiseComputeEx, DeviceContext, + // ValueType>( + // // This is copyed from elementwise_sub, which means we + // // need reverse will xrank < yrank + // context, &x, &y, -1, InverseSubFunctor(), &ret); + // } + // return ret; + // } private: const framework::ExecutionContext& context; From 0ce25840961afb687e2b46e506445b4080d6f675 Mon Sep 17 00:00:00 2001 From: Zjq9409 <15205085056@163.com> Date: Fri, 17 Sep 2021 02:48:11 +0000 Subject: [PATCH 2/3] Modify the macro implementation --- .../operators/math/eigen_values_vectors.h | 71 ++++++++++--------- .../fluid/tests/unittests/test_eigh_op.py | 4 +- 2 files changed, 39 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h index 70a41f67d5ba83..20da7c85abf001 100644 --- a/paddle/fluid/operators/math/eigen_values_vectors.h +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -93,42 +93,45 @@ struct MatrixEighFunctor { int rows, bool has_vectors) const; }; -#define EIGEN_INSTANCE(ValueType, T, CastType) \ - template <> \ - inline void MatrixEighFunctor:: \ - ComputeEigenvaluesAndVectors(T *x_data, ValueType *value_data, \ - T *vector_data, int batches, int rows, \ - bool has_vectors) const { \ - int stride = rows * rows; \ - for (int i = 0; i < batches; i++) { \ - auto x_data_ = reinterpret_cast(x_data); \ - auto vector_data_ = reinterpret_cast(vector_data); \ - auto eigenvalues = \ - OutputMatrixMap(value_data + i * rows, 1, rows); \ - auto m = InputMatrixMap(x_data_ + i * stride, rows, rows); \ - auto eigenvectors = \ - OutputMatrixMap(vector_data_ + i * stride, rows, rows); \ - Eigen::SelfAdjointEigenSolver> eigen_solver( \ - m, \ - has_vectors ? Eigen::ComputeEigenvectors : Eigen::EigenvaluesOnly); \ - PADDLE_ENFORCE_EQ( \ - eigen_solver.info(), Eigen::Success, \ - platform::errors::InvalidArgument( \ - "Self Adjoint Eigen decomposition is not successful. " \ - "The %d-th input matrice might not be not be positive " \ - "definite.", \ - i)); \ - eigenvalues = eigen_solver.eigenvalues().transpose(); \ - if (has_vectors) { \ - eigenvectors = eigen_solver.eigenvectors(); \ - } \ - } \ +#define EIGEN_WITH_TYPES(m) \ + m(float, float, float) m(double, double, double) \ + m(float, paddle::platform::complex, std::complex) \ + m(double, paddle::platform::complex, std::complex) + +#define EIGEN_INSTANCE(ValueType, T, CastType) \ + template <> \ + inline void MatrixEighFunctor:: \ + ComputeEigenvaluesAndVectors(T *x_data, ValueType *value_data, \ + T *vector_data, int batches, int rows, \ + bool has_vectors) const { \ + int stride = rows * rows; \ + for (int i = 0; i < batches; i++) { \ + auto x_data_ = reinterpret_cast(x_data); \ + auto vector_data_ = reinterpret_cast(vector_data); \ + auto eigenvalues = \ + OutputMatrixMap(value_data + i * rows, 1, rows); \ + auto m = InputMatrixMap(x_data_ + i * stride, rows, rows); \ + auto eigenvectors = \ + OutputMatrixMap(vector_data_ + i * stride, rows, rows); \ + Eigen::SelfAdjointEigenSolver> \ + eigen_solver(m, has_vectors ? Eigen::ComputeEigenvectors \ + : Eigen::EigenvaluesOnly); \ + PADDLE_ENFORCE_EQ( \ + eigen_solver.info(), Eigen::Success, \ + platform::errors::InvalidArgument( \ + "Self Adjoint Eigen decomposition is not successful. " \ + "The %d-th input matrice might not be not be positive " \ + "definite.", \ + i)); \ + eigenvalues = eigen_solver.eigenvalues().transpose(); \ + if (has_vectors) { \ + eigenvectors = eigen_solver.eigenvectors(); \ + } \ + } \ } -EIGEN_INSTANCE(float, float, float); -EIGEN_INSTANCE(double, double, double); -EIGEN_INSTANCE(double, paddle::platform::complex, std::complex); -EIGEN_INSTANCE(float, paddle::platform::complex, std::complex); +EIGEN_WITH_TYPES(EIGEN_INSTANCE); #ifdef PADDLE_WITH_CUDA diff --git a/python/paddle/fluid/tests/unittests/test_eigh_op.py b/python/paddle/fluid/tests/unittests/test_eigh_op.py index e4343647025255..8e8c9df199f142 100644 --- a/python/paddle/fluid/tests/unittests/test_eigh_op.py +++ b/python/paddle/fluid/tests/unittests/test_eigh_op.py @@ -140,7 +140,7 @@ def test_in_static_mode(self): self.check_static_complex_result() def test_in_dynamic_mode(self): - paddle.disable_static(self.place) + paddle.disable_static() input_real_data = paddle.to_tensor(self.real_data) expected_w, expected_v = np.linalg.eigh(self.real_data) actual_w, actual_v = paddle.linalg.eigh(input_real_data) @@ -152,7 +152,7 @@ def test_in_dynamic_mode(self): self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v) def test_eigh_grad(self): - paddle.disable_static(self.place) + paddle.disable_static() x = paddle.to_tensor(self.complex_data, stop_gradient=False) w, v = paddle.linalg.eigh(x) (w.sum() + paddle.abs(v).sum()).backward() From da51c6c4c5ed431869ee7c055f905b24d8c13634 Mon Sep 17 00:00:00 2001 From: Zjq9409 <15205085056@163.com> Date: Fri, 17 Sep 2021 02:48:11 +0000 Subject: [PATCH 3/3] Modify the macro implementation --- paddle/fluid/operators/math/eigen_values_vectors.h | 14 ++++++++------ .../paddle/fluid/tests/unittests/test_eigh_op.py | 4 ++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h index 70a41f67d5ba83..a93e5b38564566 100644 --- a/paddle/fluid/operators/math/eigen_values_vectors.h +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -93,6 +93,11 @@ struct MatrixEighFunctor { int rows, bool has_vectors) const; }; +#define EIGEN_WITH_TYPES(m) \ + m(float, float, float) m(double, double, double) \ + m(float, paddle::platform::complex, std::complex) \ + m(double, paddle::platform::complex, std::complex) + #define EIGEN_INSTANCE(ValueType, T, CastType) \ template <> \ inline void MatrixEighFunctor:: \ @@ -101,8 +106,8 @@ struct MatrixEighFunctor { bool has_vectors) const { \ int stride = rows * rows; \ for (int i = 0; i < batches; i++) { \ - auto x_data_ = reinterpret_cast(x_data); \ - auto vector_data_ = reinterpret_cast(vector_data); \ + CastType *x_data_ = reinterpret_cast(x_data); \ + CastType *vector_data_ = reinterpret_cast(vector_data); \ auto eigenvalues = \ OutputMatrixMap(value_data + i * rows, 1, rows); \ auto m = InputMatrixMap(x_data_ + i * stride, rows, rows); \ @@ -125,10 +130,7 @@ struct MatrixEighFunctor { } \ } -EIGEN_INSTANCE(float, float, float); -EIGEN_INSTANCE(double, double, double); -EIGEN_INSTANCE(double, paddle::platform::complex, std::complex); -EIGEN_INSTANCE(float, paddle::platform::complex, std::complex); +EIGEN_WITH_TYPES(EIGEN_INSTANCE); #ifdef PADDLE_WITH_CUDA diff --git a/python/paddle/fluid/tests/unittests/test_eigh_op.py b/python/paddle/fluid/tests/unittests/test_eigh_op.py index e4343647025255..8e8c9df199f142 100644 --- a/python/paddle/fluid/tests/unittests/test_eigh_op.py +++ b/python/paddle/fluid/tests/unittests/test_eigh_op.py @@ -140,7 +140,7 @@ def test_in_static_mode(self): self.check_static_complex_result() def test_in_dynamic_mode(self): - paddle.disable_static(self.place) + paddle.disable_static() input_real_data = paddle.to_tensor(self.real_data) expected_w, expected_v = np.linalg.eigh(self.real_data) actual_w, actual_v = paddle.linalg.eigh(input_real_data) @@ -152,7 +152,7 @@ def test_in_dynamic_mode(self): self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v) def test_eigh_grad(self): - paddle.disable_static(self.place) + paddle.disable_static() x = paddle.to_tensor(self.complex_data, stop_gradient=False) w, v = paddle.linalg.eigh(x) (w.sum() + paddle.abs(v).sum()).backward()