Skip to content

Commit 11b9f5f

Browse files
cxxlyzhiboniuchenfeiyulijiaqi0612
authored
[Cherry-pick]FFT function enhancements and bugfixes (PaddlePaddle#36537)
* update fft api path (PaddlePaddle#36219) * update fft api path * add sample code for ihfft2 Co-authored-by: chenfeiyu <[email protected]> * fix fft axis (PaddlePaddle#36321) fix: `-1` is used when fft's axis is `0` * use unified external error message for cufft api (PaddlePaddle#36114) * fft: modify sample code result (PaddlePaddle#36325) * dynamic load mkl as a fft backend when it is avaialble and requested (PaddlePaddle#36414) * add rocm support for fft api (PaddlePaddle#36415) * move signal apis * move fft and signal API path (#2) * move signal apis * move fft.py and signal.py to paddle/, fix typos * fix relative imports from fft.py and signal.py * fix typos in signal.py (#3) * move signal apis * move fft.py and signal.py to paddle/, fix typos * fix relative imports from fft.py and signal.py * fix typos * disable Cache when CUFFT_VERSION >= 10200 (#4) * move signal apis * move fft.py and signal.py to paddle/, fix typos * fix relative imports from fft.py and signal.py * fix typos * Add LRUCache for fft plans * add LRUCache for cuff and hipfft (#5) * move signal apis * move fft.py and signal.py to paddle/, fix typos * fix relative imports from fft.py and signal.py * fix typos * WIP: add cache * delete move constructor and operator= for CuFFTHandle and FFTConfig * remove log from CuFFTHandle and FFTConfig * add lrucache for fft rocm backend * disable LRUCache when CUFFT_VERSION >= 10200 * disbale copy and move for hipFFTHandle; format code Co-authored-by: Xiaoxu Chen <[email protected]> * remove debug message of cufftHandler * roll_op: support Tensor as input for shifts (PaddlePaddle#36727) * fix fftshift/ifftshift on static mode * update roll_op version * add more test cases for fftshift/ifftshift Co-authored-by: zhiboniu <[email protected]> Co-authored-by: chenfeiyu <[email protected]> Co-authored-by: LJQ❤️ <[email protected]>
1 parent 96edcea commit 11b9f5f

29 files changed

+1413
-549
lines changed

cmake/third_party.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ if(WITH_GPU)
255255
include(external/cub) # download cub
256256
list(APPEND third_party_deps extern_cub)
257257
endif()
258-
set(URL "https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg.tar.gz" CACHE STRING "" FORCE)
259-
file_download_and_uncompress(${URL} "externalError" MD5 061f3b7895aadcbe2c3ed592590f8b10) # download file externalErrorMsg.tar.gz
258+
set(URL "https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg_20210928.tar.gz" CACHE STRING "" FORCE)
259+
file_download_and_uncompress(${URL} "externalError" MD5 a712a49384e77ca216ad866712f7cafa) # download file externalErrorMsg.tar.gz
260260
if(WITH_TESTING)
261261
# copy externalErrorMsg.pb, just for unittest can get error message correctly.
262262
set(SRC_DIR ${THIRD_PARTY_PATH}/externalError/data)

paddle/fluid/operators/CMakeLists.txt

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,20 @@ else()
105105
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
106106
endif()
107107

108-
if (WITH_GPU AND (NOT WITH_ROCM))
109-
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS})
108+
if (WITH_GPU OR WITH_ROCM)
109+
if (MKL_FOUND AND WITH_ONEMKL)
110+
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda dynload_mklrt ${OP_HEADER_DEPS})
111+
target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
112+
else()
113+
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS})
114+
endif()
110115
else()
111-
op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS})
116+
if (MKL_FOUND AND WITH_ONEMKL)
117+
op_library(spectral_op SRCS spectral_op.cc DEPS dynload_mklrt ${OP_HEADER_DEPS})
118+
target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
119+
else()
120+
op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS})
121+
endif()
112122
endif()
113123

114124
op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute)

paddle/fluid/operators/roll_op.cc

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,23 @@ class RollOp : public framework::OperatorWithKernel {
4040
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis");
4141
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");
4242

43-
if (dims.size() != 0) {
44-
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
45-
platform::errors::InvalidArgument(
46-
"When dims.size() != 0, dims.size() "
47-
"should be equal to "
48-
"shifts.size(). But received "
49-
"dims.size() = %d, shifts.size() = %d",
50-
dims.size(), shifts.size()));
51-
} else {
52-
PADDLE_ENFORCE_EQ(shifts.size(), 1,
53-
platform::errors::InvalidArgument(
54-
"When dims.size() == 0, shifts.size() "
55-
"should be equal to 1, But received "
56-
"shifts.size() = %d",
57-
shifts.size()));
43+
if (!ctx->HasInput("ShiftsTensor")) {
44+
if (dims.size() != 0) {
45+
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
46+
platform::errors::InvalidArgument(
47+
"When dims.size() != 0, dims.size() "
48+
"should be equal to "
49+
"shifts.size(). But received "
50+
"dims.size() = %d, shifts.size() = %d",
51+
dims.size(), shifts.size()));
52+
} else {
53+
PADDLE_ENFORCE_EQ(shifts.size(), 1,
54+
platform::errors::InvalidArgument(
55+
"When dims.size() == 0, shifts.size() "
56+
"should be equal to 1, But received "
57+
"shifts.size() = %d",
58+
shifts.size()));
59+
}
5860
}
5961

