Skip to content

Commit a382fc3

Browse files
authored
Merge pull request #1 from openai/master
合并原作
2 parents c5ba53b + 1f13529 commit a382fc3

15 files changed

Lines changed: 713 additions & 358 deletions

README.md

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@ Code for reproducing results in ["Glow: Generative Flow with Invertible 1x1 Conv
55
## Requirements
66

77
- Tensorflow (tested with v1.8.0)
8-
- Horovod (tested with v0.13.4) and (Open)MPI
8+
- Horovod (tested with v0.13.8) and (Open)MPI
9+
10+
Run
11+
```
12+
pip install -r requirements.txt
13+
```
14+
15+
To setup (Open)MPI, check instructions on Horovod github [page](https://github.com/uber/horovod).
916

1017
## Download datasets
1118
The datasets are in the Google Cloud locations `https://storage.googleapis.com/glow-demo/data/{dataset_name}-tfr.tar`. The dataset_names are below, we mention the exact preprocessing / downsampling method for a correct comparison of likelihood.
@@ -21,11 +28,18 @@ Qualitative results
2128

2229
To download and extract celeb for example, run
2330
```
24-
curl https://storage.googleapis.com/glow-demo/data/celeba-tfr.tar
31+
wget https://storage.googleapis.com/glow-demo/data/celeba-tfr.tar
2532
tar -xvf celeb-tfr.tar
2633
```
2734
Change `hps.data_dir` in train.py file to point to the above folder (or use the `--data_dir` flag when you run train.py)
2835

36+
## Simple Train with 1 GPU
37+
38+
Run wtih small depth to test
39+
```
40+
CUDA_VISIBLE_DEVICES=0 python train.py --depth 1
41+
```
42+
2943
## Train with multiple GPUs using MPI and Horovod
3044

3145
Run default training script with 8 GPUs:
@@ -80,4 +94,4 @@ mpiexec -n 8 python train.py --problem cifar10 --image_size 32 --n_level 3 --dep
8094
##### Conditional ImageNet 32x32 Qualitative result
8195
```
8296
mpiexec -n 8 python train.py --problem imagenet --image_size 32 --n_level 3 --depth 48 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5 --ycond --weight_y=0.01
83-
```
97+
```

data_loaders/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-

data_loaders/generate_tfr/generate.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,43 @@
3333
_DOWNSAMPLING = tf.image.ResizeMethod.BILINEAR
3434
_SHUFFLE_BUFFER = 1024
3535

36+
3637
def _int64_feature(value):
3738
if not isinstance(value, Iterable):
3839
value = [value]
3940
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
4041

42+
4143
def _bytes_feature(value):
4244
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
4345

46+
4447
def error(msg):
4548
print('Error: ' + msg)
4649
exit(1)
4750

51+
4852
def x_to_uint8(x):
4953
return tf.cast(tf.clip_by_value(tf.floor(x), 0, 255), 'uint8')
5054

55+
5156
def centre_crop(img):
5257
h, w = tf.shape(img)[0], tf.shape(img)[1]
5358
min_side = tf.minimum(h, w)
5459
h_offset = (h - min_side) // 2
5560
w_offset = (w - min_side) // 2
5661
return tf.image.crop_to_bounding_box(img, h_offset, w_offset, min_side, min_side)
5762

63+
5864
def downsample(img):
5965
return (img[0::2, 0::2, :] + img[0::2, 1::2, :] + img[1::2, 0::2, :] + img[1::2, 1::2, :]) * 0.25
6066

67+
6168
def parse_image(max_res):
6269
def _process_image(img):
6370
img = centre_crop(img)
64-
img = tf.image.resize_images(img, [max_res, max_res], method=_DOWNSAMPLING)
71+
img = tf.image.resize_images(
72+
img, [max_res, max_res], method=_DOWNSAMPLING)
6573
img = tf.cast(img, 'float32')
6674
resolution_log2 = int(np.log2(max_res))
6775
q_imgs = []
@@ -89,7 +97,8 @@ def _parse_image(example):
8997

9098
return _parse_image
9199

92-
def parse_celeba_image(max_res, transpose = False):
100+
101+
def parse_celeba_image(max_res, transpose=False):
93102
def _process_image(img):
94103
img = tf.cast(img, 'float32')
95104
resolution_log2 = int(np.log2(max_res))
@@ -112,26 +121,29 @@ def _parse_image(example):
112121
data = tf.decode_raw(data, tf.uint8)
113122
img = tf.reshape(data, shape)
114123
if transpose:
115-
img = tf.transpose(img, (1,2,0)) # CHW -> HWC
124+
img = tf.transpose(img, (1, 2, 0)) # CHW -> HWC
116125
imgs = _process_image(img)
117126
parsed = (attr, *imgs)
118127
return parsed
119128

120129
return _parse_image
121130

131+
122132
def get_tfr_files(data_dir, split, lgres):
123133
data_dir = os.path.join(data_dir, split)
124134
tfr_prefix = os.path.join(data_dir, os.path.basename(data_dir))
125135
tfr_files = tfr_prefix + '-r%02d-s-*-of-*.tfrecords' % (lgres)
126136
return tfr_files
127137

138+
128139
def get_tfr_file(data_dir, split, lgres):
129140
if split:
130141
data_dir = os.path.join(data_dir, split)
131142
tfr_prefix = os.path.join(data_dir, os.path.basename(data_dir))
132143
tfr_file = tfr_prefix + '-r%02d.tfrecords' % (lgres)
133144
return tfr_file
134145

146+
135147
def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write):
136148
_NUM_IMAGES = {
137149
'train': 27000,
@@ -150,7 +162,8 @@ def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write):
150162
if split:
151163
tfr_files = get_tfr_files(data_dir, split, int(np.log2(max_res)))
152164
files = tf.data.Dataset.list_files(tfr_files)
153-
dset = files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS))
165+
dset = files.apply(tf.contrib.data.parallel_interleave(
166+
tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS))
154167
transpose = False
155168
else:
156169
tfr_file = get_tfr_file(data_dir, "", int(np.log2(max_res)))
@@ -173,10 +186,12 @@ def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write):
173186
if write:
174187
tfr.add_image(0, imgs, attr)
175188
if write:
176-
assert tfr.cur_images == total_imgs, (tfr.cur_images, total_imgs)
189+
assert tfr.cur_images == total_imgs, (
190+
tfr.cur_images, total_imgs)
177191

178192
#attr, *imgs = sess.run([_attr, *_imgs])
179193

194+
180195
def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):
181196
_NUM_IMAGES = {
182197
'train': 1281167,
@@ -194,9 +209,11 @@ def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):
194209
with tf.Session() as sess:
195210
is_training = (split == 'train')
196211
if is_training:
197-
files = tf.data.Dataset.list_files(os.path.join(data_dir, 'train-*-of-01024'))
212+
files = tf.data.Dataset.list_files(
213+
os.path.join(data_dir, 'train-*-of-01024'))
198214
else:
199-
files = tf.data.Dataset.list_files(os.path.join(data_dir, 'validation-*-of-00128'))
215+
files = tf.data.Dataset.list_files(
216+
os.path.join(data_dir, 'validation-*-of-00128'))
200217

201218
files = files.shuffle(buffer_size=_NUM_FILES[split])
202219

@@ -205,7 +222,8 @@ def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):
205222

206223
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
207224
parse_fn = parse_image(max_res)
208-
dataset = dataset.map(parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS)
225+
dataset = dataset.map(
226+
parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS)
209227
dataset = dataset.prefetch(1)
210228
iterator = dataset.make_one_shot_iterator()
211229

@@ -225,10 +243,12 @@ def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):
225243

226244
#label, *imgs = sess.run([_label, *_imgs])
227245

246+
228247
class TFRecordExporter:
229248
def __init__(self, tfrecord_dir, resolution_log2, expected_images, shards, print_progress=True, progress_interval=10):
230249
self.tfrecord_dir = tfrecord_dir
231-
self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir))
250+
self.tfr_prefix = os.path.join(
251+
self.tfrecord_dir, os.path.basename(self.tfrecord_dir))
232252
self.resolution_log2 = resolution_log2
233253
self.expected_images = expected_images
234254

@@ -242,19 +262,24 @@ def __init__(self, tfrecord_dir, resolution_log2, expected_images, shards, print
242262
if not os.path.isdir(self.tfrecord_dir):
243263
os.makedirs(self.tfrecord_dir)
244264
assert (os.path.isdir(self.tfrecord_dir))
245-
tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
265+
tfr_opt = tf.python_io.TFRecordOptions(
266+
tf.python_io.TFRecordCompressionType.NONE)
246267
for lod in range(self.resolution_log2 - 1):
247-
p_shard = np.array_split(np.random.permutation(expected_images),shards)
248-
img_to_shard = np.zeros(expected_images, dtype = np.int)
268+
p_shard = np.array_split(
269+
np.random.permutation(expected_images), shards)
270+
img_to_shard = np.zeros(expected_images, dtype=np.int)
249271
writers = []
250272
for shard in range(shards):
251273
img_to_shard[p_shard[shard]] = shard
252-
tfr_file = self.tfr_prefix + '-r%02d-s-%04d-of-%04d.tfrecords' % (self.resolution_log2 - lod, shard, shards)
274+
tfr_file = self.tfr_prefix + \
275+
'-r%02d-s-%04d-of-%04d.tfrecords' % (
276+
self.resolution_log2 - lod, shard, shards)
253277
writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
254278
#print(np.unique(img_to_shard, return_counts=True))
255279
counts = np.unique(img_to_shard, return_counts=True)[1]
256280
assert len(counts) == shards
257-
print("Smallest and largest shards have size", np.min(counts), np.max(counts))
281+
print("Smallest and largest shards have size",
282+
np.min(counts), np.max(counts))
258283
self.tfr_writers.append((writers, img_to_shard))
259284

260285
def close(self):
@@ -286,7 +311,8 @@ def add_image(self, label, imgs, attr):
286311
}
287312
)
288313
)
289-
writers[img_to_shard[self.cur_images]].write(ex.SerializeToString())
314+
writers[img_to_shard[self.cur_images]].write(
315+
ex.SerializeToString())
290316
self.cur_images += 1
291317

292318
# def add_labels(self, labels):
@@ -302,16 +328,20 @@ def __enter__(self):
302328
def __exit__(self, *args):
303329
self.close()
304330

331+
305332
if __name__ == "__main__":
306333
import argparse
307334
parser = argparse.ArgumentParser()
308-
parser.add_argument("--data_dir", type=str, required = True)
335+
parser.add_argument("--data_dir", type=str, required=True)
309336
parser.add_argument("--max_res", type=int, default=256, help="Image size")
310-
parser.add_argument("--tfrecord_dir", type=str, required = True, help = 'place to dump')
311-
parser.add_argument("--write", action = 'store_true', help = "Whether to write")
312-
hps = parser.parse_args() # So error if typo
337+
parser.add_argument("--tfrecord_dir", type=str,
338+
required=True, help='place to dump')
339+
parser.add_argument("--write", action='store_true',
340+
help="Whether to write")
341+
hps = parser.parse_args() # So error if typo
313342
#dump_imagenet(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'validation', hps.write)
314343
#dump_imagenet(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'train', hps.write)
315-
dump_celebahq(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'validation', hps.write)
316-
dump_celebahq(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'train', hps.write)
317-
344+
dump_celebahq(hps.data_dir, hps.tfrecord_dir,
345+
hps.max_res, 'validation', hps.write)
346+
dump_celebahq(hps.data_dir, hps.tfrecord_dir,
347+
hps.max_res, 'train', hps.write)

data_loaders/generate_tfr/imagenet_oord.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,17 @@
3333

3434
from typing import Iterable
3535

36+
3637
def _int64_feature(value):
3738
if not isinstance(value, Iterable):
3839
value = [value]
3940
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
4041

42+
4143
def _bytes_feature(value):
4244
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
4345

46+
4447
def dump(fn_root, tfrecord_dir, max_res, expected_images, shards, write):
4548
"""Main converter function."""
4649
# fn_root = FLAGS.fn_root
@@ -57,19 +60,23 @@ def dump(fn_root, tfrecord_dir, max_res, expected_images, shards, write):
5760
assert num_examples == expected_images
5861

5962
# Sharding
60-
tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
63+
tfr_opt = tf.python_io.TFRecordOptions(
64+
tf.python_io.TFRecordCompressionType.NONE)
6165
p_shard = np.array_split(np.random.permutation(expected_images), shards)
6266
img_to_shard = np.zeros(expected_images, dtype=np.int)
6367
writers = []
6468
for shard in range(shards):
6569
img_to_shard[p_shard[shard]] = shard
66-
tfr_file = tfr_prefix + '-r%02d-s-%04d-of-%04d.tfrecords' % (resolution_log2, shard, shards)
70+
tfr_file = tfr_prefix + \
71+
'-r%02d-s-%04d-of-%04d.tfrecords' % (
72+
resolution_log2, shard, shards)
6773
writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
6874

6975
# print(np.unique(img_to_shard, return_counts=True))
7076
counts = np.unique(img_to_shard, return_counts=True)[1]
7177
assert len(counts) == shards
72-
print("Smallest and largest shards have size", np.min(counts), np.max(counts))
78+
print("Smallest and largest shards have size",
79+
np.min(counts), np.max(counts))
7380

7481
for example_idx, img_fn in enumerate(tqdm(img_fn_list)):
7582
shard = img_to_shard[example_idx]
@@ -105,8 +112,10 @@ def dump(fn_root, tfrecord_dir, max_res, expected_images, shards, write):
105112

106113
parser = argparse.ArgumentParser()
107114
parser.add_argument("--res", type=int, default=32, help="Image size")
108-
parser.add_argument("--tfrecord_dir", type=str, required=True, help='place to dump')
109-
parser.add_argument("--write", action='store_true', help="Whether to write")
115+
parser.add_argument("--tfrecord_dir", type=str,
116+
required=True, help='place to dump')
117+
parser.add_argument("--write", action='store_true',
118+
help="Whether to write")
110119
hps = parser.parse_args()
111120

112121
# Imagenet
@@ -127,9 +136,9 @@ def dump(fn_root, tfrecord_dir, max_res, expected_images, shards, write):
127136

128137
for split in ['validation', 'train']:
129138
fn_root = _FILE[split]
130-
tfrecord_dir = os.path.join(hps.tfrecord_dir, split)
139+
tfrecord_dir = os.path.join(hps.tfrecord_dir, split)
131140
total_imgs = _NUM_IMAGES[split]
132141
shards = _NUM_SHARDS[split]
133142
if not os.path.exists(tfrecord_dir):
134143
os.mkdir(tfrecord_dir)
135-
dump(fn_root, tfrecord_dir, hps.res, total_imgs, shards, hps.write)
144+
dump(fn_root, tfrecord_dir, hps.res, total_imgs, shards, hps.write)

0 commit comments

Comments
 (0)