Skip to content

Commit dbc08d6

Browse files
modify complex template for elementwise ops (#33071)
* modify complex template for elementwise ops * modify mul, div grad struct * add complex template for CudaShuffleDownSync CudaShuffleXorSync funcs and fix the bug when delete cuda<9000 * fix shuffle func args bug * fix shuffle func args bug * fix shuffle func args bug
1 parent 3a7b9ed commit dbc08d6

File tree

11 files changed

+180
-193
lines changed

11 files changed

+180
-193
lines changed

paddle/fluid/operators/elementwise/elementwise_add_op.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ limitations under the License. */
2020

2121
namespace paddle {
2222
namespace platform {
23-
struct complex128;
24-
struct complex64;
23+
template <typename T>
24+
struct complex;
2525
} // namespace platform
2626
} // namespace paddle
2727

@@ -135,19 +135,19 @@ REGISTER_OP_CPU_KERNEL(
135135
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
136136
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
137137
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
138-
paddle::platform::complex64>,
138+
paddle::platform::complex<float>>,
139139
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
140-
paddle::platform::complex128>);
140+
paddle::platform::complex<double>>);
141141
REGISTER_OP_CPU_KERNEL(
142142
elementwise_add_grad,
143143
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
144144
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>,
145145
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
146146
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
147147
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext,
148-
paddle::platform::complex64>,
148+
paddle::platform::complex<float>>,
149149
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext,
150-
paddle::platform::complex128>);
150+
paddle::platform::complex<double>>);
151151
REGISTER_OP_CPU_KERNEL(
152152
elementwise_add_grad_grad,
153153
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
@@ -159,9 +159,9 @@ REGISTER_OP_CPU_KERNEL(
159159
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
160160
int64_t>,
161161
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
162-
paddle::platform::complex64>,
162+
paddle::platform::complex<float>>,
163163
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
164-
paddle::platform::complex128>);
164+
paddle::platform::complex<double>>);
165165

166166
// A specialization elementwise_add operator, used in gradient accumulation with
167167
// inplace addto.
@@ -178,9 +178,9 @@ REGISTER_OP_CPU_KERNEL(
178178
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
179179
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
180180
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
181-
paddle::platform::complex64>,
181+
paddle::platform::complex<float>>,
182182
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
183-
paddle::platform::complex128>);
183+
paddle::platform::complex<double>>);
184184

185185
REGISTER_OP_VERSION(elementwise_add)
186186
.AddCheckpoint(

paddle/fluid/operators/elementwise/elementwise_add_op.cu

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
1515
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
16-
#include "paddle/fluid/platform/complex128.h"
17-
#include "paddle/fluid/platform/complex64.h"
16+
#include "paddle/fluid/platform/complex.h"
1817
#include "paddle/fluid/platform/float16.h"
1918

2019
namespace ops = paddle::operators;
@@ -141,17 +140,19 @@ REGISTER_OP_CUDA_KERNEL(
141140
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
142141
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
143142
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
144-
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>,
145-
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>);
143+
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
144+
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);
146145
REGISTER_OP_CUDA_KERNEL(
147146
elementwise_add_grad,
148147
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
149148
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
150149
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
151150
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
152151
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>,
153-
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex64>,
154-
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex128>);
152+
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext,
153+
plat::complex<float>>,
154+
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext,
155+
plat::complex<double>>);
155156
REGISTER_OP_CUDA_KERNEL(
156157
elementwise_add_grad_grad,
157158
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, float>,
@@ -160,15 +161,15 @@ REGISTER_OP_CUDA_KERNEL(
160161
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
161162
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
162163
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
163-
plat::complex64>,
164+
plat::complex<float>>,
164165
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
165-
plat::complex128>);
166+
plat::complex<double>>);
166167

167168
REGISTER_OP_CUDA_KERNEL(
168169
grad_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
169170
ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
170171
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
171172
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
172173
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
173-
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>,
174-
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>);
174+
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
175+
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);

paddle/fluid/operators/elementwise/elementwise_div_op.cc

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ limitations under the License. */
1717
#include <string>
1818

1919
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
20-
#include "paddle/fluid/platform/complex128.h"
21-
#include "paddle/fluid/platform/complex64.h"
20+
#include "paddle/fluid/platform/complex.h"
2221

2322
namespace paddle {
2423
namespace operators {
@@ -135,19 +134,19 @@ REGISTER_OP_CPU_KERNEL(
135134
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int>,
136135
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int64_t>,
137136
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
138-
paddle::platform::complex64>,
137+
paddle::platform::complex<float>>,
139138
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
140-
paddle::platform::complex128>);
139+
paddle::platform::complex<double>>);
141140
REGISTER_OP_CPU_KERNEL(
142141
elementwise_div_grad,
143142
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, float>,
144143
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, double>,
145144
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int>,
146145
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
147146
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
148-
paddle::platform::complex64>,
147+
paddle::platform::complex<float>>,
149148
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
150-
paddle::platform::complex128>);
149+
paddle::platform::complex<double>>);
151150

