Skip to content

Commit 3e4b32e

Browse files
zkh2016AnnaTrainingG
authored andcommitted
fix the bug of layer_norm when batch_size=1 (PaddlePaddle#35480)
The bug is that access to mean and var is incorrect, and the array will be out of bounds: the shape of mean and var is [batch_size], and the range of thread idx is 0~feature_size, so mean[idx] and var[idx] is incorrect. When batch_size=1, the correct access is mean[0] and var[0], and a unit test with batch_size=1 is added.
1 parent b91e4a7 commit 3e4b32e

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

paddle/fluid/operators/layer_norm_kernel.cu.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
705705
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
706706
if (idx < feature_size) {
707707
auto var_val =
708-
static_cast<U>(real_sqrt(static_cast<float>(var[idx]) + epsilon));
708+
static_cast<U>(real_sqrt(static_cast<float>(var[0]) + epsilon));
709709
if (d_x != nullptr) {
710710
if (d_scale == nullptr) {
711711
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
@@ -717,7 +717,7 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
717717

718718
if (d_scale != nullptr) {
719719
d_scale[idx] = static_cast<U>(d_y[idx]) *
720-
(static_cast<U>(x[idx]) - mean[idx]) / var_val;
720+
(static_cast<U>(x[idx]) - mean[0]) / var_val;
721721
}
722722

723723
if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]);

python/paddle/fluid/tests/unittests/test_layer_norm_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def test_with_place(place,
233233
test_with_place(place, shape, begin_norm_axis)
234234

235235
def test_check_forward_backward_with_scale_and_bias(self):
236+
self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1)
236237
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1)
237238
self.check_forward_backward(
238239
shape=[2, 3, 4, 5],

0 commit comments

Comments
 (0)