Skip to content

Commit 5ccdf98

Browse files
author
Feiyu Chan
authored
use dynload for cufft (#46)
* use std::ptrdiff_t as datatype of stride (instead of int64_t) to avoid argument mismatch on some platforms. * add complex support for fill_zeros_like * use dynload for cufft
1 parent 2cb21c0 commit 5ccdf98

File tree

11 files changed

+270
-62
lines changed

11 files changed

+270
-62
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -98,29 +98,30 @@ else()
9898
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
9999
endif()
100100

101-
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS ${OP_HEADER_DEPS})
102-
if (WITH_GPU)
103-
find_library(CUFFT_LIB libcufft.so
104-
PATHS
105-
${CUDA_TOOLKIT_ROOT_DIR}/lib64/
106-
NO_DEFAULT_PATH
107-
)
108-
target_link_libraries(spectral_op ${CUFFT_LIB})
109-
endif()
110-
if(WITH_ONEMKL)
111-
find_library(ONEMKL_CORE libmkl_core.so
112-
PATHS
113-
${MKL_ROOT}/lib/${MKL_ARCH}
114-
NO_DEFAULT_PATH
115-
)
116-
find_library(ONEMKL_THREAD libmkl_intel_thread.so
117-
PATHS
118-
${MKL_ROOT}/lib/${MKL_ARCH}
119-
NO_DEFAULT_PATH
120-
)
121-
target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
122-
target_link_libraries(spectral_op MKL::mkl_core MKL::mkl_intel_thread)
123-
endif()
101+
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS})
102+
# op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS ${OP_HEADER_DEPS})
103+
# if (WITH_GPU)
104+
# find_library(CUFFT_LIB libcufft.so
105+
# PATHS
106+
# ${CUDA_TOOLKIT_ROOT_DIR}/lib64/
107+
# NO_DEFAULT_PATH
108+
# )
109+
# target_link_libraries(spectral_op ${CUFFT_LIB})
110+
# endif()
111+
# if(WITH_ONEMKL)
112+
# find_library(ONEMKL_CORE libmkl_core.so
113+
# PATHS
114+
# ${MKL_ROOT}/lib/${MKL_ARCH}
115+
# NO_DEFAULT_PATH
116+
# )
117+
# find_library(ONEMKL_THREAD libmkl_intel_thread.so
118+
# PATHS
119+
# ${MKL_ROOT}/lib/${MKL_ARCH}
120+
# NO_DEFAULT_PATH
121+
# )
122+
# target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
123+
# target_link_libraries(spectral_op MKL::mkl_core MKL::mkl_intel_thread)
124+
# endif()
124125

125126
op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute)
126127
op_library(eye_op DEPS ${OP_HEADER_DEPS})

paddle/fluid/operators/fill_zeros_like_op.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/fill_zeros_like_op.h"
16+
#include "paddle/fluid/platform/complex.h"
1617

