33from PIL import Image
44import os
55import os .path
6- import errno
6+ import gzip
77import numpy as np
88import torch
99import codecs
10- from .utils import download_url
10+ from .utils import download_url , makedir_exist_ok
1111
1212
1313class MNIST (data .Dataset ):
@@ -32,13 +32,10 @@ class MNIST(data.Dataset):
3232 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz' ,
3333 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz' ,
3434 ]
35- raw_folder = 'raw'
36- processed_folder = 'processed'
3735 training_file = 'training.pt'
3836 test_file = 'test.pt'
3937 classes = ['0 - zero' , '1 - one' , '2 - two' , '3 - three' , '4 - four' ,
4038 '5 - five' , '6 - six' , '7 - seven' , '8 - eight' , '9 - nine' ]
41- class_to_idx = {_class : i for i , _class in enumerate (classes )}
4239
4340 def __init__ (self , root , train = True , transform = None , target_transform = None , download = False ):
4441 self .root = os .path .expanduser (root )
@@ -57,7 +54,7 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down
5754 data_file = self .training_file
5855 else :
5956 data_file = self .test_file
60- self .data , self .targets = torch .load (os .path .join (self .root , self . processed_folder , data_file ))
57+ self .data , self .targets = torch .load (os .path .join (self .processed_folder , data_file ))
6158
6259 def __getitem__ (self , index ):
6360 """
@@ -84,51 +81,61 @@ def __getitem__(self, index):
8481 def __len__ (self ):
8582 return len (self .data )
8683
84+ @property
85+ def raw_folder (self ):
86+ return os .path .join (self .root , self .__class__ .__name__ , 'raw' )
87+
88+ @property
89+ def processed_folder (self ):
90+ return os .path .join (self .root , self .__class__ .__name__ , 'processed' )
91+
92+ @property
93+ def class_to_idx (self ):
94+ return {_class : i for i , _class in enumerate (self .classes )}
95+
8796 def _check_exists (self ):
88- return os .path .exists (os .path .join (self .root , self .processed_folder , self .training_file )) and \
89- os .path .exists (os .path .join (self .root , self .processed_folder , self .test_file ))
97+ return os .path .exists (os .path .join (self .processed_folder , self .training_file )) and \
98+ os .path .exists (os .path .join (self .processed_folder , self .test_file ))
99+
100+ @staticmethod
101+ def extract_gzip (gzip_path , remove_finished = False ):
102+ print ('Extracting {}' .format (gzip_path ))
103+ with open (gzip_path .replace ('.gz' , '' ), 'wb' ) as out_f , \
104+ gzip .GzipFile (gzip_path ) as zip_f :
105+ out_f .write (zip_f .read ())
106+ if remove_finished :
107+ os .unlink (gzip_path )
90108
91109 def download (self ):
92110 """Download the MNIST data if it doesn't exist in processed_folder already."""
93- import gzip
94111
95112 if self ._check_exists ():
96113 return
97114
98- # download files
99- try :
100- os .makedirs (os .path .join (self .root , self .raw_folder ))
101- os .makedirs (os .path .join (self .root , self .processed_folder ))
102- except OSError as e :
103- if e .errno == errno .EEXIST :
104- pass
105- else :
106- raise
115+ makedir_exist_ok (self .raw_folder )
116+ makedir_exist_ok (self .processed_folder )
107117
118+ # download files
108119 for url in self .urls :
109120 filename = url .rpartition ('/' )[2 ]
110- file_path = os .path .join (self .root , self .raw_folder , filename )
111- download_url (url , root = os .path .join (self .root , self .raw_folder ),
112- filename = filename , md5 = None )
113- with open (file_path .replace ('.gz' , '' ), 'wb' ) as out_f , \
114- gzip .GzipFile (file_path ) as zip_f :
115- out_f .write (zip_f .read ())
116- os .unlink (file_path )
121+ file_path = os .path .join (self .raw_folder , filename )
122+ download_url (url , root = self .raw_folder , filename = filename , md5 = None )
123+ self .extract_gzip (gzip_path = file_path , remove_finished = True )
117124
118125 # process and save as torch files
119126 print ('Processing...' )
120127
121128 training_set = (
122- read_image_file (os .path .join (self .root , self . raw_folder , 'train-images-idx3-ubyte' )),
123- read_label_file (os .path .join (self .root , self . raw_folder , 'train-labels-idx1-ubyte' ))
129+ read_image_file (os .path .join (self .raw_folder , 'train-images-idx3-ubyte' )),
130+ read_label_file (os .path .join (self .raw_folder , 'train-labels-idx1-ubyte' ))
124131 )
125132 test_set = (
126- read_image_file (os .path .join (self .root , self . raw_folder , 't10k-images-idx3-ubyte' )),
127- read_label_file (os .path .join (self .root , self . raw_folder , 't10k-labels-idx1-ubyte' ))
133+ read_image_file (os .path .join (self .raw_folder , 't10k-images-idx3-ubyte' )),
134+ read_label_file (os .path .join (self .raw_folder , 't10k-labels-idx1-ubyte' ))
128135 )
129- with open (os .path .join (self .root , self . processed_folder , self .training_file ), 'wb' ) as f :
136+ with open (os .path .join (self .processed_folder , self .training_file ), 'wb' ) as f :
130137 torch .save (training_set , f )
131- with open (os .path .join (self .root , self . processed_folder , self .test_file ), 'wb' ) as f :
138+ with open (os .path .join (self .processed_folder , self .test_file ), 'wb' ) as f :
132139 torch .save (test_set , f )
133140
134141 print ('Done!' )
@@ -170,7 +177,6 @@ class FashionMNIST(MNIST):
170177 ]
171178 classes = ['T-shirt/top' , 'Trouser' , 'Pullover' , 'Dress' , 'Coat' , 'Sandal' ,
172179 'Shirt' , 'Sneaker' , 'Bag' , 'Ankle boot' ]
173- class_to_idx = {_class : i for i , _class in enumerate (classes )}
174180
175181
176182class EMNIST (MNIST ):
@@ -205,64 +211,55 @@ def __init__(self, root, split, **kwargs):
205211 self .test_file = self ._test_file (split )
206212 super (EMNIST , self ).__init__ (root , ** kwargs )
207213
208- def _training_file (self , split ):
214+ @staticmethod
215+ def _training_file (split ):
209216 return 'training_{}.pt' .format (split )
210217
211- def _test_file (self , split ):
218+ @staticmethod
219+ def _test_file (split ):
212220 return 'test_{}.pt' .format (split )
213221
214222 def download (self ):
215223 """Download the EMNIST data if it doesn't exist in processed_folder already."""
216- import gzip
217224 import shutil
218225 import zipfile
219226
220227 if self ._check_exists ():
221228 return
222229
223- # download files
224- try :
225- os .makedirs (os .path .join (self .root , self .raw_folder ))
226- os .makedirs (os .path .join (self .root , self .processed_folder ))
227- except OSError as e :
228- if e .errno == errno .EEXIST :
229- pass
230- else :
231- raise
230+ makedir_exist_ok (self .raw_folder )
231+ makedir_exist_ok (self .processed_folder )
232232
233+ # download files
233234 filename = self .url .rpartition ('/' )[2 ]
234- raw_folder = os .path .join (self .root , self .raw_folder )
235- file_path = os .path .join (raw_folder , filename )
236- download_url (self .url , root = file_path , filename = filename , md5 = None )
235+ file_path = os .path .join (self .raw_folder , filename )
236+ download_url (self .url , root = self .raw_folder , filename = filename , md5 = None )
237237
238238 print ('Extracting zip archive' )
239239 with zipfile .ZipFile (file_path ) as zip_f :
240- zip_f .extractall (raw_folder )
240+ zip_f .extractall (self . raw_folder )
241241 os .unlink (file_path )
242- gzip_folder = os .path .join (raw_folder , 'gzip' )
242+ gzip_folder = os .path .join (self . raw_folder , 'gzip' )
243243 for gzip_file in os .listdir (gzip_folder ):
244244 if gzip_file .endswith ('.gz' ):
245- print ('Extracting ' + gzip_file )
246- with open (os .path .join (raw_folder , gzip_file .replace ('.gz' , '' )), 'wb' ) as out_f , \
247- gzip .GzipFile (os .path .join (gzip_folder , gzip_file )) as zip_f :
248- out_f .write (zip_f .read ())
249- shutil .rmtree (gzip_folder )
245+ self .extract_gzip (gzip_path = os .path .join (gzip_folder , gzip_file ))
250246
251247 # process and save as torch files
252248 for split in self .splits :
253249 print ('Processing ' + split )
254250 training_set = (
255- read_image_file (os .path .join (raw_folder , 'emnist-{}-train-images-idx3-ubyte' .format (split ))),
256- read_label_file (os .path .join (raw_folder , 'emnist-{}-train-labels-idx1-ubyte' .format (split )))
251+ read_image_file (os .path .join (gzip_folder , 'emnist-{}-train-images-idx3-ubyte' .format (split ))),
252+ read_label_file (os .path .join (gzip_folder , 'emnist-{}-train-labels-idx1-ubyte' .format (split )))
257253 )
258254 test_set = (
259- read_image_file (os .path .join (raw_folder , 'emnist-{}-test-images-idx3-ubyte' .format (split ))),
260- read_label_file (os .path .join (raw_folder , 'emnist-{}-test-labels-idx1-ubyte' .format (split )))
255+ read_image_file (os .path .join (gzip_folder , 'emnist-{}-test-images-idx3-ubyte' .format (split ))),
256+ read_label_file (os .path .join (gzip_folder , 'emnist-{}-test-labels-idx1-ubyte' .format (split )))
261257 )
262- with open (os .path .join (self .root , self . processed_folder , self ._training_file (split )), 'wb' ) as f :
258+ with open (os .path .join (self .processed_folder , self ._training_file (split )), 'wb' ) as f :
263259 torch .save (training_set , f )
264- with open (os .path .join (self .root , self . processed_folder , self ._test_file (split )), 'wb' ) as f :
260+ with open (os .path .join (self .processed_folder , self ._test_file (split )), 'wb' ) as f :
265261 torch .save (test_set , f )
262+ shutil .rmtree (gzip_folder )
266263
267264 print ('Done!' )
268265
0 commit comments