@@ -14,14 +14,15 @@ limitations under the License. */
1414
1515#pragma once
1616
17- #include " paddle/fluid/operators/fused/fused_dropout .h"
17+ #include " paddle/fluid/operators/fused/fused_dropout_common .h"
1818#include " paddle/fluid/operators/layer_norm_kernel.cu.h"
1919
2020namespace paddle {
2121namespace operators {
2222
2323namespace platform = paddle::platform;
2424namespace cg = cooperative_groups;
25+ namespace memory = paddle::memory;
2526
2627/* *
2728 * @brief fused the add_bias, dropout, add residual into one operators
@@ -32,15 +33,17 @@ namespace cg = cooperative_groups;
3233/* *
3334 * @brief the fused function called by every thread
3435 */
35- template <typename T, typename MaskType, typename U, int VecSize,
36- bool layer_norm>
36+ template <typename T, typename MaskType, int VecSize, bool ComputeLayerNorm>
3737__forceinline__ __device__ void FusedResidualDropoutBiasVecOneThread (
3838 const int row_id, const int col_id, const int cols,
3939 curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor,
4040 const T *src, const T *residual, const T *bias, T *dst, MaskType *mask,
41- U *mean_val, U *var_val) {
41+ typename details::MPTypeTrait<T>::Type *mean_val,
42+ typename details::MPTypeTrait<T>::Type *var_val) {
4243 using LoadT = AlignedVector<T, VecSize>;
4344 using MaskLoadT = AlignedVector<MaskType, VecSize>;
45+ using U = typename details::MPTypeTrait<T>::Type;
46+
4447 T src_vec[VecSize];
4548 T residual_vec[VecSize];
4649 T bias_vec[VecSize];
@@ -74,7 +77,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasVecOneThread(
7477 dest_vec[ii] =
7578 (src_vec[ii] + bias_vec[ii]) * static_cast <T>(mask_vec[ii]) * factor +
7679 residual_vec[ii];
77- if (layer_norm ) {
80+ if (ComputeLayerNorm ) {
7881 U tmp = static_cast <U>(dest_vec[ii]);
7982 *mean_val += tmp;
8083 *var_val += (tmp * tmp);
@@ -114,7 +117,7 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows,
114117 for (int r = row_id; r < rows; r += blockDim.y * gridDim.y ) {
115118 for (int i = col_id * VecSize; i < cols;
116119 i += blockDim.x * gridDim.x * VecSize) {
117- FusedResidualDropoutBiasVecOneThread<T, MaskType, T, VecSize, false >(
120+ FusedResidualDropoutBiasVecOneThread<T, MaskType, VecSize, false >(
118121 r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst,
119122 mask, NULL , NULL );
120123 }
@@ -208,9 +211,10 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols,
208211 const platform::CUDADeviceContext &ctx) {
209212 // dropout_prob == 1.0f
210213 if (std::abs (dropout_prob - 1 .0f ) < 1e-5 ) {
211- PADDLE_ENFORCE_CUDA_SUCCESS (
212- cudaMemcpyAsync (dst, residual, rows * cols * sizeof (T),
213- cudaMemcpyDeviceToDevice, ctx.stream ()));
214+ if (residual == dst) return ;
215+ auto cuda_place = BOOST_GET_CONST (platform::CUDAPlace, ctx.GetPlace ());
216+ memory::Copy (cuda_place, dst, cuda_place, residual, rows * cols * sizeof (T),
217+ ctx.stream ());
214218 PADDLE_ENFORCE_CUDA_SUCCESS (cudaMemsetAsync (
215219 mask_data, 0 , rows * cols * sizeof (MaskType), ctx.stream ()));
216220 return ;
@@ -282,7 +286,8 @@ __global__ void FusedResidualDropoutGradVec(const T *dout, const MaskType *mask,
282286 * 2. save 128*8 temporary sum in 8*128 shared memory
283287 * 3. reduce the sum of 128 rows data by 8*VecSize warps
284288 */
285- template <typename T, typename MaskType, int BSX, int BSY, int VecSize>
289+ template <typename T, typename MaskType, int BLOCK_SIZE_X, int BLOCK_SIZE_Y,
290+ int VecSize>
286291__global__ void FusedResidualDropoutBiasGradVec (
287292 const T *dout, const MaskType *mask, const T factor, const int64_t rows,
288293 const int64_t cols, T *dx, T *dbias) {
@@ -316,9 +321,10 @@ __global__ void FusedResidualDropoutBiasGradVec(
316321 }
317322
318323 // save temporary sum to cache and do transpose
319- __shared__ T cache[BSX * VecSize][BSY ];
320- for (int i = 0 ; i < VecSize; i++)
324+ __shared__ T cache[BLOCK_SIZE_X * VecSize][BLOCK_SIZE_Y ];
325+ for (int i = 0 ; i < VecSize; i++) {
321326 cache[threadIdx.x * VecSize + i][threadIdx.y ] = tmp_sum[i];
327+ }
322328 __syncthreads ();
323329
324330 // reduce sum
@@ -327,11 +333,11 @@ __global__ void FusedResidualDropoutBiasGradVec(
327333 int x = tid >> 5 ; // warp id
328334 int y = tid & 31 ; // thread id on warp 0~31
329335
330- // need BSX * VecSize warps
331- if (x < BSX * VecSize) {
336+ // need BLOCK_SIZE_X * VecSize warps
337+ if (x < BLOCK_SIZE_X * VecSize) {
332338// reduce 128 to 32
333339#pragma unroll
334- for (int i = 0 ; i < (BSY >> 5 ); i++) {
340+ for (int i = 0 ; i < (BLOCK_SIZE_Y >> 5 ); i++) {
335341 sum += cache[x][y + i * 32 ];
336342 }
337343 }
@@ -341,7 +347,7 @@ __global__ void FusedResidualDropoutBiasGradVec(
341347
342348 // save sum to dbias
343349 int bias_id = blockIdx.x * blockDim.x * VecSize + x;
344- if (y == 0 && x < VecSize * BSX && bias_id < cols) {
350+ if (y == 0 && x < VecSize * BLOCK_SIZE_X && bias_id < cols) {
345351 dbias[bias_id] = sum;
346352 }
347353}
@@ -367,7 +373,9 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask,
367373 const int VecSize = 4 ;
368374 if (dbias != nullptr ) {
369375 int real_vec_size = VecSize;
370- if (cols % VecSize != 0 ) real_vec_size = 1 ;
376+ if (cols % VecSize != 0 ) {
377+ real_vec_size = 1 ;
378+ }
371379 auto threads = std::min (cols / real_vec_size, static_cast <uint32_t >(8 ));
372380 auto blocks = std::max (
373381 (uint32_t )1 , std::min ((cols / real_vec_size + threads - 1 ) / threads,
0 commit comments