@@ -64,17 +64,16 @@ static __forceinline__ __device__ U WarpReduceSum(U val) {
6464}
6565
6666template <typename U>
67- __forceinline__ __device__ U BlockReduceSum (U val) {
68- static __shared__ U shared[32 ];
67+ __forceinline__ __device__ U BlockReduceSum (U val, U *shared) {
6968 int lane = threadIdx .x % warpSize ;
7069 int wid = threadIdx .x / warpSize ;
7170
7271 val = WarpReduceSum (val); // Each warp performs partial reduction
7372
73+ __syncthreads ();
7474 if (lane == 0 ) shared[wid] = val; // Write reduced value to shared memory
7575
7676 __syncthreads (); // Wait for all partial reductions
77-
7877 // read from shared memory only if that warp existed
7978 val =
8079 (threadIdx .x < blockDim .x / warpSize ) ? shared[lane] : static_cast <U>(0 );
@@ -183,6 +182,9 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
183182 int64_t feature_size) {
184183 __shared__ U mean_share;
185184 __shared__ U var_share;
185+ __shared__ U shared_mean[32 ]; // threadIdx.x / warpSize <= kMaxBlockDim /
186+ // warpSize <= 1024/32 = 32;
187+ __shared__ U shared_var[32 ];
186188
187189 int64_t beg_idx = blockIdx .x * feature_size + threadIdx .x ;
188190 int64_t end_idx = (blockIdx .x + 1 ) * feature_size;
@@ -196,8 +198,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
196198 var_val += (tmp * tmp);
197199 }
198200
199- mean_val = BlockReduceSum<U>(mean_val);
200- var_val = BlockReduceSum<U>(var_val);
201+ mean_val = BlockReduceSum<U>(mean_val, shared_mean );
202+ var_val = BlockReduceSum<U>(var_val, shared_var );
201203
202204 if (threadIdx .x == 0 ) {
203205 auto scale = static_cast <float >(1 .) / static_cast <float >(feature_size);
@@ -541,8 +543,11 @@ __global__ void LayerNormBackwardGradientAll(
541543 }
542544 }
543545
544- d_scale_partial = BlockReduceSum<U>(d_scale_partial);
545- d_bias_partial = BlockReduceSum<U>(d_bias_partial);
546+ __shared__ U shared_scale[32 ]; // threadIdx.x / warpSize <= kMaxBlockDim /
547+ // warpSize <= 1024/32 = 32;
548+ __shared__ U shared_bias[32 ];
549+ d_scale_partial = BlockReduceSum<U>(d_scale_partial, shared_scale);
550+ d_bias_partial = BlockReduceSum<U>(d_bias_partial, shared_bias);
546551
547552 if (threadIdx .x == 0 ) {
548553 d_scale[blockIdx .x + col_offset] = d_scale_partial;
0 commit comments