@@ -15,14 +15,79 @@ limitations under the License. */
1515#pragma once
1616
1717#include " paddle/fluid/operators/fused/fused_dropout.h"
18+ #include " paddle/fluid/operators/layer_norm_kernel.cu.h"
1819
1920namespace paddle {
2021namespace operators {
2122
2223namespace platform = paddle::platform;
2324namespace cg = cooperative_groups;
2425
26+ /* *
27+ * @brief fused the add_bias, dropout, add residual into one operators
28+ *
29+ */
30+
2531/* *******Forward**************/
32+ /* *
33+ * @brief the fused function called by every thread
34+ */
35+ template <typename T, typename MaskType, typename U, int VecSize,
36+ bool layer_norm>
37+ __forceinline__ __device__ void FusedResidualDropoutBiasVecOneThread (
38+ const int row_id, const int col_id, const int cols,
39+ curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor,
40+ const T *src, const T *residual, const T *bias, T *dst, MaskType *mask,
41+ U *mean_val, U *var_val) {
42+ using LoadT = AlignedVector<T, VecSize>;
43+ using MaskLoadT = AlignedVector<MaskType, VecSize>;
44+ T src_vec[VecSize];
45+ T residual_vec[VecSize];
46+ T bias_vec[VecSize];
47+ #pragma unroll
48+ for (int ii = 0 ; ii < VecSize; ii++) {
49+ bias_vec[ii] = static_cast <T>(0 );
50+ }
51+ // vectorize load data from global
52+ LoadT *value = reinterpret_cast <LoadT *>(&src_vec);
53+ LoadT *residual_value = reinterpret_cast <LoadT *>(&residual_vec);
54+ *value = *reinterpret_cast <const LoadT *>(&src[row_id * cols + col_id]);
55+ *residual_value =
56+ *reinterpret_cast <const LoadT *>(&residual[row_id * cols + col_id]);
57+
58+ LoadT *bias_value =
59+ bias != nullptr ? reinterpret_cast <LoadT *>(&bias_vec) : nullptr ;
60+ if (bias != nullptr )
61+ *bias_value = *reinterpret_cast <const LoadT *>(&bias[col_id]);
62+
63+ float4 rand = curand_uniform4 (state);
64+ T dest_vec[VecSize];
65+ MaskType mask_vec[VecSize];
66+
67+ #pragma unroll
68+ for (int ii = 0 ; ii < VecSize; ii++) {
69+ mask_vec[ii] = (MaskType)((&rand.x )[ii] >= dropout_prob);
70+ }
71+
72+ #pragma unroll
73+ for (int ii = 0 ; ii < VecSize; ii++) {
74+ dest_vec[ii] =
75+ (src_vec[ii] + bias_vec[ii]) * static_cast <T>(mask_vec[ii]) * factor +
76+ residual_vec[ii];
77+ if (layer_norm) {
78+ U tmp = static_cast <U>(dest_vec[ii]);
79+ *mean_val += tmp;
80+ *var_val += (tmp * tmp);
81+ }
82+ }
83+
84+ // store result to global
85+ *(reinterpret_cast <LoadT *>(&dst[row_id * cols + col_id])) =
86+ *reinterpret_cast <LoadT *>(&dest_vec[0 ]);
87+ *(reinterpret_cast <MaskLoadT *>(&mask[row_id * cols + col_id])) =
88+ *reinterpret_cast <MaskLoadT *>(&mask_vec[0 ]);
89+ }
90+
2691/* *
2792 * @brief dst = residual + dropout(src + bias);
2893 * the src, residual, mask and dst shape is (rows, cols)
@@ -46,67 +111,71 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows,
46111 if (!is_upscale_in_train) {
47112 factor = static_cast <T>(1 .0f );
48113 }
49- using LoadT = AlignedVector<T, VecSize>;
50- using MaskLoadT = AlignedVector<MaskType, VecSize>;
51114 for (int r = row_id; r < rows; r += blockDim.y * gridDim.y ) {
52115 for (int i = col_id * VecSize; i < cols;
53116 i += blockDim.x * gridDim.x * VecSize) {
54- T src_vec[VecSize];
55- T residual_vec[VecSize];
56- T bias_vec[VecSize];
57- #pragma unroll
58- for (int ii = 0 ; ii < VecSize; ii++) {
59- bias_vec[ii] = static_cast <T>(0 );
60- }
61- // vectorize load data from global
62- LoadT *value = reinterpret_cast <LoadT *>(&src_vec);
63- LoadT *residual_value = reinterpret_cast <LoadT *>(&residual_vec);
64- *value = *reinterpret_cast <const LoadT *>(&src[r * cols + i]);
65- *residual_value =
66- *reinterpret_cast <const LoadT *>(&residual[r * cols + i]);
67-
68- LoadT *bias_value =
69- bias != nullptr ? reinterpret_cast <LoadT *>(&bias_vec) : nullptr ;
70- if (bias != nullptr )
71- *bias_value = *reinterpret_cast <const LoadT *>(&bias[i]);
72-
73- float4 rand = curand_uniform4 (&state);
74- T dest_vec[VecSize];
75- MaskType mask_vec[VecSize];
117+ FusedResidualDropoutBiasVecOneThread<T, MaskType, T, VecSize, false >(
118+ r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst,
119+ mask, NULL , NULL );
120+ }
121+ }
122+ }
76123
124+ /* *
125+ * @brief the fused function called by every thread
126+ */
127+ template <typename T, typename U, int VecSize, bool layer_norm>
128+ __forceinline__ __device__ void FusedResidualDropoutBiasOnlyInferVecOneThread (
129+ const int row_id, const int col_id, const int cols,
130+ const float dropout_prob, const T factor, const T *src, const T *residual,
131+ const T *bias, T *dst, U *mean_val, U *var_val) {
132+ using LoadT = AlignedVector<T, VecSize>;
133+ T src_vec[VecSize];
134+ T residual_vec[VecSize];
135+ T bias_vec[VecSize];
77136#pragma unroll
78- for (int ii = 0 ; ii < VecSize; ii++) {
79- mask_vec[ii] = (MaskType)((&rand.x )[ii] >= dropout_prob);
80- }
137+ for (int ii = 0 ; ii < VecSize; ii++) {
138+ bias_vec[ii] = static_cast <T>(0 );
139+ }
140+ // vectorize load data from global
141+ LoadT *value = reinterpret_cast <LoadT *>(&src_vec);
142+ LoadT *residual_value = reinterpret_cast <LoadT *>(&residual_vec);
143+ *value = *reinterpret_cast <const LoadT *>(&src[row_id * cols + col_id]);
144+ *residual_value =
145+ *reinterpret_cast <const LoadT *>(&residual[row_id * cols + col_id]);
81146
82- # pragma unroll
83- for ( int ii = 0 ; ii < VecSize; ii++) {
84- dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) *
85- static_cast <T>(mask_vec[ii]) * factor +
86- residual_vec[ii];
87- }
147+ LoadT *bias_value =
148+ bias != nullptr ? reinterpret_cast <LoadT *>(&bias_vec) : nullptr ;
149+ if (bias != nullptr )
150+ *bias_value = * reinterpret_cast < const LoadT *>(&bias[col_id]);
151+
152+ T dest_vec[VecSize];
88153
89- // store result to global
90- *(reinterpret_cast <LoadT *>(&dst[r * cols + i])) =
91- *reinterpret_cast <LoadT *>(&dest_vec[0 ]);
92- *(reinterpret_cast <MaskLoadT *>(&mask[r * cols + i])) =
93- *reinterpret_cast <MaskLoadT *>(&mask_vec[0 ]);
154+ #pragma unroll
155+ for (int ii = 0 ; ii < VecSize; ii++) {
156+ dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * factor + residual_vec[ii];
157+ if (layer_norm) {
158+ U tmp = static_cast <U>(dest_vec[ii]);
159+ *mean_val += tmp;
160+ *var_val += (tmp * tmp);
94161 }
95162 }
163+
164+ // store result to global
165+ *(reinterpret_cast <LoadT *>(&dst[row_id * cols + col_id])) =
166+ *reinterpret_cast <LoadT *>(&dest_vec[0 ]);
96167}
97168
98169/* *
99- * @brief for dropout's param is_test = true
170+ * @brief for dropout's param is_test = true, only used in inference
100171 * the src, residual and dst shape is (rows, cols)
101172 * the bias shape is (1, cols)
102173 */
103174template <typename T, int VecSize>
104- __global__ void FusedResidualDropoutBiasIsTest (const size_t rows,
105- const size_t cols,
106- const float dropout_prob,
107- const bool is_upscale_in_train,
108- const T *src, const T *residual,
109- const T *bias, T *dst) {
175+ __global__ void FusedResidualDropoutBiasOnlyInferVec (
176+ const size_t rows, const size_t cols, const float dropout_prob,
177+ const bool is_upscale_in_train, const T *src, const T *residual,
178+ const T *bias, T *dst) {
110179 int col_id = blockDim.x * blockIdx.x + threadIdx.x ;
111180 int row_id = blockIdx.y ;
112181 int idx = row_id * cols + col_id;
@@ -116,39 +185,12 @@ __global__ void FusedResidualDropoutBiasIsTest(const size_t rows,
116185 factor = static_cast <T>(1 .0f );
117186 }
118187
119- using LoadT = AlignedVector<T, VecSize>;
120-
121188 for (int r = row_id; r < rows; r += blockDim.y * gridDim.y ) {
122189 for (int i = col_id * VecSize; i < cols;
123190 i += blockDim.x * gridDim.x * VecSize) {
124- T src_vec[VecSize];
125- T residual_vec[VecSize];
126- T bias_vec[VecSize];
127- #pragma unroll
128- for (int ii = 0 ; ii < VecSize; ii++) {
129- bias_vec[ii] = static_cast <T>(0 );
130- }
131- // vectorize load data from global
132- LoadT *value = reinterpret_cast <LoadT *>(&src_vec);
133- LoadT *residual_value = reinterpret_cast <LoadT *>(&residual_vec);
134- *value = *reinterpret_cast <const LoadT *>(&src[r * cols + i]);
135- *residual_value =
136- *reinterpret_cast <const LoadT *>(&residual[r * cols + i]);
137-
138- LoadT *bias_value =
139- bias != nullptr ? reinterpret_cast <LoadT *>(&bias_vec) : nullptr ;
140- if (bias != nullptr )
141- *bias_value = *reinterpret_cast <const LoadT *>(&bias[i]);
142-
143- T dest_vec[VecSize];
144- #pragma unroll
145- for (int ii = 0 ; ii < VecSize; ii++) {
146- dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * factor + residual_vec[ii];
147- }
148-
149- // store result to global
150- *(reinterpret_cast <LoadT *>(&dst[r * cols + i])) =
151- *reinterpret_cast <LoadT *>(&dest_vec[0 ]);
191+ FusedResidualDropoutBiasOnlyInferVecOneThread<T, T, VecSize, false >(
192+ r, i, cols, dropout_prob, factor, src, residual, bias, dst, nullptr ,
193+ nullptr );
152194 }
153195 }
154196}
@@ -159,7 +201,7 @@ __global__ void FusedResidualDropoutBiasIsTest(const size_t rows,
159201template <typename T, typename MaskType>
160202void LaunchResidualDropoutBias (const uint32_t rows, const uint32_t cols,
161203 const int increment, uint64_t seed,
162- const float dropout_prob,
204+ const float dropout_prob, const bool is_test,
163205 bool is_upscale_in_train, const T *src,
164206 const T *residual, const T *bias,
165207 MaskType *mask_data, T *dst,
@@ -176,46 +218,32 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols,
176218
177219 const int VecSize = 4 ;
178220 auto threads = Get1DBlocksAnd2DGrids (ctx, rows, cols);
179- if (cols % VecSize != 0 )
180- FusedResidualDropoutBiasVec<
181- T, uint8_t , 1 ><<<threads.second , threads.first , 0 , ctx.stream ()>>>(
182- rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual,
183- bias, mask_data, dst, increment);
184- else
185- FusedResidualDropoutBiasVec<
186- T, uint8_t ,
187- VecSize><<<threads.second , threads.first , 0 , ctx.stream ()>>>(
188- rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual,
189- bias, mask_data, dst, increment);
190- }
191-
192- /* *
193- *@brief to launch kernel FusedResidualDropoutBiasIsTest
194- */
195- template <typename T>
196- void LaunchResidualDropoutBiasIsTest (const uint32_t rows, const uint32_t cols,
197- const float dropout_prob,
198- bool is_upscale_in_train, const T *src,
199- const T *residual, const T *bias, T *dst,
200- const platform::CUDADeviceContext &ctx) {
201- if (std::abs (dropout_prob - 1 .0f ) < 1e-5 ) {
202- PADDLE_ENFORCE_CUDA_SUCCESS (
203- cudaMemcpyAsync (dst, residual, rows * cols * sizeof (T),
204- cudaMemcpyDeviceToDevice, ctx.stream ()));
205- return ;
221+ if (cols % VecSize != 0 ) {
222+ if (!is_test) {
223+ FusedResidualDropoutBiasVec<
224+ T, uint8_t , 1 ><<<threads.second , threads.first , 0 , ctx.stream ()>>>(
225+ rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual,
226+ bias, mask_data, dst, increment);
227+ } else {
228+ FusedResidualDropoutBiasOnlyInferVec<
229+ T, 1 ><<<threads.second , threads.first , 0 , ctx.stream ()>>>(
230+ rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias,
231+ dst);
232+ }
233+ } else {
234+ if (!is_test) {
235+ FusedResidualDropoutBiasVec<
236+ T, uint8_t ,
237+ VecSize><<<threads.second , threads.first , 0 , ctx.stream ()>>>(
238+ rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual,
239+ bias, mask_data, dst, increment);
240+ } else {
241+ FusedResidualDropoutBiasOnlyInferVec<
242+ T, VecSize><<<threads.second , threads.first , 0 , ctx.stream ()>>>(
243+ rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias,
244+ dst);
245+ }
206246 }
207- const int VecSize = 4 ;
208- auto threads = Get1DBlocksAnd2DGrids (ctx, rows, cols);
209- if (cols % VecSize != 0 )
210- FusedResidualDropoutBiasIsTest<
211- T, 1 ><<<threads.second , threads.first , 0 , ctx.stream ()>>>(
212- rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias,
213- dst);
214- else
215- FusedResidualDropoutBiasIsTest<
216- T, VecSize><<<threads.second , threads.first , 0 , ctx.stream ()>>>(
217- rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias,
218- dst);
219247}
220248
221249/* *******Backward**************/
0 commit comments