Skip to content

Commit e564f13

Browse files
dizczafmassa
authored andcommitted
MNIST and FashionMNIST now have their own 'raw' and 'processed' folders (#601)
* MNIST and FashionMNIST now have their own 'raw' and 'processed' folders * mkdir exist_ok
1 parent a7935ea commit e564f13

File tree

2 files changed

+71
-67
lines changed

2 files changed

+71
-67
lines changed

torchvision/datasets/mnist.py

Lines changed: 57 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from PIL import Image
44
import os
55
import os.path
6-
import errno
6+
import gzip
77
import numpy as np
88
import torch
99
import codecs
10-
from .utils import download_url
10+
from .utils import download_url, makedir_exist_ok
1111

1212

1313
class MNIST(data.Dataset):
@@ -32,13 +32,10 @@ class MNIST(data.Dataset):
3232
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
3333
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
3434
]
35-
raw_folder = 'raw'
36-
processed_folder = 'processed'
3735
training_file = 'training.pt'
3836
test_file = 'test.pt'
3937
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
4038
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
41-
class_to_idx = {_class: i for i, _class in enumerate(classes)}
4239

4340
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
4441
self.root = os.path.expanduser(root)
@@ -57,7 +54,7 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down
5754
data_file = self.training_file
5855
else:
5956
data_file = self.test_file
60-
self.data, self.targets = torch.load(os.path.join(self.root, self.processed_folder, data_file))
57+
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
6158

6259
def __getitem__(self, index):
6360
"""
@@ -84,51 +81,61 @@ def __getitem__(self, index):
8481
def __len__(self):
8582
return len(self.data)
8683

84+
@property
85+
def raw_folder(self):
86+
return os.path.join(self.root, self.__class__.__name__, 'raw')
87+
88+
@property
89+
def processed_folder(self):
90+
return os.path.join(self.root, self.__class__.__name__, 'processed')
91+
92+
@property
93+
def class_to_idx(self):
94+
return {_class: i for i, _class in enumerate(self.classes)}
95+
8796
def _check_exists(self):
88-
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
89-
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
97+
return os.path.exists(os.path.join(self.processed_folder, self.training_file)) and \
98+
os.path.exists(os.path.join(self.processed_folder, self.test_file))
99+
100+
@staticmethod
101+
def extract_gzip(gzip_path, remove_finished=False):
102+
print('Extracting {}'.format(gzip_path))
103+
with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \
104+
gzip.GzipFile(gzip_path) as zip_f:
105+
out_f.write(zip_f.read())
106+
if remove_finished:
107+
os.unlink(gzip_path)
90108

91109
def download(self):
92110
"""Download the MNIST data if it doesn't exist in processed_folder already."""
93-
import gzip
94111

95112
if self._check_exists():
96113
return
97114

98-
# download files
99-
try:
100-
os.makedirs(os.path.join(self.root, self.raw_folder))
101-
os.makedirs(os.path.join(self.root, self.processed_folder))
102-
except OSError as e:
103-
if e.errno == errno.EEXIST:
104-
pass
105-
else:
106-
raise
115+
makedir_exist_ok(self.raw_folder)
116+
makedir_exist_ok(self.processed_folder)
107117

118+
# download files
108119
for url in self.urls:
109120
filename = url.rpartition('/')[2]
110-
file_path = os.path.join(self.root, self.raw_folder, filename)
111-
download_url(url, root=os.path.join(self.root, self.raw_folder),
112-
filename=filename, md5=None)
113-
with open(file_path.replace('.gz', ''), 'wb') as out_f, \
114-
gzip.GzipFile(file_path) as zip_f:
115-
out_f.write(zip_f.read())
116-
os.unlink(file_path)
121+
file_path = os.path.join(self.raw_folder, filename)
122+
download_url(url, root=self.raw_folder, filename=filename, md5=None)
123+
self.extract_gzip(gzip_path=file_path, remove_finished=True)
117124

118125
# process and save as torch files
119126
print('Processing...')
120127

121128
training_set = (
122-
read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
123-
read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
129+
read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
130+
read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
124131
)
125132
test_set = (
126-
read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
127-
read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
133+
read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
134+
read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
128135
)
129-
with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
136+
with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
130137
torch.save(training_set, f)
131-
with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
138+
with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
132139
torch.save(test_set, f)
133140

