-
Notifications
You must be signed in to change notification settings - Fork 190
Expand file tree
/
Copy pathmnist.py
More file actions
91 lines (77 loc) · 3.76 KB
/
mnist.py
File metadata and controls
91 lines (77 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
Mnist Data loader, as given in Mnist tutorial
"""
import imageio
import torch
import torchvision.utils as v_utils
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset, Dataset
class MnistDataLoader:
def __init__(self, config):
"""
:param config:
"""
self.config = config
if config.data_mode == "download":
self.train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=self.config.batch_size, shuffle=True, num_workers=self.config.data_loader_workers, pin_memory=self.config.pin_memory)
self.test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=self.config.test_batch_size, shuffle=True, num_workers=self.config.data_loader_workers, pin_memory=self.config.pin_memory)
elif config.data_mode == "imgs":
raise NotImplementedError("This mode is not implemented YET")
elif config.data_mode == "numpy":
raise NotImplementedError("This mode is not implemented YET")
elif config.data_mode == "random":
train_data = torch.randn(self.config.batch_size, self.config.input_channels, self.config.img_size, self.config.img_size)
train_labels = torch.ones(self.config.batch_size).long()
valid_data = train_data
valid_labels = train_labels
self.len_train_data = train_data.size()[0]
self.len_valid_data = valid_data.size()[0]
self.train_iterations = (self.len_train_data + self.config.batch_size - 1) // self.config.batch_size
self.valid_iterations = (self.len_valid_data + self.config.batch_size - 1) // self.config.batch_size
train = TensorDataset(train_data, train_labels)
valid = TensorDataset(valid_data, valid_labels)
self.train_loader = DataLoader(train, batch_size=config.batch_size, shuffle=True)
self.test_loader = DataLoader(valid, batch_size=config.batch_size, shuffle=False)
else:
raise Exception("Please specify in the json a specified mode in data_mode")
def plot_samples_per_epoch(self, batch, epoch):
"""
Plotting the batch images
:param batch: Tensor of shape (B,C,H,W)
:param epoch: the number of current epoch
:return: img_epoch: which will contain the image of this epoch
"""
img_epoch = '{}samples_epoch_{:d}.png'.format(self.config.out_dir, epoch)
v_utils.save_image(batch,
img_epoch,
nrow=4,
padding=2,
normalize=True)
return imageio.imread(img_epoch)
def make_gif(self, epochs):
"""
Make a gif from a multiple images of epochs
:param epochs: num_epochs till now
:return:
"""
gen_image_plots = []
for epoch in range(epochs + 1):
img_epoch = '{}samples_epoch_{:d}.png'.format(self.config.out_dir, epoch)
try:
gen_image_plots.append(imageio.imread(img_epoch))
except OSError as e:
pass
imageio.mimsave(self.config.out_dir + 'animation_epochs_{:d}.gif'.format(epochs), gen_image_plots, fps=2)
def finalize(self):
pass