152151
REGISTER_OP_CPU_KERNEL(
153152
elementwise_div_grad_grad,
@@ -160,9 +159,9 @@ REGISTER_OP_CPU_KERNEL(
160159
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
161160
int64_t>,
162161
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
163-
paddle::platform::complex64>,
162+
paddle::platform::complex<float>>,
164163
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
165-
paddle::platform::complex128>);
164+
paddle::platform::complex<double>>);
166165

167166
REGISTER_OP_VERSION(elementwise_div)
168167
.AddCheckpoint(

paddle/fluid/operators/elementwise/elementwise_div_op.cu

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ limitations under the License. */
1414
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
1515
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
1616
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
17-
#include "paddle/fluid/platform/complex128.h"
18-
#include "paddle/fluid/platform/complex64.h"
17+
#include "paddle/fluid/platform/complex.h"
1918
#include "paddle/fluid/platform/float16.h"
2019

2120
namespace ops = paddle::operators;
@@ -76,38 +75,43 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y,
7675
}
7776

7877
template <>
79-
__global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex64>(
80-
const paddle::platform::complex64* x, const paddle::platform::complex64* y,
81-
const paddle::platform::complex64* out,
82-
const paddle::platform::complex64* dout, int64_t size,
83-
paddle::platform::complex64* dx, paddle::platform::complex64* dy) {
78+
__global__ void
79+
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>(
80+
const paddle::platform::complex<float>* x,
81+
const paddle::platform::complex<float>* y,
82+
const paddle::platform::complex<float>* out,
83+
const paddle::platform::complex<float>* dout, int64_t size,
84+
paddle::platform::complex<float>* dx,
85+
paddle::platform::complex<float>* dy) {
8486
int col = blockIdx.x * blockDim.x + threadIdx.x;
8587

8688
while (col < size) {
87-
paddle::platform::complex64 o = dout[col];
88-
paddle::platform::complex64 y_conj(y[col].real, -y[col].imag);
89-
paddle::platform::complex64 out_div_y_conj((out[col] / y[col]).real,
90-
-(out[col] / y[col]).imag);
89+
paddle::platform::complex<float> o = dout[col];
90+
paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag);
91+
paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real,
92+
-(out[col] / y[col]).imag);
9193
dx[col] = o / y_conj;
9294
dy[col] = -o * out_div_y_conj;
9395
col += blockDim.x * gridDim.x;
9496
}
9597
}
9698

9799
template <>
98-
__global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex128>(
99-
const paddle::platform::complex128* x,
100-
const paddle::platform::complex128* y,
101-
const paddle::platform::complex128* out,
102-
const paddle::platform::complex128* dout, int64_t size,
103-
paddle::platform::complex128* dx, paddle::platform::complex128* dy) {
100+
__global__ void
101+
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>(
102+
const paddle::platform::complex<double>* x,
103+
const paddle::platform::complex<double>* y,
104+
const paddle::platform::complex<double>* out,
105+
const paddle::platform::complex<double>* dout, int64_t size,
106+
paddle::platform::complex<double>* dx,
107+
paddle::platform::complex<double>* dy) {
104108
int col = blockIdx.x * blockDim.x + threadIdx.x;
105109

106110
while (col < size) {
107-
paddle::platform::complex128 o = dout[col];
108-
paddle::platform::complex128 y_conj(y[col].real, -y[col].imag);
109-
paddle::platform::complex128 out_div_y_conj((out[col] / y[col]).real,
110-
-(out[col] / y[col]).imag);
111+
paddle::platform::complex<double> o = dout[col];
112+
paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag);
113+
paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real,
114+
-(out[col] / y[col]).imag);
111115
dx[col] = o / y_conj;
112116
dy[col] = -o * out_div_y_conj;
113117
col += blockDim.x * gridDim.x;
@@ -145,9 +149,9 @@ REGISTER_OP_CUDA_KERNEL(
145149
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
146150
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>,
147151
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
148-
paddle::platform::complex64>,
152+
paddle::platform::complex<float>>,
149153
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
150-
paddle::platform::complex128>);
154+
paddle::platform::complex<double>>);
151155
REGISTER_OP_CUDA_KERNEL(
152156
elementwise_div_grad,
153157
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
@@ -157,9 +161,9 @@ REGISTER_OP_CUDA_KERNEL(
157161
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
158162
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
159163
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
160-
paddle::platform::complex64>,
164+
paddle::platform::complex<float>>,
161165
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
162-
paddle::platform::complex128>);
166+
paddle::platform::complex<double>>);
163167
REGISTER_OP_CUDA_KERNEL(
164168
elementwise_div_grad_grad,
165169
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
@@ -173,6 +177,6 @@ REGISTER_OP_CUDA_KERNEL(
173177
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
174178
int64_t>,
175179
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
176-
paddle::platform::complex64>,
180+
paddle::platform::complex<float>>,
177181
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
178-
paddle::platform::complex128>);
182+
paddle::platform::complex<double>>);

