Skip to content

Commit 93e8ae0

Browse files
authored
Revert "fix cifar label dimension. test=develop (#33475)"
This reverts commit 6c11034.
1 parent 1dfd857 commit 93e8ae0

File tree

2 files changed

+3
-16
lines changed

2 files changed

+3
-16
lines changed

python/paddle/tests/test_dataset_cifar.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ def test_main(self):
3232
self.assertTrue(data.shape[2] == 3)
3333
self.assertTrue(data.shape[1] == 32)
3434
self.assertTrue(data.shape[0] == 32)
35-
self.assertTrue(len(label.shape) == 1)
36-
self.assertTrue(label.shape[0] == 1)
3735
self.assertTrue(0 <= int(label) <= 9)
3836

3937

@@ -51,8 +49,6 @@ def test_main(self):
5149
self.assertTrue(data.shape[2] == 3)
5250
self.assertTrue(data.shape[1] == 32)
5351
self.assertTrue(data.shape[0] == 32)
54-
self.assertTrue(len(label.shape) == 1)
55-
self.assertTrue(label.shape[0] == 1)
5652
self.assertTrue(0 <= int(label) <= 9)
5753

5854
# test cv2 backend
@@ -67,8 +63,6 @@ def test_main(self):
6763
self.assertTrue(data.shape[2] == 3)
6864
self.assertTrue(data.shape[1] == 32)
6965
self.assertTrue(data.shape[0] == 32)
70-
self.assertTrue(len(label.shape) == 1)
71-
self.assertTrue(label.shape[0] == 1)
7266
self.assertTrue(0 <= int(label) <= 99)
7367

7468
with self.assertRaises(ValueError):
@@ -89,8 +83,6 @@ def test_main(self):
8983
self.assertTrue(data.shape[2] == 3)
9084
self.assertTrue(data.shape[1] == 32)
9185
self.assertTrue(data.shape[0] == 32)
92-
self.assertTrue(len(label.shape) == 1)
93-
self.assertTrue(label.shape[0] == 1)
9486
self.assertTrue(0 <= int(label) <= 99)
9587

9688

@@ -108,8 +100,6 @@ def test_main(self):
108100
self.assertTrue(data.shape[2] == 3)
109101
self.assertTrue(data.shape[1] == 32)
110102
self.assertTrue(data.shape[0] == 32)
111-
self.assertTrue(len(label.shape) == 1)
112-
self.assertTrue(label.shape[0] == 1)
113103
self.assertTrue(0 <= int(label) <= 99)
114104

115105
# test cv2 backend
@@ -124,8 +114,6 @@ def test_main(self):
124114
self.assertTrue(data.shape[2] == 3)
125115
self.assertTrue(data.shape[1] == 32)
126116
self.assertTrue(data.shape[0] == 32)
127-
self.assertTrue(len(label.shape) == 1)
128-
self.assertTrue(label.shape[0] == 1)
129117
self.assertTrue(0 <= int(label) <= 99)
130118

131119
with self.assertRaises(ValueError):

python/paddle/vision/datasets/cifar.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ def _load_data(self):
148148
six.b('labels'), batch.get(six.b('fine_labels'), None))
149149
assert labels is not None
150150
for sample, label in six.moves.zip(data, labels):
151-
self.data.append((sample,
152-
np.array([label]).astype('int64')))
151+
self.data.append((sample, label))
153152

154153
def __getitem__(self, idx):
155154
image, label = self.data[idx]
@@ -162,9 +161,9 @@ def __getitem__(self, idx):
162161
image = self.transform(image)
163162

164163
if self.backend == 'pil':
165-
return image, label.astype('int64')
164+
return image, np.array(label).astype('int64')
166165

167-
return image.astype(self.dtype), label.astype('int64')
166+
return image.astype(self.dtype), np.array(label).astype('int64')
168167

169168
def __len__(self):
170169
return len(self.data)

0 commit comments

Comments
 (0)