6062
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
@@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker {
105107
"The number of places by which the elements "
106108
"of the tensor are shifted.")
107109
.SetDefault({});
110+
AddInput("ShiftsTensor",
111+
"The number of places by which the elements of the tensor "
112+
"are shifted.")
113+
.AsDispensable();
108114
AddAttr<std::vector<int64_t>>(
109115
"axis",
110116
"Axis along which to roll. It must have the same size "
@@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker<T> {
129135
void Apply(GradOpPtr<T> op) const override {
130136
op->SetType("roll_grad");
131137
op->SetInput("X", this->Input("X"));
138+
if (this->HasInput("ShiftsTensor")) {
139+
op->SetInput("ShiftsTensor", this->Input("ShiftsTensor"));
140+
}
132141
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
133142
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
134143
op->SetAttrMap(this->Attrs());
@@ -174,7 +183,12 @@ REGISTER_OP_VERSION(roll)
174183
"(std::vector<int64_t>) Axis along which to roll. "
175184
"It must have the same size with shifts, or size = 0.",
176185
std::vector<int64_t>())
177-
.DeleteAttr(
178-
"dims",
179-
"(std::vector<int64_t>) Dims along which to roll. "
180-
"It must have the same size with shifts, or size = 0."));
186+
.DeleteAttr("dims",
187+
"(std::vector<int64_t>) Dims along which to roll. "
188+
"It must have the same size with shifts, or size = 0."))
189+
.AddCheckpoint(
190+
R"ROC(Upgrade roll add a dispensable input "ShiftsTensor".)ROC",
191+
paddle::framework::compatible::OpVersionDesc().NewInput(
192+
"ShiftsTensor",
193+
"The number of places by which the elements of"
194+
"the tensor are shifted."));

paddle/fluid/operators/roll_op.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ class RollKernel<platform::CUDADeviceContext, T>
5959
auto* in = context.Input<LoDTensor>("X");
6060
auto* out = context.Output<LoDTensor>("Out");
6161
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
62+
if (context.HasInput("ShiftsTensor")) {
63+
const auto* shifts_tensor =
64+
context.Input<framework::Tensor>("ShiftsTensor");
65+
PADDLE_ENFORCE_EQ(
66+
shifts_tensor->dims().size(), 1,
67+
platform::errors::InvalidArgument(
68+
"The rank of ShiftsTensor is expected to be 1, got %s",
69+
shifts_tensor->dims().size()));
70+
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
71+
}
6272
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
6373

6474
auto* in_data = in->data<T>();
@@ -134,6 +144,16 @@ class RollGradKernel<platform::CUDADeviceContext, T>
134144
auto* in = context.Input<LoDTensor>(framework::GradVarName("Out"));
135145
auto* out = context.Output<LoDTensor>(framework::GradVarName("X"));
136146
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
147+
if (context.HasInput("ShiftsTensor")) {
148+
const auto* shifts_tensor =
149+
context.Input<framework::Tensor>("ShiftsTensor");
150+
PADDLE_ENFORCE_EQ(
151+
shifts_tensor->dims().size(), 1,
152+
platform::errors::InvalidArgument(
153+
"The rank of ShiftsTensor is expected to be 1, got %s",
154+
shifts_tensor->dims().size()));
155+
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
156+
}
137157
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
138158

139159
auto* in_data = in->data<T>();

paddle/fluid/operators/roll_op.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include <memory>
1717
#include <vector>
1818
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/utils.h"
20+
#include "paddle/fluid/platform/enforce.h"
1921

2022
namespace paddle {
2123
namespace operators {
@@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel<T> {
8587
auto& input = input_var->Get<LoDTensor>();
8688
auto* output = output_var->GetMutable<LoDTensor>();
8789
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
90+
if (context.HasInput("ShiftsTensor")) {
91+
const auto* shifts_tensor =
92+
context.Input<framework::Tensor>("ShiftsTensor");
93+
PADDLE_ENFORCE_EQ(
94+
shifts_tensor->dims().size(), 1,
95+
platform::errors::InvalidArgument(
96+
"The rank of ShiftsTensor is expected to be 1, got %s",
97+
shifts_tensor->dims().size()));
98+
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
99+
}
88100
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
89101

90102
std::vector<T> out_vec;
@@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel<T> {
123135
auto& input = input_var->Get<LoDTensor>();
124136
auto* output = output_var->GetMutable<LoDTensor>();
125137
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
138+
if (context.HasInput("ShiftsTensor")) {
139+
const auto* shifts_tensor =
140+
context.Input<framework::Tensor>("ShiftsTensor");
141+
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
142+
}
126143
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
127144

128145
std::vector<T> out_vec;

0 commit comments

Comments
 (0)