paddle/fluid/operators/elementwise/elementwise_div_op.h

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -74,23 +74,13 @@ struct DivGradDX {
7474
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; }
7575
};
7676

77-
template <>
78-
struct DivGradDX<paddle::platform::complex64> {
79-
HOSTDEVICE paddle::platform::complex64 operator()(
80-
paddle::platform::complex64 x, paddle::platform::complex64 y,
81-
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
82-
paddle::platform::complex64 y_conj(y.real, -y.imag);
83-
return dout / y_conj;
84-
}
85-
};
86-
87-
template <>
88-
struct DivGradDX<paddle::platform::complex128> {
89-
HOSTDEVICE paddle::platform::complex128 operator()(
90-
paddle::platform::complex128 x, paddle::platform::complex128 y,
91-
paddle::platform::complex128 out,
92-
paddle::platform::complex128 dout) const {
93-
paddle::platform::complex128 y_conj(y.real, -y.imag);
77+
template <typename T>
78+
struct DivGradDX<paddle::platform::complex<T>> {
79+
HOSTDEVICE paddle::platform::complex<T> operator()(
80+
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
81+
paddle::platform::complex<T> out,
82+
paddle::platform::complex<T> dout) const {
83+
paddle::platform::complex<T> y_conj(y.real, -y.imag);
9484
return dout / y_conj;
9585
}
9686
};
@@ -102,23 +92,13 @@ struct DivGradDY {
10292
}
10393
};
10494

105-
template <>
106-
struct DivGradDY<paddle::platform::complex64> {
107-
HOSTDEVICE paddle::platform::complex64 operator()(
108-
paddle::platform::complex64 x, paddle::platform::complex64 y,
109-
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
110-
paddle::platform::complex64 out_div_y_conj((out / y).real, -(out / y).imag);
111-
return -dout * out_div_y_conj;
112-
}
113-
};
114-
115-
template <>
116-
struct DivGradDY<paddle::platform::complex128> {
117-
HOSTDEVICE paddle::platform::complex128 operator()(
118-
paddle::platform::complex128 x, paddle::platform::complex128 y,
119-
paddle::platform::complex128 out,
120-
paddle::platform::complex128 dout) const {
121-
paddle::platform::complex128 out_div_y_conj((out / y).real,
95+
template <typename T>
96+
struct DivGradDY<paddle::platform::complex<T>> {
97+
HOSTDEVICE paddle::platform::complex<T> operator()(
98+
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
99+
paddle::platform::complex<T> out,
100+
paddle::platform::complex<T> dout) const {
101+
paddle::platform::complex<T> out_div_y_conj((out / y).real,
122102
-(out / y).imag);
123103
return -dout * out_div_y_conj;
124104
}

paddle/fluid/operators/elementwise/elementwise_mul_op.cc

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ limitations under the License. */
1616
#include <memory>
1717
#include <string>
1818
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
19-
#include "paddle/fluid/platform/complex128.h"
20-
#include "paddle/fluid/platform/complex64.h"
19+
#include "paddle/fluid/platform/complex.h"
2120

2221
namespace paddle {
2322
namespace operators {
@@ -134,19 +133,19 @@ REGISTER_OP_CPU_KERNEL(
134133
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
135134
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>,
136135
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
137-
paddle::platform::complex64>,
136+
paddle::platform::complex<float>>,
138137
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
139-
paddle::platform::complex128>);
138+
paddle::platform::complex<double>>);
140139
REGISTER_OP_CPU_KERNEL(
141140
elementwise_mul_grad,
142141
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, float>,
143142
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>,
144143
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
145144
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
146145
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
147-
paddle::platform::complex64>,
146+
paddle::platform::complex<float>>,
148147
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
149-
paddle::platform::complex128>);
148+
paddle::platform::complex<double>>);
150149
REGISTER_OP_CPU_KERNEL(
151150
elementwise_mul_grad_grad,
152151
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
@@ -158,9 +157,9 @@ REGISTER_OP_CPU_KERNEL(
158157
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
159158
int64_t>,
160159
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
161-
paddle::platform::complex64>,
160+
paddle::platform::complex<float>>,
162161
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
163-
paddle::platform::complex128>);
162+
paddle::platform::complex<double>>);
164163

165164
REGISTER_OP_VERSION(elementwise_mul)
166165
.AddCheckpoint(

0 commit comments

Comments
 (0)