134141
print('Done!')
@@ -170,7 +177,6 @@ class FashionMNIST(MNIST):
170177
]
171178
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
172179
'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
173-
class_to_idx = {_class: i for i, _class in enumerate(classes)}
174180

175181

176182
class EMNIST(MNIST):
@@ -205,64 +211,55 @@ def __init__(self, root, split, **kwargs):
205211
self.test_file = self._test_file(split)
206212
super(EMNIST, self).__init__(root, **kwargs)
207213

208-
def _training_file(self, split):
214+
@staticmethod
215+
def _training_file(split):
209216
return 'training_{}.pt'.format(split)
210217

211-
def _test_file(self, split):
218+
@staticmethod
219+
def _test_file(split):
212220
return 'test_{}.pt'.format(split)
213221

214222
def download(self):
215223
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
216-
import gzip
217224
import shutil
218225
import zipfile
219226

220227
if self._check_exists():
221228
return
222229

223-
# download files
224-
try:
225-
os.makedirs(os.path.join(self.root, self.raw_folder))
226-
os.makedirs(os.path.join(self.root, self.processed_folder))
227-
except OSError as e:
228-
if e.errno == errno.EEXIST:
229-
pass
230-
else:
231-
raise
230+
makedir_exist_ok(self.raw_folder)
231+
makedir_exist_ok(self.processed_folder)
232232

233+
# download files
233234
filename = self.url.rpartition('/')[2]
234-
raw_folder = os.path.join(self.root, self.raw_folder)
235-
file_path = os.path.join(raw_folder, filename)
236-
download_url(self.url, root=file_path, filename=filename, md5=None)
235+
file_path = os.path.join(self.raw_folder, filename)
236+
download_url(self.url, root=self.raw_folder, filename=filename, md5=None)
237237

238238
print('Extracting zip archive')
239239
with zipfile.ZipFile(file_path) as zip_f:
240-
zip_f.extractall(raw_folder)
240+
zip_f.extractall(self.raw_folder)
241241
os.unlink(file_path)
242-
gzip_folder = os.path.join(raw_folder, 'gzip')
242+
gzip_folder = os.path.join(self.raw_folder, 'gzip')
243243
for gzip_file in os.listdir(gzip_folder):
244244
if gzip_file.endswith('.gz'):
245-
print('Extracting ' + gzip_file)
246-
with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \
247-
gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f:
248-
out_f.write(zip_f.read())
249-
shutil.rmtree(gzip_folder)
245+
self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file))
250246

251247
# process and save as torch files
252248
for split in self.splits:
253249
print('Processing ' + split)
254250
training_set = (
255-
read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
256-
read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
251+
read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
252+
read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
257253
)
258254
test_set = (
259-
read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
260-
read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
255+
read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
256+
read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
261257
)
262-
with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f:
258+
with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f:
263259
torch.save(training_set, f)
264-
with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f:
260+
with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f:
265261
torch.save(test_set, f)
262+
shutil.rmtree(gzip_folder)
266263

267264
print('Done!')
268265

torchvision/datasets/utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,27 @@ def check_integrity(fpath, md5=None):
3131
return True
3232

3333

34-
def download_url(url, root, filename, md5):
35-
from six.moves import urllib
36-
37-
root = os.path.expanduser(root)
38-
fpath = os.path.join(root, filename)
39-
34+
def makedir_exist_ok(dirpath):
35+
"""
36+
Python2 support for os.makedirs(.., exist_ok=True)
37+
"""
4038
try:
41-
os.makedirs(root)
39+
os.makedirs(dirpath)
4240
except OSError as e:
4341
if e.errno == errno.EEXIST:
4442
pass
4543
else:
4644
raise
4745

46+
47+
def download_url(url, root, filename, md5):
48+
from six.moves import urllib
49+
50+
root = os.path.expanduser(root)
51+
fpath = os.path.join(root, filename)
52+
53+
makedir_exist_ok(root)
54+
4855
# downloads file
4956
if os.path.isfile(fpath) and check_integrity(fpath, md5):
5057
print('Using downloaded and verified file: ' + fpath)

0 commit comments

Comments
 (0)