Skip to content

Commit 6714fd4

Browse files
committed
Fix
1 parent 2cda62c commit 6714fd4

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

python/paddle/nn/functional/loss.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3067,6 +3067,11 @@ def cross_entropy(
30673067
# so, reduce_sum all directly is ok
30683068
return _C_ops.sum(out, [], None, False)
30693069
elif reduction == "mean":
3070+
# when reduction is mean, use paddle.nan
3071+
if input.size == 0:
3072+
mask = paddle.full(out.shape, paddle.bool)
3073+
paddle.masked_fill_(out, mask, paddle.nan)
3074+
30703075
# 1. if weight==none,
30713076
# numerator: reduce_sum all loss directly is ok causeof base_softmax_with_cross_entropy's inner logic
30723077
# denominator: count sample num with class_index!=ignore_index

test/legacy_test/test_cross_entropy_op.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,5 +468,34 @@ def test_input_dims():
468468
self.assertRaises(ValueError, test_input_dims)
469469

470470

471+
class TestCrossEntropyOp_ZeroSize(TestCrossEntropyOp):
472+
def setUp(self):
473+
self.op_type = "cross_entropy"
474+
self.python_api = api_wrapper
475+
self.soft_label = False
476+
self.ignore_index = -100
477+
self.dtype = np.float64
478+
# 0-size
479+
self.batch_size = 0
480+
self.class_num = 10
481+
482+
self.init_dtype_type()
483+
self.init_attr_type()
484+
self.init_bs_class_num()
485+
self.init_x()
486+
self.init_label()
487+
self.get_cross_entropy()
488+
489+
self.inputs = {"X": self.x, "Label": self.label}
490+
self.outputs = {"Y": self.cross_entropy}
491+
self.attrs = {
492+
"soft_label": self.soft_label,
493+
"ignore_index": self.ignore_index,
494+
}
495+
496+
def get_cross_entropy(self):
497+
self.cross_entropy = np.random.random([0, 1]).astype(np.float64)
498+
499+
471500
if __name__ == "__main__":
472501
unittest.main()

0 commit comments

Comments
 (0)