-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdata_utils.py
More file actions
49 lines (40 loc) · 1.42 KB
/
data_utils.py
File metadata and controls
49 lines (40 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from mxnet import nd
import numpy as np
import mxnet as mx
class Dataset(object):
def __init__(self, img_list, img_dims=64, batch_size=64):
self.img_list = img_list
self.img_dims = img_dims
self.batch_size = batch_size
self.num_batches = len(img_list) // batch_size
self.cur = 0
def __iter__(self):
return self
def __next__(self):
self.next()
def next(self):
if self.cur + self.batch_size < len(self.img_list):
batch = self.img_list[self.cur: self.cur + self.batch_size]
batch = self.process_batch(batch)
self.cur += self.batch_size
return batch
else:
raise StopIteration()
def reset(self):
self.cur = 0
def has_next(self):
return self.cur + self.batch_size < len(self.img_list)
@staticmethod
def transform(img, dims):
data = mx.image.imread(img)
data = mx.image.imresize(data, dims, dims)
data = nd.transpose(data, (2, 0, 1))
# normalize to [-1, 1]
data = data.astype(np.float32) / 127.5 - 1
# if image is greyscale, repeat 3 times to get RGB image.
if data.shape[0] == 1:
data = nd.tile(data, (3, 1, 1))
return data.reshape((1,) + data.shape)
def process_batch(self, batch):
imgs = list(map(lambda x: self.transform(x, self.img_dims), batch))
return nd.concatenate(imgs)