diff --git a/CMakeLists.txt b/CMakeLists.txt index 1aa50c14d7e251..351bb33583450c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,8 @@ project(paddle CXX C) # enable language CUDA # TODO(Shibo Tao): remove find_package(CUDA) completely. find_package(CUDA QUIET) +find_package(MKL QUIET) +option(WITH_ONEMKL "Compile PaddlePaddle with oneMKL" ${MKL_FOUND}) option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF) option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF) @@ -371,6 +373,10 @@ if (WITH_MIPS) add_definitions(-DPADDLE_WITH_MIPS) endif() +if (WITH_ONEMKL) + add_definitions(-DPADDLE_WITH_ONEMKL) +endif() + if (WITH_HETERPS) if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new") diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ca388ad94db354..585da3c0280652 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -91,21 +91,35 @@ if (WITH_GPU OR WITH_ROCM) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc) else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) - find_library(CUFFT_LIB libcufft.so - PATHS - ${CUDA_TOOLKIT_ROOT_DIR}/lib64/ - NO_DEFAULT_PATH - ) - op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS ${OP_HEADER_DEPS}) - target_link_libraries(spectral_op ${CUFFT_LIB}) endif() op_library(sync_batch_norm_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n") else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) - if(WITH_ONEMKL) - target_link_libraries(spectral_op MKL::MKL) - endif() +endif() + +op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS ${OP_HEADER_DEPS}) +if (WITH_GPU) + find_library(CUFFT_LIB libcufft.so + PATHS + ${CUDA_TOOLKIT_ROOT_DIR}/lib64/ + NO_DEFAULT_PATH + ) + target_link_libraries(spectral_op ${CUFFT_LIB}) +endif() +if(WITH_ONEMKL) + find_library(ONEMKL_CORE libmkl_core.so + PATHS + ${MKL_ROOT}/lib/${MKL_ARCH} + NO_DEFAULT_PATH + ) + find_library(ONEMKL_THREAD libmkl_intel_thread.so + PATHS + ${MKL_ROOT}/lib/${MKL_ARCH} + NO_DEFAULT_PATH + ) + target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE}) + target_link_libraries(spectral_op MKL::mkl_core MKL::mkl_intel_thread) endif() op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) diff --git a/paddle/fluid/operators/spectral_op.cc b/paddle/fluid/operators/spectral_op.cc index 7e508cfb0842a8..982128e38590eb 100644 --- a/paddle/fluid/operators/spectral_op.cc +++ b/paddle/fluid/operators/spectral_op.cc @@ -15,14 +15,17 @@ #include "paddle/fluid/operators/spectral_op.h" #include +#include +#include #include #include #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/complex.h" #if defined(PADDLE_WITH_ONEMKL) -// #include "mkl_dfti.h" +#include // #include "mkl_service.h" #elif defined(PADDLE_WITH_POCKETFFT) #include "extern_pocketfft/pocketfft_hdronly.h" @@ -349,34 +352,266 @@ T compute_factor(int64_t size, FFTNormMode normalization) { ////////////////// Functors #if defined(PADDLE_WITH_ONEMKL) -template -struct FFTC2CFunctor { + +static inline void MKL_DFTI_CHECK(MKL_INT status) { + if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) { + PADDLE_THROW(DftiErrorMessage(status)); + } +} + +struct DftiDescriptorDeleter { + void operator()(DFTI_DESCRIPTOR_HANDLE handle) { + if (handle != nullptr) { + MKL_DFTI_CHECK(DftiFreeDescriptor(&handle)); + } + } +}; + +class DftiDescriptor { + public: + void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, + MKL_LONG signal_ndim, MKL_LONG* sizes) { + if (desc_ != nullptr) { + PADDLE_THROW("DFT DESCRIPTOR can only be initialized once."); + } + DFTI_DESCRIPTOR* raw_desc; + if (signal_ndim == 1) { + MKL_DFTI_CHECK( + DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0])); + } else { + MKL_DFTI_CHECK( + DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0])); + } + desc_.reset(raw_desc); + } + + DFTI_DESCRIPTOR* get() const { + if (desc_ == nullptr) { + PADDLE_THROW("DFTI DESCRIPTOR has not been initialized."); + } + return desc_.get(); + } + + private: + std::unique_ptr desc_; +}; + +DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, + const framework::proto::VarType::Type& out_dtype, + const framework::DDim& in_strides, + const framework::DDim& out_strides, + const std::vector& signal_sizes, + FFTNormMode normalization, bool forward) { + const DFTI_CONFIG_VALUE precision = [&] { + switch (in_dtype) { + case framework::proto::VarType::FP32: + return DFTI_SINGLE; + case framework::proto::VarType::COMPLEX64: + return DFTI_SINGLE; + case framework::proto::VarType::FP64: + return DFTI_DOUBLE; + case framework::proto::VarType::COMPLEX128: + return DFTI_SINGLE; + default: + PADDLE_THROW("MKL DFT does not support."); + } + }(); + + const bool complex_input = framework::IsComplexType(in_dtype); + const bool complex_output = framework::IsComplexType(out_dtype); + const DFTI_CONFIG_VALUE domain = [&] { + if (forward) { + return complex_input ? DFTI_COMPLEX : DFTI_REAL; + } else { + return complex_output ? DFTI_COMPLEX : DFTI_REAL; + } + }(); + + DftiDescriptor descriptor; ///// + std::vector fft_sizes(signal_sizes.cbegin(), signal_sizes.cend()); + const MKL_LONG signal_ndim = fft_sizes.size() - 1; + descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1); + + // placement inplace? + MKL_DFTI_CHECK( + DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); + + // number of transformation + const MKL_LONG batch_size = fft_sizes[0]; + MKL_DFTI_CHECK( + DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size)); + + // input & output distance + const MKL_LONG idist = in_strides[0]; + const MKL_LONG odist = out_strides[0]; + MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist)); + MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist)); + + // input & output stride + std::vector mkl_in_stride(1 + signal_ndim, 0); + std::vector mkl_out_stride(1 + signal_ndim, 0); + for (MKL_LONG i = 1; i <= signal_ndim; i++) { + mkl_in_stride[i] = in_strides[i]; + mkl_out_stride[i] = out_strides[i]; + } + MKL_DFTI_CHECK( + DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride)); + MKL_DFTI_CHECK( + DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride)); + + // conjugate even storage + if (!complex_input || !complex_output) { + MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, + DFTI_COMPLEX_COMPLEX)); + } + + MKL_LONG signal_numel = + std::accumulate(fft_sizes.cbegin() + 1, fft_sizes.cend(), 1UL, + std::multiplies()); + if (normalization != FFTNormMode::none) { + const double scale = + ((normalization == FFTNormMode::by_sqrt_n) + ? 1.0 / std::sqrt(static_cast(signal_numel)) + : 1.0 / static_cast(signal_numel)); + const auto scale_direction = + forward ? DFTI_FORWARD_SCALE : DFTI_BACKWARD_SCALE; + MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale)); + } + + // commit the descriptor + MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get())); + return descriptor; +} + +// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r) +template +void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, + const std::vector& axes, FFTNormMode normalization, + bool forward) { + const framework::DDim& in_sizes = x->dims(); + const int ndim = in_sizes.size(); + const int signal_ndim = axes.size(); + const int batch_ndim = ndim - signal_ndim; + const framework::DDim& out_sizes = out->dims(); + + // make a dim permutation + std::vector dim_permute(ndim); + std::iota(dim_permute.begin(), dim_permute.end(), 0); + std::vector is_transformed_dim(ndim, false); + for (const auto& d : axes) { + is_transformed_dim[d] = true; + } + const auto batch_end = + std::partition(dim_permute.begin(), dim_permute.end(), + [&](size_t axis) { return !is_transformed_dim[axis]; }); + std::copy(axes.cbegin(), axes.cend(), batch_end); + + // transpose input according to that permutation + framework::DDim transposed_input_shape = in_sizes.transpose(dim_permute); + std::vector transposed_input_shape_ = + framework::vectorize(transposed_input_shape); + framework::Tensor transposed_input; + transposed_input.Resize(transposed_input_shape); + const auto place = ctx.GetPlace(); + transposed_input.mutable_data(place); + TransCompute(ndim, ctx, *x, &transposed_input, + dim_permute); + + // make an collapsed input: collapse batch axes for input + const int batch_size = std::accumulate( + transposed_input_shape.Get(), transposed_input_shape.Get() + batch_ndim, + 1L, std::multiplies()); + std::vector collapsed_input_shape_(1 + signal_ndim); + collapsed_input_shape_[0] = batch_size; + std::copy(transposed_input_shape_.begin() + batch_ndim, + transposed_input_shape_.end(), collapsed_input_shape_.begin() + 1); + const framework::DDim collapsed_input_shape = + framework::make_ddim(collapsed_input_shape_); + transposed_input.Resize(collapsed_input_shape); + framework::Tensor& collapsed_input = transposed_input; + + // make a collapsed output + std::vector collapsed_output_shape_(1 + signal_ndim); + collapsed_output_shape_[0] = batch_size; + for (int i = 0; i < signal_ndim; i++) { + collapsed_output_shape_[1 + i] = out_sizes[axes[i]]; + } + const framework::DDim collapsed_output_shape = + framework::make_ddim(collapsed_output_shape_); + framework::Tensor collapsed_output; + collapsed_output.Resize(collapsed_output_shape); + collapsed_output.mutable_data(place, out->type()); + + // signal sizes + std::vector signal_sizes(1 + signal_ndim); + signal_sizes[0] = batch_size; + for (int i = 0; i < signal_ndim; i++) { + signal_sizes[1 + i] = + std::max(collapsed_input_shape[1 + i], collapsed_output_shape[1 + i]); + } + + // input & output stride + const framework::DDim input_stride = framework::stride(collapsed_input_shape); + const framework::DDim output_stride = + framework::stride(collapsed_output_shape); + + // make a DFTI_DESCRIPTOR + DftiDescriptor desc = + _plan_mkl_fft(x->type(), out->type(), input_stride, output_stride, + signal_sizes, normalization, forward); + if (forward) { + MKL_DFTI_CHECK(DftiComputeForward(desc.get(), collapsed_input.data(), + collapsed_output.data())); + } else { + MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), collapsed_input.data(), + collapsed_output.data())); + } + + // resize for the collapsed output + framework::DDim transposed_output_shape = out_sizes.transpose(dim_permute); + collapsed_output.Resize(transposed_output_shape); + framework::Tensor& transposed_output = collapsed_output; + + // reverse the transposition + std::vector reverse_dim_permute(ndim); + for (int i = 0; i < ndim; i++) { + reverse_dim_permute[dim_permute[i]] = i; + } + TransCompute(ndim, ctx, transposed_output, + out, reverse_dim_permute); +} + +template +struct FFTC2CFunctor { void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x, Tensor* out, const std::vector& axes, - FFTNormMode normalization, bool forward) {} + FFTNormMode normalization, bool forward) { + exec_fft(ctx, x, out, axes, + normalization, forward); + } }; -template -struct FFTR2CFunctor { +template +struct FFTR2CFunctor { void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x, Tensor* out, const std::vector& axes, FFTNormMode normalization, bool forward, bool onesided) {} }; -template -struct FFTC2RFunctor { +template +struct FFTC2RFunctor { void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x, Tensor* out, const std::vector& axes, FFTNormMode normalization, bool forward) {} }; #elif defined(PADDLE_WITH_POCKETFFT) -template -struct FFTC2CFunctor { +template +struct FFTC2CFunctor { void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x, Tensor* out, const std::vector& axes, FFTNormMode normalization, bool forward) { - using R = typename T::value_type; + using R = typename Ti::value_type; using C = std::complex; const auto& input_dim = x->dims(); @@ -388,12 +623,12 @@ struct FFTC2CFunctor { std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(), [](int64_t s) { return s * data_size; }); - const auto* in_data = reinterpret_cast(x->data()); - auto* out_data = reinterpret_cast(out->data()); + const auto* in_data = reinterpret_cast(x->data()); + auto* out_data = reinterpret_cast(out->data()); // well, we have to use std::vector here std::vector axes_(axes.size()); std::copy(axes.begin(), axes.end(), axes_.begin()); - // compuet facet + // compuet factor int64_t signal_numel = 1; for (auto i : axes) { signal_numel *= in_sizes[i]; @@ -404,12 +639,12 @@ struct FFTC2CFunctor { } }; -template -struct FFTR2CFunctor { +template +struct FFTR2CFunctor { void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x, Tensor* out, const std::vector& axes, FFTNormMode normalization, bool forward, bool onesided) { - using R = typename T::value_type; + using R = Ti; using C = std::complex; const auto& input_dim = x->dims(); @@ -436,7 +671,7 @@ struct FFTR2CFunctor { } const auto* in_data = x->data(); - auto* out_data = reinterpret_cast(out->data()); + auto* out_data = reinterpret_cast(out->data()); // well, we have to use std::vector here std::vector axes_(axes.size()); std::copy(axes.begin(), axes.end(), axes_.begin()); @@ -451,12 +686,12 @@ struct FFTR2CFunctor { } }; -template -struct FFTC2RFunctor { +template +struct FFTC2RFunctor { void operator()(const platform::CPUDeviceContext& ctx, const Tensor* x, Tensor* out, const std::vector& axes, FFTNormMode normalization, bool forward) { - using R = typename T::value_type; + using R = To; using C = std::complex; const auto& input_dim = x->dims(); @@ -482,7 +717,7 @@ struct FFTC2RFunctor { [](int64_t s) { return s * data_size; }); } - const auto* in_data = reinterpret_cast(x->data()); + const auto* in_data = reinterpret_cast(x->data()); auto* out_data = out->data(); // well, we have to use std::vector here std::vector axes_(axes.size()); @@ -499,13 +734,6 @@ struct FFTC2RFunctor { }; #endif -// mkl fft for all cases -void exec_fft(const Tensor* x, Tensor* out, const std::vector& out_dim, - int64_t normalization, bool forward) { - // construct the descriptor - - // compute -} // namespace anonymous } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/spectral_op.h b/paddle/fluid/operators/spectral_op.h index a55f28103669aa..278253b7ef2d5c 100644 --- a/paddle/fluid/operators/spectral_op.h +++ b/paddle/fluid/operators/spectral_op.h @@ -28,21 +28,72 @@ enum class FFTNormMode : int64_t { FFTNormMode get_norm_from_string(const std::string& norm, bool forward); -template +// Enum representing the FFT type +enum class FFTTransformType : int8_t { + C2C, // Complex-to-complex + R2C, // Real-to-complex + C2R, // Complex-to-real +}; + +// Create transform type enum from bools representing if input and output are +// complex +inline FFTTransformType GetFFTTransformType( + framework::proto::VarType::Type input_dtype, + framework::proto::VarType::Type output_dtype) { + auto complex_input = framework::IsComplexType(input_dtype); + auto complex_output = framework::IsComplexType(output_dtype); + if (complex_input && complex_output) { + return FFTTransformType::C2C; + } else if (complex_input && !complex_output) { + return FFTTransformType::C2R; + } else if (!complex_input && complex_output) { + return FFTTransformType::R2C; + } + PADDLE_THROW( + platform::errors::InvalidArgument("Real to real FFTs are not supported")); +} + +// Returns true if the transform type has complex input +inline bool has_complex_input(FFTTransformType type) { + switch (type) { + case FFTTransformType::C2C: + case FFTTransformType::C2R: + return true; + + case FFTTransformType::R2C: + return false; + } + PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType")); +} + +// Returns true if the transform type has complex output +inline bool has_complex_output(FFTTransformType type) { + switch (type) { + case FFTTransformType::C2C: + case FFTTransformType::R2C: + return true; + + case FFTTransformType::C2R: + return false; + } + PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType")); +} + +template struct FFTC2CFunctor { void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out, const std::vector& axes, FFTNormMode normalization, bool forward); }; -template +template struct FFTR2CFunctor { void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out, const std::vector& axes, FFTNormMode normalization, bool forward, bool onesided); }; -template +template struct FFTC2RFunctor { void operator()(const DeviceContext& ctx, const Tensor* X, Tensor* out, const std::vector& axes, FFTNormMode normalization, @@ -53,7 +104,7 @@ template class FFTC2CKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using U = paddle::platform::complex; + using C = paddle::platform::complex; auto& dev_ctx = ctx.device_context(); auto axes = ctx.Attr>("axes"); @@ -62,10 +113,10 @@ class FFTC2CKernel : public framework::OpKernel { const auto* x = ctx.Input("X"); auto* y = ctx.Output("Out"); - y->mutable_data(ctx.GetPlace()); + y->mutable_data(ctx.GetPlace()); auto normalization = get_norm_from_string(norm_str, forward); - FFTC2CFunctor fft_c2c_func; + FFTC2CFunctor fft_c2c_func; fft_c2c_func(dev_ctx, x, y, axes, normalization, forward); } }; @@ -74,7 +125,7 @@ template class FFTC2CGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using U = paddle::platform::complex; + using C = paddle::platform::complex; auto& dev_ctx = ctx.device_context(); auto axes = ctx.Attr>("axes"); @@ -83,70 +134,19 @@ class FFTC2CGradKernel : public framework::OpKernel { const auto* dy = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); - dx->mutable_data(ctx.GetPlace()); + dx->mutable_data(ctx.GetPlace()); auto normalization = get_norm_from_string(norm_str, forward); - FFTC2CFunctor fft_c2c_func; + FFTC2CFunctor fft_c2c_func; fft_c2c_func(dev_ctx, dy, dx, axes, normalization, forward); } }; -// Enum representing the FFT type -enum class FFTTransformType : int8_t { - C2C, // Complex-to-complex - R2C, // Real-to-complex - C2R, // Complex-to-real -}; - -// Create transform type enum from bools representing if input and output are -// complex -inline FFTTransformType GetFFTTransformType( - framework::proto::VarType::Type input_dtype, - framework::proto::VarType::Type output_dtype) { - auto complex_input = framework::IsComplexType(input_dtype); - auto complex_output = framework::IsComplexType(output_dtype); - if (complex_input && complex_output) { - return FFTTransformType::C2C; - } else if (complex_input && !complex_output) { - return FFTTransformType::C2R; - } else if (!complex_input && complex_output) { - return FFTTransformType::R2C; - } - PADDLE_THROW( - platform::errors::InvalidArgument("Real to real FFTs are not supported")); -} - -// Returns true if the transform type has complex input -inline bool has_complex_input(FFTTransformType type) { - switch (type) { - case FFTTransformType::C2C: - case FFTTransformType::C2R: - return true; - - case FFTTransformType::R2C: - return false; - } - PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType")); -} - -// Returns true if the transform type has complex output -inline bool has_complex_output(FFTTransformType type) { - switch (type) { - case FFTTransformType::C2C: - case FFTTransformType::R2C: - return true; - - case FFTTransformType::C2R: - return false; - } - PADDLE_THROW(platform::errors::InvalidArgument("Unknown FFTTransformType")); -} - template class FFTR2CKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using U = paddle::platform::complex; + using C = paddle::platform::complex; auto& dev_ctx = ctx.device_context(); auto axes = ctx.Attr>("axes"); @@ -156,10 +156,10 @@ class FFTR2CKernel : public framework::OpKernel { const auto* x = ctx.Input("X"); auto* y = ctx.Output("Out"); - y->mutable_data(ctx.GetPlace()); + y->mutable_data(ctx.GetPlace()); auto normalization = get_norm_from_string(norm_str, forward); - FFTR2CFunctor fft_r2c_func; + FFTR2CFunctor fft_r2c_func; fft_r2c_func(dev_ctx, x, y, axes, normalization, forward, onesided); } }; @@ -168,7 +168,7 @@ template class FFTR2CGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using U = paddle::platform::complex; + using C = paddle::platform::complex; auto& dev_ctx = ctx.device_context(); auto axes = ctx.Attr>("axes"); @@ -178,10 +178,10 @@ class FFTR2CGradKernel : public framework::OpKernel { const auto* dy = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); - dx->mutable_data(ctx.GetPlace()); + dx->mutable_data(ctx.GetPlace()); auto normalization = get_norm_from_string(norm_str, forward); - FFTC2RFunctor fft_c2r_func; + FFTC2RFunctor fft_c2r_func; fft_c2r_func(dev_ctx, dy, dx, axes, normalization, forward); } }; @@ -190,7 +190,7 @@ template class FFTC2RKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using U = paddle::platform::complex; + using C = paddle::platform::complex; auto& dev_ctx = ctx.device_context(); auto axes = ctx.Attr>("axes"); @@ -202,7 +202,7 @@ class FFTC2RKernel : public framework::OpKernel { y->mutable_data(ctx.GetPlace()); auto normalization = get_norm_from_string(norm_str, forward); - FFTC2RFunctor fft_c2r_func; + FFTC2RFunctor fft_c2r_func; fft_c2r_func(dev_ctx, x, y, axes, normalization, forward); } }; @@ -211,7 +211,7 @@ template class FFTC2RGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using U = paddle::platform::complex; + using C = paddle::platform::complex; auto& dev_ctx = ctx.device_context(); auto axes = ctx.Attr>("axes"); @@ -221,10 +221,10 @@ class FFTC2RGradKernel : public framework::OpKernel { const auto* dy = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); - dx->mutable_data(ctx.GetPlace()); + dx->mutable_data(ctx.GetPlace()); auto normalization = get_norm_from_string(norm_str, forward); - FFTR2CFunctor fft_r2c_func; + FFTR2CFunctor fft_r2c_func; fft_r2c_func(dev_ctx, dy, dx, axes, normalization, forward, onesided); } };