Skip to content

Commit 036b430

Browse files
committed
optimize code structure to facilitate reuse
1 parent e2808ff commit 036b430

File tree

3 files changed

+159
-149
lines changed

3 files changed

+159
-149
lines changed

paddle/fluid/operators/fused/fused_dropout.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,5 @@ struct alignas(sizeof(T) * VecSize) AlignedVector {
6666
T val[VecSize];
6767
};
6868

69-
// reduce sum by a warp
70-
template <typename U>
71-
static __forceinline__ __device__ U WarpReduceSum(U val) {
72-
unsigned mask = 0u;
73-
CREATE_SHFL_MASK(mask, true);
74-
const int warpSize = 32;
75-
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
76-
val += paddle::platform::CudaShuffleDownSync(mask, val, offset);
77-
}
78-
return val;
79-
}
80-
8169
} // namespace operators
8270
} // namespace paddle

paddle/fluid/operators/fused/fused_residual_dropout_bias.h

Lines changed: 143 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,79 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include "paddle/fluid/operators/fused/fused_dropout.h"
18+
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
1819

1920
namespace paddle {
2021
namespace operators {
2122

2223
namespace platform = paddle::platform;
2324
namespace cg = cooperative_groups;
2425

26+
/**
27+
* @brief fused the add_bias, dropout, add residual into one operators
28+
*
29+
*/
30+
2531
/********Forward**************/
32+
/**
33+
* @brief the fused function called by every thread
34+
*/
35+
template <typename T, typename MaskType, typename U, int VecSize,
36+
bool layer_norm>
37+
__forceinline__ __device__ void FusedResidualDropoutBiasVecOneThread(
38+
const int row_id, const int col_id, const int cols,
39+
curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor,
40+
const T *src, const T *residual, const T *bias, T *dst, MaskType *mask,
41+
U *mean_val, U *var_val) {
42+
using LoadT = AlignedVector<T, VecSize>;
43+
using MaskLoadT = AlignedVector<MaskType, VecSize>;
44+
T src_vec[VecSize];
45+
T residual_vec[VecSize];
46+
T bias_vec[VecSize];
47+
#pragma unroll
48+
for (int ii = 0; ii < VecSize; ii++) {
49+
bias_vec[ii] = static_cast<T>(0);
50+
}
51+
// vectorize load data from global
52+
LoadT *value = reinterpret_cast<LoadT *>(&src_vec);
53+
LoadT *residual_value = reinterpret_cast<LoadT *>(&residual_vec);
54+
*value = *reinterpret_cast<const LoadT *>(&src[row_id * cols + col_id]);
55+
*residual_value =
56+
*reinterpret_cast<const LoadT *>(&residual[row_id * cols + col_id]);
57+
58+
LoadT *bias_value =
59+
bias != nullptr ? reinterpret_cast<LoadT *>(&bias_vec) : nullptr;
60+
if (bias != nullptr)
61+
*bias_value = *reinterpret_cast<const LoadT *>(&bias[col_id]);
62+
63+
float4 rand = curand_uniform4(state);
64+
T dest_vec[VecSize];
65+
MaskType mask_vec[VecSize];
66+
67+
#pragma unroll
68+
for (int ii = 0; ii < VecSize; ii++) {
69+
mask_vec[ii] = (MaskType)((&rand.x)[ii] >= dropout_prob);
70+
}
71+
72+
#pragma unroll
73+
for (int ii = 0; ii < VecSize; ii++) {
74+
dest_vec[ii] =
75+
(src_vec[ii] + bias_vec[ii]) * static_cast<T>(mask_vec[ii]) * factor +
76+
residual_vec[ii];
77+
if (layer_norm) {
78+
U tmp = static_cast<U>(dest_vec[ii]);
79+
*mean_val += tmp;
80+
*var_val += (tmp * tmp);
81+
}
82+
}
83+
84+
// store result to global
85+
*(reinterpret_cast<LoadT *>(&dst[row_id * cols + col_id])) =
86+
*reinterpret_cast<LoadT *>(&dest_vec[0]);
87+
*(reinterpret_cast<MaskLoadT *>(&mask[row_id * cols + col_id])) =
88+
*reinterpret_cast<MaskLoadT *>(&mask_vec[0]);
89+
}
90+
2691
/**
2792
* @brief dst = residual + dropout(src + bias);
2893
* the src, residual, mask and dst shape is (rows, cols)
@@ -46,67 +111,71 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows,
46111
if (!is_upscale_in_train) {
47112
factor = static_cast<T>(1.0f);
48113
}
49-
using LoadT = AlignedVector<T, VecSize>;
50-
using MaskLoadT = AlignedVector<MaskType, VecSize>;
51114
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
52115
for (int i = col_id * VecSize; i < cols;
53116
i += blockDim.x * gridDim.x * VecSize) {
54-
T src_vec[VecSize];
55-
T residual_vec[VecSize];
56-
T bias_vec[VecSize];
57-
#pragma unroll
58-
for (int ii = 0; ii < VecSize; ii++) {
59-
bias_vec[ii] = static_cast<T>(0);
60-
}
61-
// vectorize load data from global
62-
LoadT *value = reinterpret_cast<LoadT *>(&src_vec);
63-
LoadT *residual_value = reinterpret_cast<LoadT *>(&residual_vec);
64-
*value = *reinterpret_cast<const LoadT *>(&src[r * cols + i]);
65-
*residual_value =
66-
*reinterpret_cast<const LoadT *>(&residual[r * cols + i]);
67-
68-
LoadT *bias_value =
69-
bias != nullptr ? reinterpret_cast<LoadT *>(&bias_vec) : nullptr;
70-
if (bias != nullptr)
71-
*bias_value = *reinterpret_cast<const LoadT *>(&bias[i]);
72-
73-
float4 rand = curand_uniform4(&state);
74-
T dest_vec[VecSize];
75-
MaskType mask_vec[VecSize];
117+
FusedResidualDropoutBiasVecOneThread<T, MaskType, T, VecSize, false>(
118+
r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst,
119+
mask, NULL, NULL);
120+
}
121+
}
122+
}
76123

124+
/**
125+
* @brief the fused function called by every thread
126+
*/
127+
template <typename T, typename U, int VecSize, bool layer_norm>
128+
__forceinline__ __device__ void FusedResidualDropoutBiasOnlyInferVecOneThread(
129+
const int row_id, const int col_id, const int cols,
130+
const float dropout_prob, const T factor, const T *src, const T *residual,
131+
const T *bias, T *dst, U *mean_val, U *var_val) {
132+
using LoadT = AlignedVector<T, VecSize>;
133+
T src_vec[VecSize];
134+
T residual_vec[VecSize];
135+
T bias_vec[VecSize];
77136
#pragma unroll
78-
for (int ii = 0; ii < VecSize; ii++) {
79-
mask_vec[ii] = (MaskType)((&rand.x)[ii] >= dropout_prob);
80-
}
137+
for (int ii = 0; ii < VecSize; ii++) {
138+
bias_vec[ii] = static_cast<T>(0);
139+
}
140+
// vectorize load data from global
141+
LoadT *value = reinterpret_cast<LoadT *>(&src_vec);
142+
LoadT *residual_value = reinterpret_cast<LoadT *>(&residual_vec);
143+
*value = *reinterpret_cast<const LoadT *>(&src[row_id * cols + col_id]);
144+
*residual_value =
145+
*reinterpret_cast<const LoadT *>(&residual[row_id * cols + col_id]);
81146

82-
#pragma unroll
83-
for (int ii = 0; ii < VecSize; ii++) {
84-
dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) *
85-
static_cast<T>(mask_vec[ii]) * factor +
86-
residual_vec[ii];
87-
}
147+
LoadT *bias_value =
148+
bias != nullptr ? reinterpret_cast<LoadT *>(&bias_vec) : nullptr;
149+
if (bias != nullptr)
150+
*bias_value = *reinterpret_cast<const LoadT *>(&bias[col_id]);
151+
152+
T dest_vec[VecSize];
88153

89-
// store result to global
90-
*(reinterpret_cast<LoadT *>(&dst[r * cols + i])) =
91-
*reinterpret_cast<LoadT *>(&dest_vec[0]);
92-
*(reinterpret_cast<MaskLoadT *>(&mask[r * cols + i])) =
93-
*reinterpret_cast<MaskLoadT *>(&mask_vec[0]);
154+
#pragma unroll
155+
for (int ii = 0; ii < VecSize; ii++) {
156+
dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * factor + residual_vec[ii];
157+
if (layer_norm) {
158+
U tmp = static_cast<U>(dest_vec[ii]);
159+
*mean_val += tmp;
160+
*var_val += (tmp * tmp);
94161
}
95162
}
163+
164+
// store result to global
165+
*(reinterpret_cast<LoadT *>(&dst[row_id * cols + col_id])) =
166+
*reinterpret_cast<LoadT *>(&dest_vec[0]);
96167
}
97168

98169
/**
99-
* @brief for dropout's param is_test = true
170+
* @brief for dropout's param is_test = true, only used in inference
100171
* the src, residual and dst shape is (rows, cols)
101172
* the bias shape is (1, cols)
102173
*/
103174
template <typename T, int VecSize>
104-
__global__ void FusedResidualDropoutBiasIsTest(const size_t rows,
105-
const size_t cols,
106-
const float dropout_prob,
107-
const bool is_upscale_in_train,
108-
const T *src, const T *residual,
109-
const T *bias, T *dst) {
175+
__global__ void FusedResidualDropoutBiasOnlyInferVec(
176+
const size_t rows, const size_t cols, const float dropout_prob,
177+
const bool is_upscale_in_train, const T *src, const T *residual,
178+
const T *bias, T *dst) {
110179
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
111180
int row_id = blockIdx.y;
112181
int idx = row_id * cols + col_id;
@@ -116,39 +185,12 @@ __global__ void FusedResidualDropoutBiasIsTest(const size_t rows,
116185
factor = static_cast<T>(1.0f);
117186
}
118187

119-
using LoadT = AlignedVector<T, VecSize>;
120-
121188
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
122189
for (int i = col_id * VecSize; i < cols;
123190
i += blockDim.x * gridDim.x * VecSize) {
124-
T src_vec[VecSize];
125-
T residual_vec[VecSize];
126-
T bias_vec[VecSize];
127-
#pragma unroll
128-
for (int ii = 0; ii < VecSize; ii++) {
129-
bias_vec[ii] = static_cast<T>(0);
130-
}
131-
// vectorize load data from global
132-
LoadT *value = reinterpret_cast<LoadT *>(&src_vec);
133-
LoadT *residual_value = reinterpret_cast<LoadT *>(&residual_vec);
134-
*value = *reinterpret_cast<const LoadT *>(&src[r * cols + i]);
135-
*residual_value =
136-
*reinterpret_cast<const LoadT *>(&residual[r * cols + i]);
137-
138-
LoadT *bias_value =
139-
bias != nullptr ? reinterpret_cast<LoadT *>(&bias_vec) : nullptr;
140-
if (bias != nullptr)
141-
*bias_value = *reinterpret_cast<const LoadT *>(&bias[i]);
142-
143-
T dest_vec[VecSize];
144-
#pragma unroll
145-
for (int ii = 0; ii < VecSize; ii++) {
146-
dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * factor + residual_vec[ii];
147-
}
148-
149-
// store result to global
150-
*(reinterpret_cast<LoadT *>(&dst[r * cols + i])) =
151-
*reinterpret_cast<LoadT *>(&dest_vec[0]);
191+
FusedResidualDropoutBiasOnlyInferVecOneThread<T, T, VecSize, false>(
192+
r, i, cols, dropout_prob, factor, src, residual, bias, dst, nullptr,
193+
nullptr);
152194
}
153195
}
154196
}
@@ -159,7 +201,7 @@ __global__ void FusedResidualDropoutBiasIsTest(const size_t rows,
159201
template <typename T, typename MaskType>
160202
void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols,
161203
const int increment, uint64_t seed,
162-
const float dropout_prob,
204+
const float dropout_prob, const bool is_test,
163205
bool is_upscale_in_train, const T *src,
164206
const T *residual, const T *bias,
165207
MaskType *mask_data, T *dst,
@@ -176,46 +218,32 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols,
176218

177219
const int VecSize = 4;
178220
auto threads = Get1DBlocksAnd2DGrids(ctx, rows, cols);
179-
if (cols % VecSize != 0)
180-
FusedResidualDropoutBiasVec<
181-
T, uint8_t, 1><<<threads.second, threads.first, 0, ctx.stream()>>>(
182-
rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual,
183-
bias, mask_data, dst, increment);
184-
else
185-
FusedResidualDropoutBiasVec<
186-
T, uint8_t,
187-
VecSize><<<threads.second, threads.first, 0, ctx.stream()>>>(
188-
rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual,
189-
bias, mask_data, dst, increment);
190-
}
191-
192-
/**
193-
*@brief to launch kernel FusedResidualDropoutBiasIsTest
194-
*/
195-
template <typename T>
196-
void LaunchResidualDropoutBiasIsTest(const uint32_t rows, const uint32_t cols,
197-
const float dropout_prob,
198-
bool is_upscale_in_train, const T *src,
199-
const T *residual, const T *bias, T *dst,
200-
const platform::CUDADeviceContext &ctx) {
201-
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
202-
PADDLE_ENFORCE_CUDA_SUCCESS(
203-
cudaMemcpyAsync(dst, residual, rows * cols * sizeof(T),
204-
cudaMemcpyDeviceToDevice, ctx.stream()));
205-
return;
221+
if (cols % VecSize != 0) {
222+
if (!is_test) {
223+
FusedResidualDropoutBiasVec<
224+
T, uint8_t, 1><<<threads.second, threads.first, 0, ctx.stream()>>>(
225+
rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual,
226+
bias, mask_data, dst, increment);
227+
} else {
228+
FusedResidualDropoutBiasOnlyInferVec<
229+
T, 1><<<threads.second, threads.first, 0, ctx.stream()>>>(
230+
rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias,
231+
dst);
232+
}
233+
} else {
234+
if (!is_test) {
235+
FusedResidualDropoutBiasVec<
236+
T, uint8_t,
237+
VecSize><<<threads.second, threads.first, 0, ctx.stream()>>>(
238+
rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual,
239+
bias, mask_data, dst, increment);
240+
} else {
241+
FusedResidualDropoutBiasOnlyInferVec<
242+
T, VecSize><<<threads.second, threads.first, 0, ctx.stream()>>>(
243+
rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias,
244+
dst);
245+
}
206246
}
207-
const int VecSize = 4;
208-
auto threads = Get1DBlocksAnd2DGrids(ctx, rows, cols);
209-
if (cols % VecSize != 0)
210-
FusedResidualDropoutBiasIsTest<
211-
T, 1><<<threads.second, threads.first, 0, ctx.stream()>>>(
212-
rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias,
213-
dst);
214-
else
215-
FusedResidualDropoutBiasIsTest<
216-
T, VecSize><<<threads.second, threads.first, 0, ctx.stream()>>>(
217-
rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias,
218-
dst);
219247
}
220248

221249
/********Backward**************/

0 commit comments

Comments
 (0)