1718
namespace paddle {
1819
namespace operators {
@@ -93,12 +94,20 @@ REGISTER_OP_CPU_KERNEL(
9394
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
9495
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
9596
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
96-
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
97+
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>,
98+
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
99+
paddle::platform::complex<float>>,
100+
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
101+
paddle::platform::complex<double>>);
97102

98103
REGISTER_OP_CPU_KERNEL(
99104
fill_zeros_like2,
100105
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int>,
101106
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
102107
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
103108
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
104-
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
109+
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>,
110+
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
111+
paddle::platform::complex<float>>,
112+
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
113+
paddle::platform::complex<double>>);

paddle/fluid/operators/fill_zeros_like_op.cu.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/fill_zeros_like_op.h"
1616
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/platform/complex.h"
1718
#include "paddle/fluid/platform/float16.h"
1819

1920
namespace ops = paddle::operators;
@@ -25,7 +26,11 @@ REGISTER_OP_CUDA_KERNEL(
2526
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
2627
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
2728
paddle::platform::float16>,
28-
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
29+
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>,
30+
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
31+
paddle::platform::complex<float>>,
32+
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
33+
paddle::platform::complex<double>>);
2934

3035
REGISTER_OP_CUDA_KERNEL(
3136
fill_zeros_like2,
@@ -35,4 +40,8 @@ REGISTER_OP_CUDA_KERNEL(
3540
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
3641
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
3742
paddle::platform::float16>,
38-
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
43+
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>,
44+
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
45+
paddle::platform::complex<float>>,
46+
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
47+
paddle::platform::complex<double>>);

paddle/fluid/operators/spectral_op.cc

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -240,15 +240,29 @@ class FFTC2ROp : public framework::OperatorWithKernel {
240240
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fft_c2r");
241241

242242
const auto axes = ctx->Attrs().Get<std::vector<int64_t>>("axes");
243+
const auto x_dim = ctx->GetInputDim("X");
244+
for (size_t i = 0; i < axes.size() - 1L; i++) {
245+
const auto fft_n_point = (x_dim[axes[i]] - 1) * 2;
246+
PADDLE_ENFORCE_GT(fft_n_point, 0,
247+
platform::errors::InvalidArgument(
248+
"Invalid fft n-point (%d).", fft_n_point));
249+
}
243250

244251
const int64_t last_dim_size = ctx->Attrs().Get<int64_t>("last_dim_size");
245252
framework::DDim out_dim(ctx->GetInputDim("X"));
246253
const int64_t last_fft_axis = axes.back();
247254
if (last_dim_size == 0) {
248255
const int64_t last_fft_dim_size = out_dim.at(last_fft_axis);
249-
out_dim.at(last_fft_axis) = (last_fft_dim_size - 1) * 2;
256+
const int64_t fft_n_point = (last_fft_dim_size - 1) * 2;
257+
PADDLE_ENFORCE_GT(fft_n_point, 0,
258+
platform::errors::InvalidArgument(
259+
"Invalid fft n-point (%d).", fft_n_point));
260+
out_dim.at(last_fft_axis) = fft_n_point;
250261
} else {
251-
out_dim.at(last_fft_axis) = ctx->Attrs().Get<int64_t>("last_dim_size");
262+
PADDLE_ENFORCE_GT(last_dim_size, 0,
263+
platform::errors::InvalidArgument(
264+
"Invalid fft n-point (%d).", last_dim_size));
265+
out_dim.at(last_fft_axis) = last_dim_size;
252266
}
253267
ctx->SetOutputDim("Out", out_dim);
254268
}
@@ -681,11 +695,11 @@ struct FFTC2CFunctor<platform::CPUDeviceContext, Ti, To> {
681695
const auto& input_dim = x->dims();
682696
const std::vector<size_t> in_sizes =
683697
framework::vectorize<size_t>(input_dim);
684-
std::vector<int64_t> in_strides =
685-
framework::vectorize<int64_t>(framework::stride(input_dim));
698+
std::vector<std::ptrdiff_t> in_strides =
699+
framework::vectorize<std::ptrdiff_t>(framework::stride(input_dim));
686700
const int64_t data_size = sizeof(C);
687701
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
688-
[](int64_t s) { return s * data_size; });
702+
[](std::ptrdiff_t s) { return s * data_size; });
689703

690704
const auto* in_data = reinterpret_cast<const C*>(x->data<Ti>());
691705
auto* out_data = reinterpret_cast<C*>(out->data<To>());
@@ -714,24 +728,24 @@ struct FFTR2CFunctor<platform::CPUDeviceContext, Ti, To> {
714728
const auto& input_dim = x->dims();
715729
const std::vector<size_t> in_sizes =
716730
framework::vectorize<size_t>(input_dim);
717-
std::vector<int64_t> in_strides =
718-
framework::vectorize<int64_t>(framework::stride(input_dim));
731+
std::vector<std::ptrdiff_t> in_strides =
732+
framework::vectorize<std::ptrdiff_t>(framework::stride(input_dim));
719733
{
720734
const int64_t data_size = sizeof(R);
721735
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
722-
[](int64_t s) { return s * data_size; });
736+
[](std::ptrdiff_t s) { return s * data_size; });
723737
}
724738

725739
const auto& output_dim = out->dims();
726740
const std::vector<size_t> out_sizes =
727741
framework::vectorize<size_t>(output_dim);
728-
std::vector<int64_t> out_strides =
729-
framework::vectorize<int64_t>(framework::stride(output_dim));
742+
std::vector<std::ptrdiff_t> out_strides =
743+
framework::vectorize<std::ptrdiff_t>(framework::stride(output_dim));
730744
{
731745
const int64_t data_size = sizeof(C);
732746
std::transform(out_strides.begin(), out_strides.end(),
733747
out_strides.begin(),
734-
[](int64_t s) { return s * data_size; });
748+
[](std::ptrdiff_t s) { return s * data_size; });
735749
}
736750

