Skip to content

Commit 4289442

Browse files
committed
Support squeezed label as input in paddle.metric.Accuracy
1 parent fdc06f2 commit 4289442

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

python/paddle/metric/metrics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def compute(self, pred, label, *args):
244244
Tensor: Correct mask, a tensor with shape [batch_size, topk].
245245
"""
246246
pred = paddle.argsort(pred, descending=True)[:, :self.maxk]
247+
label = paddle.reshape(label, (-1, 1))
247248
correct = pred == label
248249
return paddle.cast(correct, dtype='float32')
249250

python/paddle/tests/test_metrics.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
def accuracy(pred, label, topk=(1, )):
2929
maxk = max(topk)
3030
pred = np.argsort(pred)[:, ::-1][:, :maxk]
31+
label = label.reshape(-1, 1)
3132
correct = (pred == np.repeat(label, maxk, 1))
3233

3334
batch_size = label.shape[0]
@@ -47,21 +48,27 @@ def convert_to_one_hot(y, C):
4748

4849

4950
class TestAccuracy(unittest.TestCase):
50-
def test_acc(self):
51+
def test_acc(self, squeeze_y=False):
5152
paddle.disable_static()
5253

5354
x = paddle.to_tensor(
5455
np.array([[0.1, 0.2, 0.3, 0.4], [0.1, 0.4, 0.3, 0.2],
5556
[0.1, 0.2, 0.4, 0.3], [0.1, 0.2, 0.3, 0.4]]))
56-
y = paddle.to_tensor(np.array([[0], [1], [2], [3]]))
57+
58+
y = np.array([[0], [1], [2], [3]])
59+
if squeeze_y:
60+
y = y.squeeze()
61+
62+
y = paddle.to_tensor(y)
5763

5864
m = paddle.metric.Accuracy(name='my_acc')
5965

6066
# check name
6167
self.assertEqual(m.name(), ['my_acc'])
6268

6369
correct = m.compute(x, y)
64-
# check results
70+
# check shape and results
71+
self.assertEqual(correct.shape, [4, 1])
6572
self.assertEqual(m.update(correct), 0.75)
6673
self.assertEqual(m.accumulate(), 0.75)
6774

@@ -80,19 +87,25 @@ def test_acc(self):
8087
self.assertEqual(m.count[0], 0.0)
8188
paddle.enable_static()
8289

90+
def test_1d_label(self):
91+
self.test_acc(True)
92+
8393

8494
class TestAccuracyDynamic(unittest.TestCase):
8595
def setUp(self):
8696
self.topk = (1, )
8797
self.class_num = 5
8898
self.sample_num = 1000
8999
self.name = None
100+
self.squeeze_label = False
90101

91102
def random_pred_label(self):
92103
label = np.random.randint(0, self.class_num,
93104
(self.sample_num, 1)).astype('int64')
94105
pred = np.random.randint(0, self.class_num,
95106
(self.sample_num, 1)).astype('int32')
107+
if self.squeeze_label:
108+
label = label.squeeze()
96109
pred_one_hot = convert_to_one_hot(pred, self.class_num)
97110
pred_one_hot = pred_one_hot.astype('float32')
98111

@@ -123,9 +136,14 @@ def setUp(self):
123136
self.class_num = 10
124137
self.sample_num = 1000
125138
self.name = "accuracy"
139+
self.squeeze_label = True
126140

127141

128142
class TestAccuracyStatic(TestAccuracyDynamic):
143+
def setUp(self):
144+
super().setUp()
145+
self.squeeze_label = True
146+
129147
def test_main(self):
130148
main_prog = fluid.Program()
131149
startup_prog = fluid.Program()
@@ -164,6 +182,7 @@ def setUp(self):
164182
self.class_num = 10
165183
self.sample_num = 100
166184
self.name = "accuracy"
185+
self.squeeze_label = False
167186

168187

169188
class TestPrecision(unittest.TestCase):

python/paddle/vision/datasets/cifar.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ class Cifar10(Dataset):
5353
:attr:`data_file` is not set. Default True
5454
5555
Returns:
56-
Dataset: instance of cifar-10 dataset
56+
Dataset: instance of cifar-10 dataset. If transform is None, the shape
57+
of each data iterm is [3, 32, 32], and default dtype is float32.
58+
The dtype of label is int64.
5759
5860
Examples:
5961

0 commit comments

Comments
 (0)