Skip to content

Commit 1d5746b

Browse files
authored
add rocm support for fft api (#36415)
1 parent 77f4597 commit 1d5746b

File tree

10 files changed

+679
-380
lines changed

10 files changed

+679
-380
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ else()
102102
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
103103
endif()
104104

105-
106-
if (WITH_GPU AND (NOT WITH_ROCM))
105+
if (WITH_GPU OR WITH_ROCM)
107106
if (MKL_FOUND AND WITH_ONEMKL)
108107
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda dynload_mklrt ${OP_HEADER_DEPS})
109108
target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/operators/spectral_op.h"
18+
19+
#ifdef PADDLE_WITH_HIP
20+
#include "paddle/fluid/platform/dynload/hipfft.h"
21+
#endif
22+
23+
#ifdef PADDLE_WITH_CUDA
24+
#include "paddle/fluid/platform/dynload/cufft.h"
25+
#endif
26+
27+
namespace paddle {
28+
namespace operators {
29+
using ScalarType = framework::proto::VarType::Type;
30+
const int64_t kMaxCUFFTNdim = 3;
31+
const int64_t kMaxDataNdim = kMaxCUFFTNdim + 1;
32+
// This struct is used to easily compute hashes of the
33+
// parameters. It will be the **key** to the plan cache.
34+
struct PlanKey {
35+
// between 1 and kMaxCUFFTNdim, i.e., 1 <= signal_ndim <= 3
36+
int64_t signal_ndim_;
37+
// These include additional batch dimension as well.
38+
int64_t sizes_[kMaxDataNdim];
39+
int64_t input_shape_[kMaxDataNdim];
40+
int64_t output_shape_[kMaxDataNdim];
41+
FFTTransformType fft_type_;
42+
ScalarType value_type_;
43+
44+
PlanKey() = default;
45+
46+
PlanKey(const std::vector<int64_t>& in_shape,
47+
const std::vector<int64_t>& out_shape,
48+
const std::vector<int64_t>& signal_size, FFTTransformType fft_type,
49+
ScalarType value_type) {
50+
// Padding bits must be zeroed for hashing
51+
memset(this, 0, sizeof(*this));
52+
signal_ndim_ = signal_size.size() - 1;
53+
fft_type_ = fft_type;
54+
value_type_ = value_type;
55+
56+
std::copy(signal_size.cbegin(), signal_size.cend(), sizes_);
57+
std::copy(in_shape.cbegin(), in_shape.cend(), input_shape_);
58+
std::copy(out_shape.cbegin(), out_shape.cend(), output_shape_);
59+
}
60+
};
61+
62+
#if defined(PADDLE_WITH_CUDA)
63+
// An RAII encapsulation of cuFFTHandle
64+
class CuFFTHandle {
65+
::cufftHandle handle_;
66+
67+
public:
68+
CuFFTHandle() {
69+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftCreate(&handle_));
70+
}
71+
72+
::cufftHandle& get() { return handle_; }
73+
const ::cufftHandle& get() const { return handle_; }
74+
75+
~CuFFTHandle() {
76+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftDestroy(handle_));
77+
}
78+
};
79+
80+
using plan_size_type = long long int; // NOLINT
81+
// This class contains all the information needed to execute a cuFFT plan:
82+
// 1. the plan
83+
// 2. the workspace size needed
84+
class CuFFTConfig {
85+
public:
86+
// Only move semantics is enought for this class. Although we already use
87+
// unique_ptr for the plan, still remove copy constructor and assignment op so
88+
// we don't accidentally copy and take perf hit.
89+
explicit CuFFTConfig(const PlanKey& plan_key)
90+
: CuFFTConfig(
91+
std::vector<int64_t>(plan_key.sizes_,
92+
plan_key.sizes_ + plan_key.signal_ndim_ + 1),
93+
plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {}
94+
95+
// sizes are full signal, including batch size and always two-sided
96+
CuFFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim,
97+
FFTTransformType fft_type, ScalarType dtype)
98+
: fft_type_(fft_type), value_type_(dtype) {
99+
// signal sizes (excluding batch dim)
100+
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
101+
102+
// input batch size
103+
const auto batch = static_cast<plan_size_type>(sizes[0]);
104+
// const int64_t signal_ndim = sizes.size() - 1;
105+
PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1,
106+
platform::errors::InvalidArgument(
107+
"The signal_ndim must be equal to sizes.size() - 1,"
108+
"But signal_ndim is: [%d], sizes.size() - 1 is: [%d]",
109+
signal_ndim, sizes.size() - 1));
110+
111+
cudaDataType itype, otype, exec_type;
112+
const auto complex_input = has_complex_input(fft_type);
113+
const auto complex_output = has_complex_output(fft_type);
114+
if (dtype == framework::proto::VarType::FP32) {
115+
itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
116+
otype = complex_output ? CUDA_C_32F : CUDA_R_32F;
117+
exec_type = CUDA_C_32F;
118+
} else if (dtype == framework::proto::VarType::FP64) {
119+
itype = complex_input ? CUDA_C_64F : CUDA_R_64F;
120+
otype = complex_output ? CUDA_C_64F : CUDA_R_64F;
121+
exec_type = CUDA_C_64F;
122+
} else if (dtype == framework::proto::VarType::FP16) {
123+
itype = complex_input ? CUDA_C_16F : CUDA_R_16F;
124+
otype = complex_output ? CUDA_C_16F : CUDA_R_16F;
125+
exec_type = CUDA_C_16F;
126+
} else {
127+
PADDLE_THROW(platform::errors::InvalidArgument(
128+
"cuFFT only support transforms of type float16, float32 and "
129+
"float64"));
130+
}
131+
132+
// disable auto allocation of workspace to use allocator from the framework
133+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftSetAutoAllocation(
134+
plan(), /* autoAllocate */ 0));
135+
136+
size_t ws_size_t;
137+
138+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftXtMakePlanMany(
139+
plan(), signal_ndim, signal_sizes.data(),
140+
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
141+
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
142+
batch, &ws_size_t, exec_type));
143+
144+
ws_size = ws_size_t;
145+
}
146+
147+
const cufftHandle& plan() const { return plan_ptr.get(); }
148+
149+
FFTTransformType transform_type() const { return fft_type_; }
150+
ScalarType data_type() const { return value_type_; }
151+
size_t workspace_size() const { return ws_size; }
152+
153+
private:
154+
CuFFTHandle plan_ptr;
155+
size_t ws_size;
156+
FFTTransformType fft_type_;
157+
ScalarType value_type_;
158+
};
159+
160+
#elif defined(PADDLE_WITH_HIP)
161+
// An RAII encapsulation of cuFFTHandle
162+
class HIPFFTHandle {
163+
::hipfftHandle handle_;
164+
165+
public:
166+
HIPFFTHandle() {
167+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftCreate(&handle_));
168+
}
169+
170+
::hipfftHandle& get() { return handle_; }
171+
const ::hipfftHandle& get() const { return handle_; }
172+
173+
~HIPFFTHandle() {
174+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftDestroy(handle_));
175+
}
176+
};
177+
using plan_size_type = int;
178+
// This class contains all the information needed to execute a cuFFT plan:
179+
// 1. the plan
180+
// 2. the workspace size needed
181+
class HIPFFTConfig {
182+
public:
183+
// Only move semantics is enought for this class. Although we already use
184+
// unique_ptr for the plan, still remove copy constructor and assignment op so
185+
// we don't accidentally copy and take perf hit.
186+
explicit HIPFFTConfig(const PlanKey& plan_key)
187+
: HIPFFTConfig(
188+
std::vector<int64_t>(plan_key.sizes_,
189+
plan_key.sizes_ + plan_key.signal_ndim_ + 1),
190+
plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {}
191+
192+
// sizes are full signal, including batch size and always two-sided
193+
HIPFFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim,
194+
FFTTransformType fft_type, ScalarType dtype)
195+
: fft_type_(fft_type), value_type_(dtype) {
196+
// signal sizes (excluding batch dim)
197+
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
198+
199+
// input batch size
200+
const auto batch = static_cast<plan_size_type>(sizes[0]);
201+
// const int64_t signal_ndim = sizes.size() - 1;
202+
PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1,
203+
platform::errors::InvalidArgument(
204+
"The signal_ndim must be equal to sizes.size() - 1,"
205+
"But signal_ndim is: [%d], sizes.size() - 1 is: [%d]",
206+
signal_ndim, sizes.size() - 1));
207+
208+
hipfftType exec_type = [&] {
209+
if (dtype == framework::proto::VarType::FP32) {
210+
switch (fft_type) {
211+
case FFTTransformType::C2C:
212+
return HIPFFT_C2C;
213+
case FFTTransformType::R2C:
214+
return HIPFFT_R2C;
215+
case FFTTransformType::C2R:
216+
return HIPFFT_C2R;
217+
}
218+
} else if (dtype == framework::proto::VarType::FP64) {
219+
switch (fft_type) {
220+
case FFTTransformType::C2C:
221+
return HIPFFT_Z2Z;
222+
case FFTTransformType::R2C:
223+
return HIPFFT_D2Z;
224+
case FFTTransformType::C2R:
225+
return HIPFFT_Z2D;
226+
}
227+
}
228+
PADDLE_THROW(platform::errors::InvalidArgument(
229+
"hipFFT only support transforms of type float32 and float64"));
230+
}();
231+
232+
// disable auto allocation of workspace to use allocator from the framework
233+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftSetAutoAllocation(
234+
plan(), /* autoAllocate */ 0));
235+
236+
size_t ws_size_t;
237+
238+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftMakePlanMany(
239+
plan(), signal_ndim, signal_sizes.data(),
240+
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1,
241+
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, exec_type,
242+
batch, &ws_size_t));
243+
244+
ws_size = ws_size_t;
245+
}
246+
247+
const hipfftHandle& plan() const { return plan_ptr.get(); }
248+
249+
FFTTransformType transform_type() const { return fft_type_; }
250+
ScalarType data_type() const { return value_type_; }
251+
size_t workspace_size() const { return ws_size; }
252+
253+
private:
254+
HIPFFTHandle plan_ptr;
255+
size_t ws_size;
256+
FFTTransformType fft_type_;
257+
ScalarType value_type_;
258+
};
259+
#endif
260+
} // namespace operators
261+
} // namespace paddle

0 commit comments

Comments
 (0)