Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/paddle/metric/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def compute(self, pred, label, *args):
Tensor: Correct mask, a tensor with shape [batch_size, topk].
"""
pred = paddle.argsort(pred, descending=True)[:, :self.maxk]
label = paddle.reshape(label, (-1, 1))
correct = pred == label
return paddle.cast(correct, dtype='float32')

Expand Down
25 changes: 22 additions & 3 deletions python/paddle/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
def accuracy(pred, label, topk=(1, )):
maxk = max(topk)
pred = np.argsort(pred)[:, ::-1][:, :maxk]
label = label.reshape(-1, 1)
correct = (pred == np.repeat(label, maxk, 1))

batch_size = label.shape[0]
Expand All @@ -47,21 +48,27 @@ def convert_to_one_hot(y, C):


class TestAccuracy(unittest.TestCase):
def test_acc(self):
def test_acc(self, squeeze_y=False):
paddle.disable_static()

x = paddle.to_tensor(
np.array([[0.1, 0.2, 0.3, 0.4], [0.1, 0.4, 0.3, 0.2],
[0.1, 0.2, 0.4, 0.3], [0.1, 0.2, 0.3, 0.4]]))
y = paddle.to_tensor(np.array([[0], [1], [2], [3]]))

y = np.array([[0], [1], [2], [3]])
if squeeze_y:
y = y.squeeze()

y = paddle.to_tensor(y)

m = paddle.metric.Accuracy(name='my_acc')

# check name
self.assertEqual(m.name(), ['my_acc'])

correct = m.compute(x, y)
# check results
# check shape and results
self.assertEqual(correct.shape, [4, 1])
self.assertEqual(m.update(correct), 0.75)
self.assertEqual(m.accumulate(), 0.75)

Expand All @@ -80,19 +87,25 @@ def test_acc(self):
self.assertEqual(m.count[0], 0.0)
paddle.enable_static()

def test_1d_label(self):
self.test_acc(True)


class TestAccuracyDynamic(unittest.TestCase):
def setUp(self):
self.topk = (1, )
self.class_num = 5
self.sample_num = 1000
self.name = None
self.squeeze_label = False

def random_pred_label(self):
label = np.random.randint(0, self.class_num,
(self.sample_num, 1)).astype('int64')
pred = np.random.randint(0, self.class_num,
(self.sample_num, 1)).astype('int32')
if self.squeeze_label:
label = label.squeeze()
pred_one_hot = convert_to_one_hot(pred, self.class_num)
pred_one_hot = pred_one_hot.astype('float32')

Expand Down Expand Up @@ -123,9 +136,14 @@ def setUp(self):
self.class_num = 10
self.sample_num = 1000
self.name = "accuracy"
self.squeeze_label = True


class TestAccuracyStatic(TestAccuracyDynamic):
def setUp(self):
super().setUp()
self.squeeze_label = True

def test_main(self):
main_prog = fluid.Program()
startup_prog = fluid.Program()
Expand Down Expand Up @@ -164,6 +182,7 @@ def setUp(self):
self.class_num = 10
self.sample_num = 100
self.name = "accuracy"
self.squeeze_label = False


class TestPrecision(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/vision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Cifar10(Dataset):
default backend is 'pil'. Default: None.

Returns:
Dataset: instance of cifar-10 dataset
Dataset: instance of cifar-10 dataset.

Examples:

Expand Down