Skip to content

Commit 3773748

Browse files
committed
encorder layer; modify the code according review comment
2 parents f0836be + be066f6 commit 3773748

File tree

11 files changed

+236
-253
lines changed

11 files changed

+236
-253
lines changed

paddle/fluid/operators/fused/fused_dropout_act_bias.h

Lines changed: 38 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -13,43 +13,27 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#ifndef _USE_MATH_DEFINES
17+
#define _USE_MATH_DEFINES
18+
#endif
1619

1720
#include "paddle/fluid/operators/fused/fused_dropout_common.h"
18-
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
1921
#include "paddle/fluid/operators/math/functors.h"
2022

2123
namespace paddle {
2224
namespace operators {
2325

24-
typedef platform::float16 fp16;
25-
26-
/**
27-
*@brief the relu functor
28-
*/
29-
template <typename T>
30-
struct ReluFunctor {
31-
__host__ __device__ T operator()(const T *args) const {
32-
math::ReluFunctor<T> relu;
33-
return relu(args[0]);
34-
}
35-
};
36-
37-
template <typename T>
38-
struct ReluGradFunctor {
39-
__host__ __device__ __forceinline__ T operator()(const T *args) const {
40-
math::ReluGradFunctor<T> relu_grad;
41-
return args[0] * relu_grad.UseOut(args[1]);
42-
}
43-
};
44-
4526
/**
4627
*@brief the gelu functor
4728
*/
4829
template <typename T>
4930
struct GeluFunctor {
50-
__host__ __device__ T operator()(const T *args) const {
51-
math::GeluFunctor<T> gelu;
52-
return gelu(args[0]);
31+
inline __host__ __device__ T operator()(const T x) const {
32+
using U = LayerNormParamType<T>;
33+
const U casted_x = static_cast<U>(x);
34+
const U temp = erf(casted_x * static_cast<U>(M_SQRT1_2));
35+
const U out = (casted_x * static_cast<U>(0.5) * (static_cast<U>(1) + temp));
36+
return static_cast<T>(out);
5337
}
5438
};
5539

@@ -58,11 +42,17 @@ struct GeluFunctor {
5842
*/
5943
template <typename T>
6044
struct GeluGradFunctor {
61-
__host__ __device__ T operator()(const T *args) const {
62-
const T grad = args[0];
63-
const T x = args[1];
64-
math::GeluGradFunctor<T> gelu_grad;
65-
return grad * gelu_grad.UseOut(x);
45+
inline __host__ __device__ T UseOut(const T x) const {
46+
using U = LayerNormParamType<T>;
47+
auto casted_x = static_cast<U>(x);
48+
49+
auto first =
50+
static_cast<U>(0.5) *
51+
(static_cast<U>(1) + erf(casted_x * static_cast<U>(M_SQRT1_2)));
52+
53+
auto second = static_cast<U>(0.5 * M_2_SQRTPI * M_SQRT1_2) * casted_x *
54+
exp(-static_cast<U>(0.5) * casted_x * casted_x);
55+
return static_cast<T>((first + second));
6656
}
6757
};
6858

@@ -72,13 +62,12 @@ struct GeluGradFunctor {
7262
* the bias shape is (1, cols)
7363
*/
7464
template <typename T, typename MaskType, int VecSize, typename Functor>
75-
__global__ void FusedDropoutActBias(Functor act, const uint64_t seed,
76-
const uint64_t rows, const uint64_t cols,
77-
const int increment,
78-
const float dropout_prob,
79-
const bool is_upscale_in_train,
80-
const bool is_test, const T *src,
81-
const T *bias, T *dst, MaskType *mask) {
65+
__global__ void FusedDropoutActBias(
66+
Functor act, const uint64_t seed, const uint64_t rows, const uint64_t cols,
67+
const int increment, const float dropout_prob,
68+
const bool is_upscale_in_train, const bool is_test,
69+
const T *__restrict__ src, const T *__restrict__ bias, T *dst,
70+
MaskType *mask) {
8271
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
8372
int row_id = blockIdx.y;
8473
int idx = row_id * cols + col_id;
@@ -102,9 +91,8 @@ __global__ void FusedDropoutActBias(Functor act, const uint64_t seed,
10291
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
10392
using MaskStoreT = platform::AlignedVector<MaskType, VecSize>;
10493

105-
const int tmp_cols = cols / VecSize * VecSize;
10694
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
107-
for (int i = col_id * VecSize; i < tmp_cols;
95+
for (int i = col_id * VecSize; i < cols;
10896
i += blockDim.x * gridDim.x * VecSize) {
10997
LoadT src_vec;
11098
LoadT bias_vec;
@@ -139,11 +127,14 @@ __global__ void FusedDropoutActBias(Functor act, const uint64_t seed,
139127
#pragma unroll
140128
for (int ii = 0; ii < VecSize; ii++) {
141129
const T tmp = src_vec[ii] + bias_vec[ii];
142-
dest_vec[ii] = act(&tmp) * static_cast<T>(mask_vec[ii]) * factor;
130+
const T act_out = act(tmp);
131+
dest_vec[ii] = act_out * static_cast<T>(mask_vec[ii]) * factor;
143132
}
144133
// store result to global
145134
platform::Store<T, VecSize>(dest_vec, &dst[r * cols + i]);
146-
platform::Store<MaskType, VecSize>(mask_vec, &mask[r * cols + i]);
135+
if (!is_test) {
136+
platform::Store<MaskType, VecSize>(mask_vec, &mask[r * cols + i]);
137+
}
147138
}
148139
}
149140
}
@@ -161,10 +152,8 @@ void LaunchDropoutActBias(Functor act_functor, const uint64_t seed,
161152
const platform::CUDADeviceContext &ctx) {
162153
// dropout_prob == 1.0f
163154
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
164-
PADDLE_ENFORCE_CUDA_SUCCESS(
165-
cudaMemsetAsync(dst, 0, rows * cols * sizeof(T), ctx.stream()));
166-
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(
167-
mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream()));
155+
SetZero<T>(ctx, dst, rows * cols);
156+
SetZero<MaskType>(ctx, mask_data, rows * cols);
168157
return;
169158
}
170159

@@ -211,7 +200,7 @@ __global__ void FusedDropoutActGrad(Functor act_grad, const T *dout,
211200
T args[2];
212201
args[0] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
213202
args[1] = src_vec[ii];
214-
dx_vec[ii] = act_grad(args);
203+
dx_vec[ii] = args[0] * act_grad.UseOut(args[1]);
215204
}
216205
platform::Store<T, VecSize>(dx_vec, &dx[i]);
217206
}
@@ -221,7 +210,7 @@ __global__ void FusedDropoutActGrad(Functor act_grad, const T *dout,
221210
* blocks(128 * 8)
222211
* 1. calculate the dx and reduce total rows to 128 rows
223212
* 2. save 128*8 temporary sum in 8*128 shared memory
224-
* 3. reduce the sum of 128 rows data by 8*VecSize warps
213+
* 3. reduce the sum of 128 cols data by 8*VecSize warps
225214
*/
226215
template <typename T, typename MaskType, int BlockSizeX, int BlockSizeY,
227216
int VecSize, typename Functor>
@@ -257,43 +246,15 @@ __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout,
257246
T args[2];
258247
args[0] = dout_vec[i] * static_cast<T>(mask_vec[i]) * factor;
259248
args[1] = src_vec[i] + bias_vec[i];
260-
val = act_grad(args);
249+
val = args[0] * act_grad.UseOut(args[1]);
261250
dx_vec[i] = val;
262251
tmp_sum[i] += val;
263252
}
264253
platform::Store<T, VecSize>(dx_vec, &dx[index]);
265254
}
266255
}
267256

268-
__shared__ T cache[BlockSizeX * VecSize][BlockSizeY];
269-
for (int i = 0; i < VecSize; i++) {
270-
cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i];
271-
}
272-
__syncthreads();
273-
274-
// reduce sum
275-
T sum = static_cast<T>(0);
276-
int tid = threadIdx.y * blockDim.x + threadIdx.x;
277-
int x = tid >> 5; // warp id
278-
int y = tid & 31; // thread id on warp 0~31
279-
280-
// need BlockSizeX * VecSize warps
281-
if (x < BlockSizeX * VecSize) {
282-
// reduce 128 to 32
283-
#pragma unroll
284-
for (int i = 0; i < (BlockSizeY >> 5); i++) {
285-
sum += cache[x][y + i * 32];
286-
}
287-
}
288-
289-
// reduce 32 to 1
290-
sum = WarpReduceSum<T>(sum);
291-
292-
// save sum to dbias
293-
int bias_id = blockIdx.x * blockDim.x * VecSize + x;
294-
if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) {
295-
dbias[bias_id] = sum;
296-
}
257+
CalculateDBias<T, VecSize, BlockSizeX, BlockSizeY>(tmp_sum, dbias, cols);
297258
}
298259

299260
/**

paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ limitations under the License. */
2020
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
2121
#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
2222
#include "paddle/fluid/operators/fused/fused_dropout_test.h"
23+
#include "paddle/fluid/operators/math/functors.h"
2324

2425
namespace framework = paddle::framework;
2526
namespace platform = paddle::platform;
2627
namespace details = paddle::operators::details;
27-
namespace operators = paddle::operators;
28+
namespace math = paddle::operators::math;
2829

2930
/**
3031
* @brief the unittest of fused_dropout_act_bias
@@ -111,16 +112,12 @@ struct TestFusedDropoutActBias {
111112
}
112113

113114
{
114-
out.Resize({rows, cols});
115-
out.mutable_data<T>(place);
116-
mask.Resize({rows, cols});
117-
mask.mutable_data<uint8_t>(place);
118-
dsrc.Resize({rows, cols});
119-
dsrc.mutable_data<T>(place);
115+
out.mutable_data<T>({rows, cols}, place);
116+
mask.mutable_data<uint8_t>({rows, cols}, place);
117+
dsrc.mutable_data<T>({rows, cols}, place);
120118

121119
if (has_bias) {
122-
dbias.Resize({cols});
123-
dbias.mutable_data<T>(place);
120+
dbias.mutable_data<T>({cols}, place);
124121
}
125122
}
126123
}
@@ -133,7 +130,7 @@ struct TestFusedDropoutActBias {
133130
for (int i = 0; i < rows; i++) {
134131
for (int j = 0; j < cols; j++) {
135132
const T tmp = src_vec[i * cols + j] + bias_vec[j];
136-
out1[i * cols + j] = act(&tmp);
133+
out1[i * cols + j] = act(tmp);
137134
}
138135
}
139136
// call dropout
@@ -143,7 +140,7 @@ struct TestFusedDropoutActBias {
143140
for (int i = 0; i < rows; i++) {
144141
for (int j = 0; j < cols; j++) {
145142
const T tmp = src_vec[i * cols + j];
146-
out1[i * cols + j] = act(&tmp);
143+
out1[i * cols + j] = act(tmp);
147144
}
148145
}
149146

@@ -164,22 +161,22 @@ struct TestFusedDropoutActBias {
164161
GradFunctor act_grad;
165162
for (int i = 0; i < rows; i++) {
166163
for (int j = 0; j < cols; j++) {
164+
T args[2];
165+
args[0] = _out[i * cols + j];
167166
if (has_bias) {
168-
T args[2];
169-
args[0] = _out[i * cols + j];
170167
args[1] = src_vec[i * cols + j] + bias_vec[j];
171-
T val = act_grad(args);
172-
correct_dbias[j] += val;
173-
correct_dsrc[i * cols + j] = val;
174168
} else {
175-
T args[2];
176-
args[0] = _out[i * cols + j];
177169
args[1] = src_vec[i * cols + j];
178-
T val = act_grad(args);
179-
correct_dsrc[i * cols + j] = val;
180170
}
171+
T val = args[0] * act_grad.UseOut(args[1]);
172+
correct_dsrc[i * cols + j] = val;
181173
}
182174
}
175+
176+
if (has_bias) {
177+
// reduce_sum: keep the same calculate order as the GPU
178+
ReduceSum<T>(correct_dsrc, &correct_dbias, rows, cols);
179+
}
183180
}
184181

185182
void FusedForward() {
@@ -273,47 +270,41 @@ static void BaseTest(const bool is_fp16 = false) {
273270
const int rows = 16;
274271
std::vector<int> cols_list = {16, 17};
275272
bool has_bias[2] = {true, false};
276-
T default_diff = !is_fp16 ? static_cast<T>(1e-3) : default_diff =
277-
static_cast<T>(1e-2);
273+
T default_diff = !is_fp16 ? static_cast<T>(1e-5) : static_cast<T>(1e-1);
278274
for (auto cols : {16, 17}) {
279275
for (auto has_bias : {true, false}) {
280276
TestFusedDropoutActBias<T, Functor, GradFunctor> test(rows, cols);
281277
test.has_bias = has_bias;
282278
test.Run();
283279
test.CheckOut(default_diff);
284-
if (!is_fp16) {
285-
test.CheckGrad(default_diff);
286-
}
280+
test.CheckGrad(default_diff);
287281
}
288282
}
289283
}
290284

291285
TEST(FusedDropout, GPUFusedDorpoutActBias) {
292-
BaseTest<float, paddle::operators::ReluFunctor<float>,
293-
paddle::operators::ReluGradFunctor<float>>();
294-
BaseTest<float, operators::GeluFunctor<float>,
295-
operators::GeluGradFunctor<float>>();
286+
BaseTest<float, math::ReluFunctor<float>, math::ReluGradFunctor<float>>();
287+
BaseTest<float, paddle::operators::GeluFunctor<float>,
288+
paddle::operators::GeluGradFunctor<float>>();
296289
}
297290
TEST(FusedDropout, GPUFusedDropoutActBiasDouble) {
298-
BaseTest<double, operators::ReluFunctor<double>,
299-
operators::ReluGradFunctor<double>>();
300-
BaseTest<double, operators::GeluFunctor<double>,
301-
operators::GeluGradFunctor<double>>();
291+
BaseTest<double, math::ReluFunctor<double>, math::ReluGradFunctor<double>>();
292+
BaseTest<double, paddle::operators::GeluFunctor<double>,
293+
paddle::operators::GeluGradFunctor<double>>();
302294
}
303295

304296
// test fp16, For inference, check_grad is not required. ref: test_dropout_op.py
305297
TEST(FusedDropout, GPUFusedDropoutActBiasFp16) {
306298
using fp16 = platform::float16;
307-
BaseTest<fp16, operators::ReluFunctor<fp16>,
308-
operators::ReluGradFunctor<fp16>>(true);
299+
BaseTest<fp16, math::ReluFunctor<fp16>, math::ReluGradFunctor<fp16>>(true);
309300
}
310301

311302
TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) {
312303
const int rows = 16;
313304
const int cols = 16;
314305
for (auto is_upscale_in_train : {true, false}) {
315-
TestFusedDropoutActBias<float, operators::ReluFunctor<float>,
316-
operators::ReluGradFunctor<float>>
306+
TestFusedDropoutActBias<float, math::ReluFunctor<float>,
307+
math::ReluGradFunctor<float>>
317308
test(rows, cols, 0, 1.0, is_upscale_in_train, false);
318309
test.Run();
319310
test.CheckOut(static_cast<float>(1e-5));
@@ -324,8 +315,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) {
324315
TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) {
325316
const int rows = 16;
326317
const int cols = 16;
327-
TestFusedDropoutActBias<float, operators::ReluFunctor<float>,
328-
operators::ReluGradFunctor<float>>
318+
TestFusedDropoutActBias<float, math::ReluFunctor<float>,
319+
math::ReluGradFunctor<float>>
329320
test(rows, cols, 0, 0.35, true, true);
330321
test.Run();
331322
test.CheckOut(static_cast<float>(1e-5));
@@ -335,8 +326,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) {
335326
TEST(FusedDropout, GPUFusedDropoutActBiasSeed) {
336327
const int rows = 16;
337328
const int cols = 16;
338-
TestFusedDropoutActBias<float, operators::ReluFunctor<float>,
339-
operators::ReluGradFunctor<float>>
329+
TestFusedDropoutActBias<float, math::ReluFunctor<float>,
330+
math::ReluGradFunctor<float>>
340331
test(rows, cols, 125, 0.0, false, false);
341332
test.Run();
342333
test.CheckOut(static_cast<float>(1e-5));
@@ -346,8 +337,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasSeed) {
346337
TEST(FusedDropout, GPUFusedDropoutActBiasLargeShape) {
347338
const int rows = 256;
348339
const int cols = 4096;
349-
TestFusedDropoutActBias<float, operators::ReluFunctor<float>,
350-
operators::ReluGradFunctor<float>>
340+
TestFusedDropoutActBias<float, math::ReluFunctor<float>,
341+
math::ReluGradFunctor<float>>
351342
test(rows, cols);
352343
test.Run();
353344
test.CheckOut(static_cast<float>(1e-5));

0 commit comments

Comments
 (0)