737751
const auto* in_data = x->data<R>();
@@ -761,24 +775,24 @@ struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
761775
const auto& input_dim = x->dims();
762776
const std::vector<size_t> in_sizes =
763777
framework::vectorize<size_t>(input_dim);
764-
std::vector<int64_t> in_strides =
765-
framework::vectorize<int64_t>(framework::stride(input_dim));
778+
std::vector<std::ptrdiff_t> in_strides =
779+
framework::vectorize<std::ptrdiff_t>(framework::stride(input_dim));
766780
{
767781
const int64_t data_size = sizeof(C);
768782
std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(),
769-
[](int64_t s) { return s * data_size; });
783+
[](std::ptrdiff_t s) { return s * data_size; });
770784
}
771785

772786
const auto& output_dim = out->dims();
773787
const std::vector<size_t> out_sizes =
774788
framework::vectorize<size_t>(output_dim);
775-
std::vector<int64_t> out_strides =
776-
framework::vectorize<int64_t>(framework::stride(output_dim));
789+
std::vector<std::ptrdiff_t> out_strides =
790+
framework::vectorize<std::ptrdiff_t>(framework::stride(output_dim));
777791
{
778792
const int64_t data_size = sizeof(R);
779793
std::transform(out_strides.begin(), out_strides.end(),
780794
out_strides.begin(),
781-
[](int64_t s) { return s * data_size; });
795+
[](std::ptrdiff_t s) { return s * data_size; });
782796
}
783797

784798
const auto* in_data = reinterpret_cast<const C*>(x->data<Ti>());

paddle/fluid/operators/spectral_op.cu

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "paddle/fluid/operators/conj_op.h"
2727
#include "paddle/fluid/operators/spectral_op.h"
2828
#include "paddle/fluid/operators/transpose_op.h"
29+
#include "paddle/fluid/platform/dynload/cufft.h"
2930

