From 9c3196ef6ccf26af52a5c066816c63144cfd2ee5 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Wed, 15 Sep 2021 18:00:36 +0800 Subject: [PATCH 1/7] use std::ptrdiff_t as datatype of stride (instead of int64_t) to avoid argument mismatch on some platforms. --- paddle/fluid/operators/spectral_op.cc | 30 +++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/spectral_op.cc b/paddle/fluid/operators/spectral_op.cc index 5fb1f8673df01f..b9359f48c6f295 100644 --- a/paddle/fluid/operators/spectral_op.cc +++ b/paddle/fluid/operators/spectral_op.cc @@ -681,11 +681,11 @@ struct FFTC2CFunctor { const auto& input_dim = x->dims(); const std::vector in_sizes = framework::vectorize(input_dim); - std::vector in_strides = - framework::vectorize(framework::stride(input_dim)); + std::vector in_strides = + framework::vectorize(framework::stride(input_dim)); const int64_t data_size = sizeof(C); std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(), - [](int64_t s) { return s * data_size; }); + [](std::ptrdiff_t s) { return s * data_size; }); const auto* in_data = reinterpret_cast(x->data()); auto* out_data = reinterpret_cast(out->data()); @@ -714,24 +714,24 @@ struct FFTR2CFunctor { const auto& input_dim = x->dims(); const std::vector in_sizes = framework::vectorize(input_dim); - std::vector in_strides = - framework::vectorize(framework::stride(input_dim)); + std::vector in_strides = + framework::vectorize(framework::stride(input_dim)); { const int64_t data_size = sizeof(R); std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(), - [](int64_t s) { return s * data_size; }); + [](std::ptrdiff_t s) { return s * data_size; }); } const auto& output_dim = out->dims(); const std::vector out_sizes = framework::vectorize(output_dim); - std::vector out_strides = - framework::vectorize(framework::stride(output_dim)); + std::vector out_strides = + framework::vectorize(framework::stride(output_dim)); { const int64_t data_size = sizeof(C); std::transform(out_strides.begin(), out_strides.end(), out_strides.begin(), - [](int64_t s) { return s * data_size; }); + [](std::ptrdiff_t s) { return s * data_size; }); } const auto* in_data = x->data(); @@ -761,24 +761,24 @@ struct FFTC2RFunctor { const auto& input_dim = x->dims(); const std::vector in_sizes = framework::vectorize(input_dim); - std::vector in_strides = - framework::vectorize(framework::stride(input_dim)); + std::vector in_strides = + framework::vectorize(framework::stride(input_dim)); { const int64_t data_size = sizeof(C); std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(), - [](int64_t s) { return s * data_size; }); + [](std::ptrdiff_t s) { return s * data_size; }); } const auto& output_dim = out->dims(); const std::vector out_sizes = framework::vectorize(output_dim); - std::vector out_strides = - framework::vectorize(framework::stride(output_dim)); + std::vector out_strides = + framework::vectorize(framework::stride(output_dim)); { const int64_t data_size = sizeof(R); std::transform(out_strides.begin(), out_strides.end(), out_strides.begin(), - [](int64_t s) { return s * data_size; }); + [](std::ptrdiff_t s) { return s * data_size; }); } const auto* in_data = reinterpret_cast(x->data()); From 5c1d9aad093c232d03dbb7619dffb812f65d2fdb Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Wed, 15 Sep 2021 19:50:27 +0800 Subject: [PATCH 2/7] add complex support for fill_zeros_like --- paddle/fluid/operators/fill_zeros_like_op.cc | 13 +++++++++++-- paddle/fluid/operators/fill_zeros_like_op.cu.cc | 13 +++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/fill_zeros_like_op.cc b/paddle/fluid/operators/fill_zeros_like_op.cc index c727c657ed79d5..2d340829332c81 100644 --- a/paddle/fluid/operators/fill_zeros_like_op.cc +++ b/paddle/fluid/operators/fill_zeros_like_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fill_zeros_like_op.h" +#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -93,7 +94,11 @@ REGISTER_OP_CPU_KERNEL( ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, - ops::FillZerosLikeKernel); + ops::FillZerosLikeKernel, + ops::FillZerosLikeKernel>, + ops::FillZerosLikeKernel>); REGISTER_OP_CPU_KERNEL( fill_zeros_like2, @@ -101,4 +106,8 @@ REGISTER_OP_CPU_KERNEL( ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, - ops::FillZerosLikeKernel); + ops::FillZerosLikeKernel, + ops::FillZerosLikeKernel>, + ops::FillZerosLikeKernel>); diff --git a/paddle/fluid/operators/fill_zeros_like_op.cu.cc b/paddle/fluid/operators/fill_zeros_like_op.cu.cc index 1831635def79b3..4cb0887c1f326c 100644 --- a/paddle/fluid/operators/fill_zeros_like_op.cu.cc +++ b/paddle/fluid/operators/fill_zeros_like_op.cu.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fill_zeros_like_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; @@ -25,7 +26,11 @@ REGISTER_OP_CUDA_KERNEL( ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, - ops::FillZerosLikeKernel); + ops::FillZerosLikeKernel, + ops::FillZerosLikeKernel>, + ops::FillZerosLikeKernel>); REGISTER_OP_CUDA_KERNEL( fill_zeros_like2, @@ -35,4 +40,8 @@ REGISTER_OP_CUDA_KERNEL( ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, - ops::FillZerosLikeKernel); + ops::FillZerosLikeKernel, + ops::FillZerosLikeKernel>, + ops::FillZerosLikeKernel>); From ce663bef858d7123d0f8d146897b6e5204ce0e16 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 16 Sep 2021 02:38:01 +0800 Subject: [PATCH 3/7] use dynload for cufft --- paddle/fluid/operators/CMakeLists.txt | 47 ++++---- paddle/fluid/operators/spectral_op.cc | 18 ++- paddle/fluid/operators/spectral_op.cu | 25 ++-- paddle/fluid/platform/dynload/CMakeLists.txt | 2 +- paddle/fluid/platform/dynload/cufft.cc | 44 +++++++ paddle/fluid/platform/dynload/cufft.h | 113 ++++++++++++++++++ .../fluid/platform/dynload/dynamic_loader.cc | 11 ++ .../fluid/platform/dynload/dynamic_loader.h | 1 + .../fluid/tests/unittests/fft/test_fft.py | 15 +-- 9 files changed, 233 insertions(+), 43 deletions(-) create mode 100644 paddle/fluid/platform/dynload/cufft.cc create mode 100644 paddle/fluid/platform/dynload/cufft.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index a9e29fb8361b9c..a4eea956fca32a 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -98,29 +98,30 @@ else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() -op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS ${OP_HEADER_DEPS}) -if (WITH_GPU) - find_library(CUFFT_LIB libcufft.so - PATHS - ${CUDA_TOOLKIT_ROOT_DIR}/lib64/ - NO_DEFAULT_PATH - ) - target_link_libraries(spectral_op ${CUFFT_LIB}) -endif() -if(WITH_ONEMKL) - find_library(ONEMKL_CORE libmkl_core.so - PATHS - ${MKL_ROOT}/lib/${MKL_ARCH} - NO_DEFAULT_PATH - ) - find_library(ONEMKL_THREAD libmkl_intel_thread.so - PATHS - ${MKL_ROOT}/lib/${MKL_ARCH} - NO_DEFAULT_PATH - ) - target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE}) - target_link_libraries(spectral_op MKL::mkl_core MKL::mkl_intel_thread) -endif() +op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS}) +# op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS ${OP_HEADER_DEPS}) +# if (WITH_GPU) +# find_library(CUFFT_LIB libcufft.so +# PATHS +# ${CUDA_TOOLKIT_ROOT_DIR}/lib64/ +# NO_DEFAULT_PATH +# ) +# target_link_libraries(spectral_op ${CUFFT_LIB}) +# endif() +# if(WITH_ONEMKL) +# find_library(ONEMKL_CORE libmkl_core.so +# PATHS +# ${MKL_ROOT}/lib/${MKL_ARCH} +# NO_DEFAULT_PATH +# ) +# find_library(ONEMKL_THREAD libmkl_intel_thread.so +# PATHS +# ${MKL_ROOT}/lib/${MKL_ARCH} +# NO_DEFAULT_PATH +# ) +# target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE}) +# target_link_libraries(spectral_op MKL::mkl_core MKL::mkl_intel_thread) +# endif() op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) op_library(eye_op DEPS ${OP_HEADER_DEPS}) diff --git a/paddle/fluid/operators/spectral_op.cc b/paddle/fluid/operators/spectral_op.cc index b9359f48c6f295..ebc06bac5328fa 100644 --- a/paddle/fluid/operators/spectral_op.cc +++ b/paddle/fluid/operators/spectral_op.cc @@ -240,15 +240,29 @@ class FFTC2ROp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fft_c2r"); const auto axes = ctx->Attrs().Get>("axes"); + const auto x_dim = ctx->GetInputDim("X"); + for (size_t i = 0; i < axes.size() - 1L; i++) { + const auto fft_n_point = (x_dim[axes[i]] - 1) * 2; + PADDLE_ENFORCE_GT(fft_n_point, 0, + platform::errors::InvalidArgument( + "Invalid fft n-point (%d).", fft_n_point)); + } const int64_t last_dim_size = ctx->Attrs().Get("last_dim_size"); framework::DDim out_dim(ctx->GetInputDim("X")); const int64_t last_fft_axis = axes.back(); if (last_dim_size == 0) { const int64_t last_fft_dim_size = out_dim.at(last_fft_axis); - out_dim.at(last_fft_axis) = (last_fft_dim_size - 1) * 2; + const int64_t fft_n_point = (last_fft_dim_size - 1) * 2; + PADDLE_ENFORCE_GT(fft_n_point, 0, + platform::errors::InvalidArgument( + "Invalid fft n-point (%d).", fft_n_point)); + out_dim.at(last_fft_axis) = fft_n_point; } else { - out_dim.at(last_fft_axis) = ctx->Attrs().Get("last_dim_size"); + PADDLE_ENFORCE_GT(last_dim_size, 0, + platform::errors::InvalidArgument( + "Invalid fft n-point (%d).", last_dim_size)); + out_dim.at(last_fft_axis) = last_dim_size; } ctx->SetOutputDim("Out", out_dim); } diff --git a/paddle/fluid/operators/spectral_op.cu b/paddle/fluid/operators/spectral_op.cu index a591103663a58b..7441a06a88c10f 100644 --- a/paddle/fluid/operators/spectral_op.cu +++ b/paddle/fluid/operators/spectral_op.cu @@ -26,6 +26,7 @@ #include "paddle/fluid/operators/conj_op.h" #include "paddle/fluid/operators/spectral_op.h" #include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/platform/dynload/cufft.h" namespace paddle { namespace operators { @@ -141,7 +142,7 @@ class CuFFTHandle { ::cufftHandle handle_; public: - CuFFTHandle() { CUFFT_CHECK(cufftCreate(&handle_)); } + CuFFTHandle() { CUFFT_CHECK(platform::dynload::cufftCreate(&handle_)); } ::cufftHandle& get() { return handle_; } const ::cufftHandle& get() const { return handle_; } @@ -149,7 +150,8 @@ class CuFFTHandle { ~CuFFTHandle() { // Not using fftDestroy() for rocFFT to work around double freeing of handles #ifndef __HIPCC__ - cufftDestroy(handle_); + std::cout << "Dtor of CuFFTHandle" << std::endl; + CUFFT_CHECK(platform::dynload::cufftDestroy(handle_)); #endif } }; @@ -245,7 +247,8 @@ class CuFFTConfig { #endif // disable auto allocation of workspace to use THC allocator - CUFFT_CHECK(cufftSetAutoAllocation(plan(), /* autoAllocate */ 0)); + CUFFT_CHECK(platform::dynload::cufftSetAutoAllocation( + plan(), /* autoAllocate */ 0)); size_t ws_size_t; @@ -258,7 +261,7 @@ class CuFFTConfig { batch, &ws_size_t)); #else - CUFFT_CHECK(cufftXtMakePlanMany( + CUFFT_CHECK(platform::dynload::cufftXtMakePlanMany( plan(), signal_ndim, signal_sizes.data(), /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype, /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype, @@ -364,6 +367,11 @@ class PlanLRUCache { return *this; } + ~PlanLRUCache() { + std::cout << "DTor of PlanLRUCache" << std::endl; + clear(); + } + // If key is in this cache, return the cached config. Otherwise, emplace the // config in this cache and return it. CuFFTConfig& lookup(PlanKey params) { @@ -498,8 +506,8 @@ static void exec_cufft_plan(const CuFFTConfig& config, void* in_data, PADDLE_THROW(platform::errors::InvalidArgument( "hipFFT only support transforms of type float32 and float64")); #else - CUFFT_CHECK(cufftXtExec(plan, in_data, out_data, - forward ? CUFFT_FORWARD : CUFFT_INVERSE)); + CUFFT_CHECK(platform::dynload::cufftXtExec( + plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE)); #endif } @@ -641,10 +649,11 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, auto& plan = config->plan(); // prepare cufft for execution - CUFFT_CHECK(cufftSetStream(plan, ctx.stream())); + CUFFT_CHECK(platform::dynload::cufftSetStream(plan, ctx.stream())); framework::Tensor workspace_tensor; workspace_tensor.mutable_data(tensor_place, config->workspace_size()); - CUFFT_CHECK(cufftSetWorkArea(plan, workspace_tensor.data())); + CUFFT_CHECK( + platform::dynload::cufftSetWorkArea(plan, workspace_tensor.data())); // execute transform plan if (fft_type == FFTTransformType::C2R && forward) { diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index ac98ff02035bdb..eed3568f1d86be 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -1,6 +1,6 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) -list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc cusparse.cc nvtx.cc) +list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc cusparse.cc nvtx.cc cufft.cc) if (NOT WITH_NV_JETSON) list(APPEND CUDA_SRCS nvjpeg.cc) diff --git a/paddle/fluid/platform/dynload/cufft.cc b/paddle/fluid/platform/dynload/cufft.cc new file mode 100644 index 00000000000000..a125fb7226050b --- /dev/null +++ b/paddle/fluid/platform/dynload/cufft.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/cufft.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace dynload { +std::once_flag cufft_dso_flag; +void* cufft_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +CUFFT_FFT_ROUTINE_EACH(DEFINE_WRAP); + +bool HasCUFFT() { + std::call_once(cufft_dso_flag, + []() { cufft_dso_handle = GetCUFFTDsoHandle(); }); + return cufft_dso_handle != nullptr; +} + +void EnforceCUFFTLoaded(const char* fn_name) { + PADDLE_ENFORCE_NOT_NULL( + cufft_dso_handle, + platform::errors::PreconditionNotMet( + "Cannot load cufft shared library. Cannot invoke method %s.", + fn_name)); +} + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cufft.h b/paddle/fluid/platform/dynload/cufft.h new file mode 100644 index 00000000000000..ef924d7b5ee865 --- /dev/null +++ b/paddle/fluid/platform/dynload/cufft.h @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#ifdef PADDLE_WITH_CUDA +#include +#include +#include +#include // NOLINT + +#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/fluid/platform/port.h" + +namespace paddle { +namespace platform { +namespace dynload { + +extern std::once_flag cufft_dso_flag; +extern void* cufft_dso_handle; +extern bool HasCUFFT(); + +extern void EnforceCUFFTLoaded(const char* fn_name); +#define DECLARE_DYNAMIC_LOAD_CUFFT_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using cufft_func = decltype(&::__name); \ + std::call_once(cufft_dso_flag, []() { \ + cufft_dso_handle = paddle::platform::dynload::GetCUFFTDsoHandle(); \ + }); \ + EnforceCUFFTLoaded(#__name); \ + static void* p_##__name = dlsym(cufft_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern struct DynLoad__##__name __name + +/** + * include all needed cufft functions in HPPL + * different cufft version has different interfaces + **/ +#define CUFFT_FFT_ROUTINE_EACH(__macro) \ + __macro(cufftPlan1d); \ + __macro(cufftPlan2d); \ + __macro(cufftPlan3d); \ + __macro(cufftPlanMany); \ + __macro(cufftMakePlan1d); \ + __macro(cufftMakePlan2d); \ + __macro(cufftMakePlan3d); \ + __macro(cufftMakePlanMany); \ + __macro(cufftMakePlanMany64); \ + __macro(cufftGetSizeMany64); \ + __macro(cufftEstimate1d); \ + __macro(cufftEstimate2d); \ + __macro(cufftEstimate3d); \ + __macro(cufftEstimateMany); \ + __macro(cufftCreate); \ + __macro(cufftGetSize1d); \ + __macro(cufftGetSize2d); \ + __macro(cufftGetSize3d); \ + __macro(cufftGetSizeMany); \ + __macro(cufftGetSize); \ + __macro(cufftSetWorkArea); \ + __macro(cufftSetAutoAllocation); \ + __macro(cufftExecC2C); \ + __macro(cufftExecR2C); \ + __macro(cufftExecC2R); \ + __macro(cufftExecZ2Z); \ + __macro(cufftExecD2Z); \ + __macro(cufftExecZ2D); \ + __macro(cufftSetStream); \ + __macro(cufftDestroy); \ + __macro(cufftGetVersion); \ + __macro(cufftGetProperty); \ + __macro(cufftXtSetGPUs); \ + __macro(cufftXtMalloc); \ + __macro(cufftXtMemcpy); \ + __macro(cufftXtFree); \ + __macro(cufftXtSetWorkArea); \ + __macro(cufftXtExecDescriptorC2C); \ + __macro(cufftXtExecDescriptorR2C); \ + __macro(cufftXtExecDescriptorC2R); \ + __macro(cufftXtExecDescriptorZ2Z); \ + __macro(cufftXtExecDescriptorD2Z); \ + __macro(cufftXtExecDescriptorZ2D); \ + __macro(cufftXtQueryPlan); \ + __macro(cufftXtSetCallback); \ + __macro(cufftXtClearCallback); \ + __macro(cufftXtSetCallbackSharedSize); \ + __macro(cufftXtMakePlanMany); \ + __macro(cufftXtGetSizeMany); \ + __macro(cufftXtExec); \ + __macro(cufftXtExecDescriptor); \ + __macro(cufftXtSetWorkAreaPolicy); + +CUFFT_FFT_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUFFT_WRAP) + +} // namespace dynload +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 37932600e7a7ea..d43e901371582a 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -122,6 +122,7 @@ static constexpr char* win_cusolver_lib = static constexpr char* win_cusparse_lib = "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll"; +static constexpr char* win_cufft_lib = "cufft64_" CUDA_MAJOR_VERSION ".dll"; #endif // CUDA_VERSION #endif @@ -489,6 +490,16 @@ void* GetNvtxDsoHandle() { #endif } +void* GetCUFFTDsoHandle() { +#if defined(_WIN32) && defined(PADDLE_WITH_CUDA) + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_cufft_lib, true, + {cuda_lib_path}); +#else + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so", false, + {cuda_lib_path}); +#endif +} + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index e282c033c4451b..08f0aec8b0179a 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -41,6 +41,7 @@ void* GetTensorRtDsoHandle(); void* GetMKLMLDsoHandle(); void* GetOpDsoHandle(const std::string& dso_name); void* GetNvtxDsoHandle(); +void* GetCUFFTDsoHandle(); void SetPaddleLibPath(const std::string&); } // namespace dynload diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index 6bfc7decdae1e0..a06dffb78c8a16 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -359,12 +359,6 @@ def test_hfftn(self): ('test_x_complex128', (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) ).astype(np.complex128), None, (-2, -1), "backward"), - ('test_n_grater_input_length', - np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [1, 2], (-2, -1), - "backward"), - ('test_n_smaller_input_length', - np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2, 1], (-2, -1), - "backward"), ('test_axis_not_last', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), "backward"), @@ -394,9 +388,6 @@ def test_hfft2(self): ('test_x_complex128', (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) ).astype(np.complex128), None, (-2, -1), "backward"), - ('test_n_equal_input_length', - np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (2, 1), (-2, -1), - "backward"), ('test_axis_not_last', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), "backward"), @@ -514,6 +505,9 @@ def test_irfft(self): ('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, -1, 'backward', ValueError), \ + ('test_zero_n_point', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2, 1], (-2, -1), + "backward", ValueError), ('test_norm_not_in_enum_value', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None, None, 'random', ValueError)]) @@ -539,6 +533,9 @@ def test_hfft2(self): [('test_n_nagative', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), (-2, -1), 'backward', ValueError), \ + ('test_n_equal_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (2, 1), (-2, -1), + "backward", ValueError), \ ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (0, 0), (-2, -1), 'backward', ValueError), \ ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), From 254dddc321162a223900fda6acf463610f96f64f Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 16 Sep 2021 15:16:26 +0800 Subject: [PATCH 4/7] 1. fix unittest; 2. temporarily disable Rocm. --- paddle/fluid/operators/spectral_op.cc | 5 ++--- paddle/fluid/operators/spectral_op.cu | 8 +++----- python/paddle/fluid/layers/nn.py | 6 ++++-- .../paddle/fluid/tests/unittests/fft/test_fft.py | 14 ++++++-------- .../unittests/fft/test_fft_with_static_graph.py | 6 +++--- 5 files changed, 18 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/spectral_op.cc b/paddle/fluid/operators/spectral_op.cc index ebc06bac5328fa..fb0be9ba68fcf4 100644 --- a/paddle/fluid/operators/spectral_op.cc +++ b/paddle/fluid/operators/spectral_op.cc @@ -242,10 +242,9 @@ class FFTC2ROp : public framework::OperatorWithKernel { const auto axes = ctx->Attrs().Get>("axes"); const auto x_dim = ctx->GetInputDim("X"); for (size_t i = 0; i < axes.size() - 1L; i++) { - const auto fft_n_point = (x_dim[axes[i]] - 1) * 2; - PADDLE_ENFORCE_GT(fft_n_point, 0, + PADDLE_ENFORCE_GT(x_dim[axes[i]], 0, platform::errors::InvalidArgument( - "Invalid fft n-point (%d).", fft_n_point)); + "Invalid fft n-point (%d).", x_dim[axes[i]])); } const int64_t last_dim_size = ctx->Attrs().Get("last_dim_size"); diff --git a/paddle/fluid/operators/spectral_op.cu b/paddle/fluid/operators/spectral_op.cu index 7441a06a88c10f..76cb04de8df29a 100644 --- a/paddle/fluid/operators/spectral_op.cu +++ b/paddle/fluid/operators/spectral_op.cu @@ -9,6 +9,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#ifdef PADDLE_WITH_CUDA #include #include @@ -150,7 +151,6 @@ class CuFFTHandle { ~CuFFTHandle() { // Not using fftDestroy() for rocFFT to work around double freeing of handles #ifndef __HIPCC__ - std::cout << "Dtor of CuFFTHandle" << std::endl; CUFFT_CHECK(platform::dynload::cufftDestroy(handle_)); #endif } @@ -367,10 +367,7 @@ class PlanLRUCache { return *this; } - ~PlanLRUCache() { - std::cout << "DTor of PlanLRUCache" << std::endl; - clear(); - } + ~PlanLRUCache() { clear(); } // If key is in this cache, return the cached config. Otherwise, emplace the // config in this cache and return it. @@ -869,3 +866,4 @@ REGISTER_OP_CUDA_KERNEL( fft_r2c_grad, ops::FFTR2CGradKernel, ops::FFTR2CGradKernel); +#endif diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3796585f089c24..7870355d5c4aa4 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6749,8 +6749,10 @@ def pad(x, paddings, pad_value=0., name=None): x = fluid.data(name='data', shape=[300, 300], dtype='float32') out = fluid.layers.pad(x=x, paddings=[0, 1, 1, 2], pad_value=0.) """ - check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], "pad") + check_variable_and_dtype(x, 'x', [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64', + 'complex128' + ], "pad") helper = LayerHelper('pad', **locals()) dtype = helper.input_dtype(input_param_name='x') diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index a06dffb78c8a16..92b46363b0f114 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -359,6 +359,8 @@ def test_hfftn(self): ('test_x_complex128', (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) ).astype(np.complex128), None, (-2, -1), "backward"), + ('test_with_s', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), + [2, 2], (-2, -1), "backward", ValueError), ('test_axis_not_last', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), "backward"), @@ -388,6 +390,9 @@ def test_hfft2(self): ('test_x_complex128', (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) ).astype(np.complex128), None, (-2, -1), "backward"), + ('test_n_equal_input_length', + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (4, 6), (-2, -1), + "backward"), \ ('test_axis_not_last', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), "backward"), @@ -503,11 +508,7 @@ def test_irfft(self): np.random.randn(4) + 1j * np.random.randn(4), None, (1, 2), 'backward', ValueError), \ ('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, -1, - 'backward', - ValueError), \ - ('test_zero_n_point', - np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2, 1], (-2, -1), - "backward", ValueError), + 'backward', ValueError), \ ('test_norm_not_in_enum_value', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), None, None, 'random', ValueError)]) @@ -533,9 +534,6 @@ def test_hfft2(self): [('test_n_nagative', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), (-2, -1), 'backward', ValueError), \ - ('test_n_equal_input_length', - np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (2, 1), (-2, -1), - "backward", ValueError), \ ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (0, 0), (-2, -1), 'backward', ValueError), \ ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py b/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py index 8aa194813f57df..6cd63f7d00a58b 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py @@ -322,10 +322,10 @@ def test_static_hfftn(self): (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) ).astype(np.complex128), None, (-2, -1), "backward"), ('test_n_grater_input_length', - np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [1, 2], (-2, -1), + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [4, 8], (-2, -1), "backward"), ('test_n_smaller_input_length', - np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2, 1], (-2, -1), + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), [2, 4], (-2, -1), "backward"), ('test_axis_not_last', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), @@ -357,7 +357,7 @@ def test_static_hfft2(self): (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4) ).astype(np.complex128), None, (-2, -1), "backward"), ('test_n_equal_input_length', - np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (2, 1), (-2, -1), + np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (2, 4), (-2, -1), "backward"), ('test_axis_not_last', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None, (-2, -1), From 76822d39ccaa5d2dfc35ab70859f0cd9721c8d1c Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 16 Sep 2021 16:38:30 +0800 Subject: [PATCH 5/7] fix compile error: only link dyload_cuda when cuda is available --- paddle/fluid/operators/CMakeLists.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index a4eea956fca32a..31594c3b9863ce 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -98,7 +98,12 @@ else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() -op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS}) +if (WITH_GPU AND (NOT WITH_ROCM)) + op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS}) +else() + op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS}) +endif() + # op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS ${OP_HEADER_DEPS}) # if (WITH_GPU) # find_library(CUFFT_LIB libcufft.so From 649c52161370071442e2c5e9fa20d47e1326dada Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 16 Sep 2021 18:36:44 +0800 Subject: [PATCH 6/7] 1. fix dynload for cufft on windows; 2. fix unittests. --- paddle/fluid/platform/dynload/dynamic_loader.cc | 7 ++++++- python/paddle/fluid/tests/unittests/fft/test_fft.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index d43e901371582a..629a50561d9c68 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -109,6 +109,9 @@ static constexpr char* win_cusolver_lib = static constexpr char* win_cusparse_lib = "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll;cusparse64_10.dll"; +static constexpr char* win_cufft_lib = + "cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR + ".dll;cufft64_" CUDA_VERSION_MAJOR ".dll;cufft64_10.dll"; #else static constexpr char* win_curand_lib = "curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR @@ -122,7 +125,9 @@ static constexpr char* win_cusolver_lib = static constexpr char* win_cusparse_lib = "cusparse64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR ".dll;cusparse64_" CUDA_VERSION_MAJOR ".dll"; -static constexpr char* win_cufft_lib = "cufft64_" CUDA_MAJOR_VERSION ".dll"; +static constexpr char* win_cufft_lib = + "cufft64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR + ".dll;cufft64_" CUDA_VERSION_MAJOR ".dll"; #endif // CUDA_VERSION #endif diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index d563b3bed52046..1b4bcb709331cd 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -534,8 +534,8 @@ def test_hfft2(self): [('test_n_nagative', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2), (-2, -1), 'backward', ValueError), \ - ('test_n_equal_input_length', - np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (2, 1), (-2, -1), + ('test_zero_point', + np.random.randn(4, 4, 1) + 1j * np.random.randn(4, 4, 1), None, (-2, -1), "backward", ValueError), \ ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (0, 0), (-2, -1), 'backward', ValueError), \ From 14fa65975c1be1ebbaeb5801ef14f4d567b738ca Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 16 Sep 2021 22:08:59 +0800 Subject: [PATCH 7/7] add NOMINMAX to compile on windows --- paddle/fluid/operators/spectral_op.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/operators/spectral_op.h b/paddle/fluid/operators/spectral_op.h index 9e32b5d8ff8efe..e549c4a454b198 100644 --- a/paddle/fluid/operators/spectral_op.h +++ b/paddle/fluid/operators/spectral_op.h @@ -10,6 +10,7 @@ limitations under the License. */ #pragma once +#define NOMINMAX // to use std::min std::max correctly on windows #include #include #include