2828def accuracy (pred , label , topk = (1 , )):
2929 maxk = max (topk )
3030 pred = np .argsort (pred )[:, ::- 1 ][:, :maxk ]
31+ label = label .reshape (- 1 , 1 )
3132 correct = (pred == np .repeat (label , maxk , 1 ))
3233
3334 batch_size = label .shape [0 ]
@@ -47,21 +48,27 @@ def convert_to_one_hot(y, C):
4748
4849
4950class TestAccuracy (unittest .TestCase ):
50- def test_acc (self ):
51+ def test_acc (self , squeeze_y = False ):
5152 paddle .disable_static ()
5253
5354 x = paddle .to_tensor (
5455 np .array ([[0.1 , 0.2 , 0.3 , 0.4 ], [0.1 , 0.4 , 0.3 , 0.2 ],
5556 [0.1 , 0.2 , 0.4 , 0.3 ], [0.1 , 0.2 , 0.3 , 0.4 ]]))
56- y = paddle .to_tensor (np .array ([[0 ], [1 ], [2 ], [3 ]]))
57+
58+ y = np .array ([[0 ], [1 ], [2 ], [3 ]])
59+ if squeeze_y :
60+ y = y .squeeze ()
61+
62+ y = paddle .to_tensor (y )
5763
5864 m = paddle .metric .Accuracy (name = 'my_acc' )
5965
6066 # check name
6167 self .assertEqual (m .name (), ['my_acc' ])
6268
6369 correct = m .compute (x , y )
64- # check results
70+ # check shape and results
71+ self .assertEqual (correct .shape , [4 , 1 ])
6572 self .assertEqual (m .update (correct ), 0.75 )
6673 self .assertEqual (m .accumulate (), 0.75 )
6774
@@ -80,19 +87,25 @@ def test_acc(self):
8087 self .assertEqual (m .count [0 ], 0.0 )
8188 paddle .enable_static ()
8289
90+ def test_1d_label (self ):
91+ self .test_acc (True )
92+
8393
8494class TestAccuracyDynamic (unittest .TestCase ):
8595 def setUp (self ):
8696 self .topk = (1 , )
8797 self .class_num = 5
8898 self .sample_num = 1000
8999 self .name = None
100+ self .squeeze_label = False
90101
91102 def random_pred_label (self ):
92103 label = np .random .randint (0 , self .class_num ,
93104 (self .sample_num , 1 )).astype ('int64' )
94105 pred = np .random .randint (0 , self .class_num ,
95106 (self .sample_num , 1 )).astype ('int32' )
107+ if self .squeeze_label :
108+ label = label .squeeze ()
96109 pred_one_hot = convert_to_one_hot (pred , self .class_num )
97110 pred_one_hot = pred_one_hot .astype ('float32' )
98111
@@ -123,9 +136,14 @@ def setUp(self):
123136 self .class_num = 10
124137 self .sample_num = 1000
125138 self .name = "accuracy"
139+ self .squeeze_label = True
126140
127141
128142class TestAccuracyStatic (TestAccuracyDynamic ):
143+ def setUp (self ):
144+ super ().setUp ()
145+ self .squeeze_label = True
146+
129147 def test_main (self ):
130148 main_prog = fluid .Program ()
131149 startup_prog = fluid .Program ()
@@ -164,6 +182,7 @@ def setUp(self):
164182 self .class_num = 10
165183 self .sample_num = 100
166184 self .name = "accuracy"
185+ self .squeeze_label = False
167186
168187
169188class TestPrecision (unittest .TestCase ):
0 commit comments