3031
namespace paddle {
3132
namespace operators {
@@ -141,15 +142,16 @@ class CuFFTHandle {
141142
::cufftHandle handle_;
142143

143144
public:
144-
CuFFTHandle() { CUFFT_CHECK(cufftCreate(&handle_)); }
145+
CuFFTHandle() { CUFFT_CHECK(platform::dynload::cufftCreate(&handle_)); }
145146

146147
::cufftHandle& get() { return handle_; }
147148
const ::cufftHandle& get() const { return handle_; }
148149

149150
~CuFFTHandle() {
150151
// Not using fftDestroy() for rocFFT to work around double freeing of handles
151152
#ifndef __HIPCC__
152-
cufftDestroy(handle_);
153+
std::cout << "Dtor of CuFFTHandle" << std::endl;
154+
CUFFT_CHECK(platform::dynload::cufftDestroy(handle_));
153155
#endif
154156
}
155157
};
@@ -245,7 +247,8 @@ class CuFFTConfig {
245247
#endif
246248

247249
// disable auto allocation of workspace to use THC allocator
248-
CUFFT_CHECK(cufftSetAutoAllocation(plan(), /* autoAllocate */ 0));
250+
CUFFT_CHECK(platform::dynload::cufftSetAutoAllocation(
251+
plan(), /* autoAllocate */ 0));
249252

250253
size_t ws_size_t;
251254

@@ -258,7 +261,7 @@ class CuFFTConfig {
258261
batch, &ws_size_t));
259262
#else
260263

261-
CUFFT_CHECK(cufftXtMakePlanMany(
264+
CUFFT_CHECK(platform::dynload::cufftXtMakePlanMany(
262265
plan(), signal_ndim, signal_sizes.data(),
263266
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
264267
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
@@ -364,6 +367,11 @@ class PlanLRUCache {
364367
return *this;
365368
}
366369

370+
~PlanLRUCache() {
371+
std::cout << "DTor of PlanLRUCache" << std::endl;
372+
clear();
373+
}
374+
367375
// If key is in this cache, return the cached config. Otherwise, emplace the
368376
// config in this cache and return it.
369377
CuFFTConfig& lookup(PlanKey params) {
@@ -498,8 +506,8 @@ static void exec_cufft_plan(const CuFFTConfig& config, void* in_data,
498506
PADDLE_THROW(platform::errors::InvalidArgument(
499507
"hipFFT only support transforms of type float32 and float64"));
500508
#else
501-
CUFFT_CHECK(cufftXtExec(plan, in_data, out_data,
502-
forward ? CUFFT_FORWARD : CUFFT_INVERSE));
509+
CUFFT_CHECK(platform::dynload::cufftXtExec(
510+
plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE));
503511
#endif
504512
}
505513

@@ -641,10 +649,11 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out,
641649
auto& plan = config->plan();
642650

643651
// prepare cufft for execution
644-
CUFFT_CHECK(cufftSetStream(plan, ctx.stream()));
652+
CUFFT_CHECK(platform::dynload::cufftSetStream(plan, ctx.stream()));
645653
framework::Tensor workspace_tensor;
646654
workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
647-
CUFFT_CHECK(cufftSetWorkArea(plan, workspace_tensor.data<To>()));
655+
CUFFT_CHECK(
656+
platform::dynload::cufftSetWorkArea(plan, workspace_tensor.data<To>()));
648657

649658
// execute transform plan
650659
if (fft_type == FFTTransformType::C2R && forward) {

paddle/fluid/platform/dynload/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
22

3-
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc nvtx.cc)
3+
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc nvtx.cc cufft.cc)
44

55
if (NOT WITH_NV_JETSON)
66
list(APPEND CUDA_SRCS nvjpeg.cc)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/platform/dynload/cufft.h"
16+
#include "paddle/fluid/platform/enforce.h"
17+
18+
namespace paddle {
19+
namespace platform {
20+
namespace dynload {
21+
std::once_flag cufft_dso_flag;
22+
void* cufft_dso_handle = nullptr;
23+
24+
#define DEFINE_WRAP(__name) DynLoad__##__name __name
25+
26+
CUFFT_FFT_ROUTINE_EACH(DEFINE_WRAP);
27+
28+
bool HasCUFFT() {
29+
std::call_once(cufft_dso_flag,
30+
[]() { cufft_dso_handle = GetCUFFTDsoHandle(); });
31+
return cufft_dso_handle != nullptr;
32+
}
33+
34+
void EnforceCUFFTLoaded(const char* fn_name) {
35+
PADDLE_ENFORCE_NOT_NULL(
36+
cufft_dso_handle,
37+
platform::errors::PreconditionNotMet(
38+
"Cannot load cufft shared library. Cannot invoke method %s.",
39+
fn_name));
40+
}
41+
42+
} // namespace dynload
43+
} // namespace platform
44+
} // namespace paddle

0 commit comments

Comments
 (0)