Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@ Code for reproducing results in ["Glow: Generative Flow with Invertible 1x1 Conv
## Requirements

- Tensorflow (tested with v1.8.0)
- Horovod (tested with v0.13.4) and (Open)MPI
- Horovod (tested with v0.13.8) and (Open)MPI

Run
```
pip install -r requirements.txt
```

To setup (Open)MPI, check instructions on Horovod github [page](https://github.com/uber/horovod).

## Download datasets
The datasets are in the Google Cloud locations `https://storage.googleapis.com/glow-demo/data/{dataset_name}-tfr.tar`. The dataset_names are below, we mention the exact preprocessing / downsampling method for a correct comparison of likelihood.
Expand All @@ -21,11 +28,18 @@ Qualitative results

To download and extract celeb for example, run
```
curl https://storage.googleapis.com/glow-demo/data/celeba-tfr.tar
wget https://storage.googleapis.com/glow-demo/data/celeba-tfr.tar
tar -xvf celeb-tfr.tar
```
Change `hps.data_dir` in train.py file to point to the above folder (or use the `--data_dir` flag when you run train.py)

## Simple Train with 1 GPU

Run wtih small depth to test
```
CUDA_VISIBLE_DEVICES=0 python train.py --depth 1
```

## Train with multiple GPUs using MPI and Horovod

Run default training script with 8 GPUs:
Expand Down Expand Up @@ -80,4 +94,4 @@ mpiexec -n 8 python train.py --problem cifar10 --image_size 32 --n_level 3 --dep
##### Conditional ImageNet 32x32 Qualitative result
```
mpiexec -n 8 python train.py --problem imagenet --image_size 32 --n_level 3 --depth 48 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5 --ycond --weight_y=0.01
```
```
1 change: 0 additions & 1 deletion data_loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

74 changes: 52 additions & 22 deletions data_loaders/generate_tfr/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,35 +33,43 @@
_DOWNSAMPLING = tf.image.ResizeMethod.BILINEAR
_SHUFFLE_BUFFER = 1024


def _int64_feature(value):
if not isinstance(value, Iterable):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def error(msg):
print('Error: ' + msg)
exit(1)


def x_to_uint8(x):
return tf.cast(tf.clip_by_value(tf.floor(x), 0, 255), 'uint8')


def centre_crop(img):
h, w = tf.shape(img)[0], tf.shape(img)[1]
min_side = tf.minimum(h, w)
h_offset = (h - min_side) // 2
w_offset = (w - min_side) // 2
return tf.image.crop_to_bounding_box(img, h_offset, w_offset, min_side, min_side)


def downsample(img):
return (img[0::2, 0::2, :] + img[0::2, 1::2, :] + img[1::2, 0::2, :] + img[1::2, 1::2, :]) * 0.25


def parse_image(max_res):
def _process_image(img):
img = centre_crop(img)
img = tf.image.resize_images(img, [max_res, max_res], method=_DOWNSAMPLING)
img = tf.image.resize_images(
img, [max_res, max_res], method=_DOWNSAMPLING)
img = tf.cast(img, 'float32')
resolution_log2 = int(np.log2(max_res))
q_imgs = []
Expand Down Expand Up @@ -89,7 +97,8 @@ def _parse_image(example):

return _parse_image

def parse_celeba_image(max_res, transpose = False):

def parse_celeba_image(max_res, transpose=False):
def _process_image(img):
img = tf.cast(img, 'float32')
resolution_log2 = int(np.log2(max_res))
Expand All @@ -112,26 +121,29 @@ def _parse_image(example):
data = tf.decode_raw(data, tf.uint8)
img = tf.reshape(data, shape)
if transpose:
img = tf.transpose(img, (1,2,0)) # CHW -> HWC
img = tf.transpose(img, (1, 2, 0)) # CHW -> HWC
imgs = _process_image(img)
parsed = (attr, *imgs)
return parsed

return _parse_image


def get_tfr_files(data_dir, split, lgres):
data_dir = os.path.join(data_dir, split)
tfr_prefix = os.path.join(data_dir, os.path.basename(data_dir))
tfr_files = tfr_prefix + '-r%02d-s-*-of-*.tfrecords' % (lgres)
return tfr_files


def get_tfr_file(data_dir, split, lgres):
if split:
data_dir = os.path.join(data_dir, split)
tfr_prefix = os.path.join(data_dir, os.path.basename(data_dir))
tfr_file = tfr_prefix + '-r%02d.tfrecords' % (lgres)
return tfr_file


def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write):
_NUM_IMAGES = {
'train': 27000,
Expand All @@ -150,7 +162,8 @@ def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write):
if split:
tfr_files = get_tfr_files(data_dir, split, int(np.log2(max_res)))
files = tf.data.Dataset.list_files(tfr_files)
dset = files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS))
dset = files.apply(tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS))
transpose = False
else:
tfr_file = get_tfr_file(data_dir, "", int(np.log2(max_res)))
Expand All @@ -173,10 +186,12 @@ def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write):
if write:
tfr.add_image(0, imgs, attr)
if write:
assert tfr.cur_images == total_imgs, (tfr.cur_images, total_imgs)
assert tfr.cur_images == total_imgs, (
tfr.cur_images, total_imgs)

#attr, *imgs = sess.run([_attr, *_imgs])


def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):
_NUM_IMAGES = {
'train': 1281167,
Expand All @@ -194,9 +209,11 @@ def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):
with tf.Session() as sess:
is_training = (split == 'train')
if is_training:
files = tf.data.Dataset.list_files(os.path.join(data_dir, 'train-*-of-01024'))
files = tf.data.Dataset.list_files(
os.path.join(data_dir, 'train-*-of-01024'))
else:
files = tf.data.Dataset.list_files(os.path.join(data_dir, 'validation-*-of-00128'))
files = tf.data.Dataset.list_files(
os.path.join(data_dir, 'validation-*-of-00128'))

files = files.shuffle(buffer_size=_NUM_FILES[split])

Expand All @@ -205,7 +222,8 @@ def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):

dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
parse_fn = parse_image(max_res)
dataset = dataset.map(parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS)
dataset = dataset.map(
parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS)
dataset = dataset.prefetch(1)
iterator = dataset.make_one_shot_iterator()

Expand All @@ -225,10 +243,12 @@ def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):

#label, *imgs = sess.run([_label, *_imgs])


class TFRecordExporter:
def __init__(self, tfrecord_dir, resolution_log2, expected_images, shards, print_progress=True, progress_interval=10):
self.tfrecord_dir = tfrecord_dir
self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir))
self.tfr_prefix = os.path.join(
self.tfrecord_dir, os.path.basename(self.tfrecord_dir))
self.resolution_log2 = resolution_log2
self.expected_images = expected_images

Expand All @@ -242,19 +262,24 @@ def __init__(self, tfrecord_dir, resolution_log2, expected_images, shards, print
if not os.path.isdir(self.tfrecord_dir):
os.makedirs(self.tfrecord_dir)
assert (os.path.isdir(self.tfrecord_dir))
tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
tfr_opt = tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.NONE)
for lod in range(self.resolution_log2 - 1):
p_shard = np.array_split(np.random.permutation(expected_images),shards)
img_to_shard = np.zeros(expected_images, dtype = np.int)
p_shard = np.array_split(
np.random.permutation(expected_images), shards)
img_to_shard = np.zeros(expected_images, dtype=np.int)
writers = []
for shard in range(shards):
img_to_shard[p_shard[shard]] = shard
tfr_file = self.tfr_prefix + '-r%02d-s-%04d-of-%04d.tfrecords' % (self.resolution_log2 - lod, shard, shards)
tfr_file = self.tfr_prefix + \
'-r%02d-s-%04d-of-%04d.tfrecords' % (
self.resolution_log2 - lod, shard, shards)
writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
#print(np.unique(img_to_shard, return_counts=True))
counts = np.unique(img_to_shard, return_counts=True)[1]
assert len(counts) == shards
print("Smallest and largest shards have size", np.min(counts), np.max(counts))
print("Smallest and largest shards have size",
np.min(counts), np.max(counts))
self.tfr_writers.append((writers, img_to_shard))

def close(self):
Expand Down Expand Up @@ -286,7 +311,8 @@ def add_image(self, label, imgs, attr):
}
)
)
writers[img_to_shard[self.cur_images]].write(ex.SerializeToString())
writers[img_to_shard[self.cur_images]].write(
ex.SerializeToString())
self.cur_images += 1

# def add_labels(self, labels):
Expand All @@ -302,16 +328,20 @@ def __enter__(self):
def __exit__(self, *args):
self.close()


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, required = True)
parser.add_argument("--data_dir", type=str, required=True)
parser.add_argument("--max_res", type=int, default=256, help="Image size")
parser.add_argument("--tfrecord_dir", type=str, required = True, help = 'place to dump')
parser.add_argument("--write", action = 'store_true', help = "Whether to write")
hps = parser.parse_args() # So error if typo
parser.add_argument("--tfrecord_dir", type=str,
required=True, help='place to dump')
parser.add_argument("--write", action='store_true',
help="Whether to write")
hps = parser.parse_args() # So error if typo
#dump_imagenet(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'validation', hps.write)
#dump_imagenet(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'train', hps.write)
dump_celebahq(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'validation', hps.write)
dump_celebahq(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'train', hps.write)

dump_celebahq(hps.data_dir, hps.tfrecord_dir,
hps.max_res, 'validation', hps.write)
dump_celebahq(hps.data_dir, hps.tfrecord_dir,
hps.max_res, 'train', hps.write)
23 changes: 16 additions & 7 deletions data_loaders/generate_tfr/imagenet_oord.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@

from typing import Iterable


def _int64_feature(value):
if not isinstance(value, Iterable):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def dump(fn_root, tfrecord_dir, max_res, expected_images, shards, write):
"""Main converter function."""
# fn_root = FLAGS.fn_root
Expand All @@ -57,19 +60,23 @@ def dump(fn_root, tfrecord_dir, max_res, expected_images, shards, write):
assert num_examples == expected_images

# Sharding
tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
tfr_opt = tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.NONE)
p_shard = np.array_split(np.random.permutation(expected_images), shards)
img_to_shard = np.zeros(expected_images, dtype=np.int)
writers = []
for shard in range(shards):
img_to_shard[p_shard[shard]] = shard
tfr_file = tfr_prefix + '-r%02d-s-%04d-of-%04d.tfrecords' % (resolution_log2, shard, shards)
tfr_file = tfr_prefix + \
'-r%02d-s-%04d-of-%04d.tfrecords' % (
resolution_log2, shard, shards)
writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))

# print(np.unique(img_to_shard, return_counts=True))
counts = np.unique(img_to_shard, return_counts=True)[1]
assert len(counts) == shards
print("Smallest and largest shards have size", np.min(counts), np.max(counts))
print("Smallest and largest shards have size",
np.min(counts), np.max(counts))

for example_idx, img_fn in enumerate(tqdm(img_fn_list)):
shard = img_to_shard[example_idx]
Expand Down Expand Up @@ -105,8 +112,10 @@ def dump(fn_root, tfrecord_dir, max_res, expected_images, shards, write):

parser = argparse.ArgumentParser()
parser.add_argument("--res", type=int, default=32, help="Image size")
parser.add_argument("--tfrecord_dir", type=str, required=True, help='place to dump')
parser.add_argument("--write", action='store_true', help="Whether to write")
parser.add_argument("--tfrecord_dir", type=str,
required=True, help='place to dump')
parser.add_argument("--write", action='store_true',
help="Whether to write")
hps = parser.parse_args()

# Imagenet
Expand All @@ -127,9 +136,9 @@ def dump(fn_root, tfrecord_dir, max_res, expected_images, shards, write):

for split in ['validation', 'train']:
fn_root = _FILE[split]
tfrecord_dir = os.path.join(hps.tfrecord_dir, split)
tfrecord_dir = os.path.join(hps.tfrecord_dir, split)
total_imgs = _NUM_IMAGES[split]
shards = _NUM_SHARDS[split]
if not os.path.exists(tfrecord_dir):
os.mkdir(tfrecord_dir)
dump(fn_root, tfrecord_dir, hps.res, total_imgs, shards, hps.write)
dump(fn_root, tfrecord_dir, hps.res, total_imgs, shards, hps.write)
Loading