Skip to content

Commit 332141f

Browse files
committed
Fix gradients with ignore_idx in softmax_with_cross_entropy on cpu.
Remove softmax_with_cross_entropy from op_threshold_white_list. test=develop
1 parent 6ca6aaa commit 332141f

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

paddle/fluid/operators/softmax_with_cross_entropy_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,13 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
116116
for (int i = 0; i < n; ++i) {
117117
for (int j = 0; j < remain; j++) {
118118
int idx = i * remain + j;
119-
logit_grad_data[i * d + label_data[idx] * remain + j] -=
120-
out_grad_data[idx];
121119
if (label_data[idx] == ignore_index) {
122120
for (int k = 0; k < axis_dim; ++k) {
123121
logit_grad_data[i * d + k * remain + j] = 0;
124122
}
123+
} else {
124+
logit_grad_data[i * d + label_data[idx] * remain + j] -=
125+
out_grad_data[idx];
125126
}
126127
}
127128
}

python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,17 @@ def setUp(self):
8383
self.attrs = {
8484
"numeric_stable_mode": self.numeric_stable_mode,
8585
"soft_label": self.soft_label,
86+
"ignore_index": self.ignore_index,
8687
}
87-
if self.ignore_index >= 0:
88-
self.attrs['ignore_index'] = self.ignore_index
88+
8989
if self.axis != -1:
9090
self.attrs['axis'] = self.axis
9191

9292
def test_check_output(self):
9393
self.check_output()
9494

9595
def test_check_grad(self):
96-
self.check_grad(["Logits"], "Loss", max_relative_error=0.05)
96+
self.check_grad(["Logits"], "Loss")
9797

9898

9999
class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp):

python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
'selu', \
3737
'sigmoid_cross_entropy_with_logits', \
3838
'soft_relu', \
39-
'softmax_with_cross_entropy', \
4039
'spp', \
4140
'teacher_student_sigmoid_loss', \
4241
'unpool', \

0 commit comments

Comments
 (0)