File tree Expand file tree Collapse file tree 2 files changed +34
-0
lines changed
python/paddle/nn/functional Expand file tree Collapse file tree 2 files changed +34
-0
lines changed Original file line number Diff line number Diff line change @@ -3067,6 +3067,11 @@ def cross_entropy(
30673067 # so, reduce_sum all directly is ok
30683068 return _C_ops .sum (out , [], None , False )
30693069 elif reduction == "mean" :
3070+ # when reduction is mean, use paddle.nan
3071+ if input .size == 0 :
3072+ mask = paddle .full (out .shape , paddle .bool )
3073+ paddle .masked_fill_ (out , mask , paddle .nan )
3074+
30703075 # 1. if weight==none,
30713076 # numerator: reduce_sum all loss directly is ok causeof base_softmax_with_cross_entropy's inner logic
30723077 # denominator: count sample num with class_index!=ignore_index
Original file line number Diff line number Diff line change @@ -468,5 +468,34 @@ def test_input_dims():
468468 self .assertRaises (ValueError , test_input_dims )
469469
470470
471+ class TestCrossEntropyOp_ZeroSize (TestCrossEntropyOp ):
472+ def setUp (self ):
473+ self .op_type = "cross_entropy"
474+ self .python_api = api_wrapper
475+ self .soft_label = False
476+ self .ignore_index = - 100
477+ self .dtype = np .float64
478+ # 0-size
479+ self .batch_size = 0
480+ self .class_num = 10
481+
482+ self .init_dtype_type ()
483+ self .init_attr_type ()
484+ self .init_bs_class_num ()
485+ self .init_x ()
486+ self .init_label ()
487+ self .get_cross_entropy ()
488+
489+ self .inputs = {"X" : self .x , "Label" : self .label }
490+ self .outputs = {"Y" : self .cross_entropy }
491+ self .attrs = {
492+ "soft_label" : self .soft_label ,
493+ "ignore_index" : self .ignore_index ,
494+ }
495+
496+ def get_cross_entropy (self ):
497+ self .cross_entropy = np .random .random ([0 , 1 ]).astype (np .float64 )
498+
499+
471500if __name__ == "__main__" :
472501 unittest .main ()
You can’t perform that action at this time.
0 commit comments