3333_DOWNSAMPLING = tf .image .ResizeMethod .BILINEAR
3434_SHUFFLE_BUFFER = 1024
3535
36+
3637def _int64_feature (value ):
3738 if not isinstance (value , Iterable ):
3839 value = [value ]
3940 return tf .train .Feature (int64_list = tf .train .Int64List (value = value ))
4041
42+
4143def _bytes_feature (value ):
4244 return tf .train .Feature (bytes_list = tf .train .BytesList (value = [value ]))
4345
46+
4447def error (msg ):
4548 print ('Error: ' + msg )
4649 exit (1 )
4750
51+
4852def x_to_uint8 (x ):
4953 return tf .cast (tf .clip_by_value (tf .floor (x ), 0 , 255 ), 'uint8' )
5054
55+
5156def centre_crop (img ):
5257 h , w = tf .shape (img )[0 ], tf .shape (img )[1 ]
5358 min_side = tf .minimum (h , w )
5459 h_offset = (h - min_side ) // 2
5560 w_offset = (w - min_side ) // 2
5661 return tf .image .crop_to_bounding_box (img , h_offset , w_offset , min_side , min_side )
5762
63+
5864def downsample (img ):
5965 return (img [0 ::2 , 0 ::2 , :] + img [0 ::2 , 1 ::2 , :] + img [1 ::2 , 0 ::2 , :] + img [1 ::2 , 1 ::2 , :]) * 0.25
6066
67+
6168def parse_image (max_res ):
6269 def _process_image (img ):
6370 img = centre_crop (img )
64- img = tf .image .resize_images (img , [max_res , max_res ], method = _DOWNSAMPLING )
71+ img = tf .image .resize_images (
72+ img , [max_res , max_res ], method = _DOWNSAMPLING )
6573 img = tf .cast (img , 'float32' )
6674 resolution_log2 = int (np .log2 (max_res ))
6775 q_imgs = []
@@ -89,7 +97,8 @@ def _parse_image(example):
8997
9098 return _parse_image
9199
92- def parse_celeba_image (max_res , transpose = False ):
100+
101+ def parse_celeba_image (max_res , transpose = False ):
93102 def _process_image (img ):
94103 img = tf .cast (img , 'float32' )
95104 resolution_log2 = int (np .log2 (max_res ))
@@ -112,26 +121,29 @@ def _parse_image(example):
112121 data = tf .decode_raw (data , tf .uint8 )
113122 img = tf .reshape (data , shape )
114123 if transpose :
115- img = tf .transpose (img , (1 ,2 , 0 )) # CHW -> HWC
124+ img = tf .transpose (img , (1 , 2 , 0 )) # CHW -> HWC
116125 imgs = _process_image (img )
117126 parsed = (attr , * imgs )
118127 return parsed
119128
120129 return _parse_image
121130
131+
122132def get_tfr_files (data_dir , split , lgres ):
123133 data_dir = os .path .join (data_dir , split )
124134 tfr_prefix = os .path .join (data_dir , os .path .basename (data_dir ))
125135 tfr_files = tfr_prefix + '-r%02d-s-*-of-*.tfrecords' % (lgres )
126136 return tfr_files
127137
138+
128139def get_tfr_file (data_dir , split , lgres ):
129140 if split :
130141 data_dir = os .path .join (data_dir , split )
131142 tfr_prefix = os .path .join (data_dir , os .path .basename (data_dir ))
132143 tfr_file = tfr_prefix + '-r%02d.tfrecords' % (lgres )
133144 return tfr_file
134145
146+
135147def dump_celebahq (data_dir , tfrecord_dir , max_res , split , write ):
136148 _NUM_IMAGES = {
137149 'train' : 27000 ,
@@ -150,7 +162,8 @@ def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write):
150162 if split :
151163 tfr_files = get_tfr_files (data_dir , split , int (np .log2 (max_res )))
152164 files = tf .data .Dataset .list_files (tfr_files )
153- dset = files .apply (tf .contrib .data .parallel_interleave (tf .data .TFRecordDataset , cycle_length = _NUM_PARALLEL_FILE_READERS ))
165+ dset = files .apply (tf .contrib .data .parallel_interleave (
166+ tf .data .TFRecordDataset , cycle_length = _NUM_PARALLEL_FILE_READERS ))
154167 transpose = False
155168 else :
156169 tfr_file = get_tfr_file (data_dir , "" , int (np .log2 (max_res )))
@@ -173,10 +186,12 @@ def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write):
173186 if write :
174187 tfr .add_image (0 , imgs , attr )
175188 if write :
176- assert tfr .cur_images == total_imgs , (tfr .cur_images , total_imgs )
189+ assert tfr .cur_images == total_imgs , (
190+ tfr .cur_images , total_imgs )
177191
178192 #attr, *imgs = sess.run([_attr, *_imgs])
179193
194+
180195def dump_imagenet (data_dir , tfrecord_dir , max_res , split , write ):
181196 _NUM_IMAGES = {
182197 'train' : 1281167 ,
@@ -194,9 +209,11 @@ def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):
194209 with tf .Session () as sess :
195210 is_training = (split == 'train' )
196211 if is_training :
197- files = tf .data .Dataset .list_files (os .path .join (data_dir , 'train-*-of-01024' ))
212+ files = tf .data .Dataset .list_files (
213+ os .path .join (data_dir , 'train-*-of-01024' ))
198214 else :
199- files = tf .data .Dataset .list_files (os .path .join (data_dir , 'validation-*-of-00128' ))
215+ files = tf .data .Dataset .list_files (
216+ os .path .join (data_dir , 'validation-*-of-00128' ))
200217
201218 files = files .shuffle (buffer_size = _NUM_FILES [split ])
202219
@@ -205,7 +222,8 @@ def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):
205222
206223 dataset = dataset .shuffle (buffer_size = _SHUFFLE_BUFFER )
207224 parse_fn = parse_image (max_res )
208- dataset = dataset .map (parse_fn , num_parallel_calls = _NUM_PARALLEL_MAP_CALLS )
225+ dataset = dataset .map (
226+ parse_fn , num_parallel_calls = _NUM_PARALLEL_MAP_CALLS )
209227 dataset = dataset .prefetch (1 )
210228 iterator = dataset .make_one_shot_iterator ()
211229
@@ -225,10 +243,12 @@ def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):
225243
226244 #label, *imgs = sess.run([_label, *_imgs])
227245
246+
228247class TFRecordExporter :
229248 def __init__ (self , tfrecord_dir , resolution_log2 , expected_images , shards , print_progress = True , progress_interval = 10 ):
230249 self .tfrecord_dir = tfrecord_dir
231- self .tfr_prefix = os .path .join (self .tfrecord_dir , os .path .basename (self .tfrecord_dir ))
250+ self .tfr_prefix = os .path .join (
251+ self .tfrecord_dir , os .path .basename (self .tfrecord_dir ))
232252 self .resolution_log2 = resolution_log2
233253 self .expected_images = expected_images
234254
@@ -242,19 +262,24 @@ def __init__(self, tfrecord_dir, resolution_log2, expected_images, shards, print
242262 if not os .path .isdir (self .tfrecord_dir ):
243263 os .makedirs (self .tfrecord_dir )
244264 assert (os .path .isdir (self .tfrecord_dir ))
245- tfr_opt = tf .python_io .TFRecordOptions (tf .python_io .TFRecordCompressionType .NONE )
265+ tfr_opt = tf .python_io .TFRecordOptions (
266+ tf .python_io .TFRecordCompressionType .NONE )
246267 for lod in range (self .resolution_log2 - 1 ):
247- p_shard = np .array_split (np .random .permutation (expected_images ),shards )
248- img_to_shard = np .zeros (expected_images , dtype = np .int )
268+ p_shard = np .array_split (
269+ np .random .permutation (expected_images ), shards )
270+ img_to_shard = np .zeros (expected_images , dtype = np .int )
249271 writers = []
250272 for shard in range (shards ):
251273 img_to_shard [p_shard [shard ]] = shard
252- tfr_file = self .tfr_prefix + '-r%02d-s-%04d-of-%04d.tfrecords' % (self .resolution_log2 - lod , shard , shards )
274+ tfr_file = self .tfr_prefix + \
275+ '-r%02d-s-%04d-of-%04d.tfrecords' % (
276+ self .resolution_log2 - lod , shard , shards )
253277 writers .append (tf .python_io .TFRecordWriter (tfr_file , tfr_opt ))
254278 #print(np.unique(img_to_shard, return_counts=True))
255279 counts = np .unique (img_to_shard , return_counts = True )[1 ]
256280 assert len (counts ) == shards
257- print ("Smallest and largest shards have size" , np .min (counts ), np .max (counts ))
281+ print ("Smallest and largest shards have size" ,
282+ np .min (counts ), np .max (counts ))
258283 self .tfr_writers .append ((writers , img_to_shard ))
259284
260285 def close (self ):
@@ -286,7 +311,8 @@ def add_image(self, label, imgs, attr):
286311 }
287312 )
288313 )
289- writers [img_to_shard [self .cur_images ]].write (ex .SerializeToString ())
314+ writers [img_to_shard [self .cur_images ]].write (
315+ ex .SerializeToString ())
290316 self .cur_images += 1
291317
292318 # def add_labels(self, labels):
@@ -302,16 +328,20 @@ def __enter__(self):
302328 def __exit__ (self , * args ):
303329 self .close ()
304330
331+
305332if __name__ == "__main__" :
306333 import argparse
307334 parser = argparse .ArgumentParser ()
308- parser .add_argument ("--data_dir" , type = str , required = True )
335+ parser .add_argument ("--data_dir" , type = str , required = True )
309336 parser .add_argument ("--max_res" , type = int , default = 256 , help = "Image size" )
310- parser .add_argument ("--tfrecord_dir" , type = str , required = True , help = 'place to dump' )
311- parser .add_argument ("--write" , action = 'store_true' , help = "Whether to write" )
312- hps = parser .parse_args () # So error if typo
337+ parser .add_argument ("--tfrecord_dir" , type = str ,
338+ required = True , help = 'place to dump' )
339+ parser .add_argument ("--write" , action = 'store_true' ,
340+ help = "Whether to write" )
341+ hps = parser .parse_args () # So error if typo
313342 #dump_imagenet(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'validation', hps.write)
314343 #dump_imagenet(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'train', hps.write)
315- dump_celebahq (hps .data_dir , hps .tfrecord_dir , hps .max_res , 'validation' , hps .write )
316- dump_celebahq (hps .data_dir , hps .tfrecord_dir , hps .max_res , 'train' , hps .write )
317-
344+ dump_celebahq (hps .data_dir , hps .tfrecord_dir ,
345+ hps .max_res , 'validation' , hps .write )
346+ dump_celebahq (hps .data_dir , hps .tfrecord_dir ,
347+ hps .max_res , 'train' , hps .write )
0 commit comments