Skip to content

Commit ab5c41d

Browse files
committed
add GAN with selftaughtlearning for unseen task processing
1 parent e12ccac commit ab5c41d

31 files changed

+2940
-1
lines changed

README_ospp.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# README
2+
------
3+
I modified sedna source code to get my algorithm integrated into lifelong learning. I would display the framework and illustrute what I modified.
4+
## Framework
5+
![](framework_with_gan_selftaughtlearning.png)
6+
7+
## What I Modified
8+
9+
1. Delete redundant annotations.
10+
2. Replce `print` with `logger`.
11+
3. Integrate my algorithm into `sedna.lib.sedna.algorithms.unseen_task_processing.unseen_task_processing.py`.
12+
4. Replace absolute path with relative path.
13+
5. Remove redundant code.
14+
6. Provide link for developers to download trained model.
15+
16+
## What I Refer to
17+
I refer to [FastGAN-pytorch](https://github.com/odegeasslbc/FastGAN-pytorch) to implement my GAN module.
18+
19+
## Model to Download
20+
In `GAN.lpips.weights`, developers may need the pre-trained model.
21+
Also, the trained GAN model and trained encoder model is also provided.
22+
Click here [link](https://drive.google.com/drive/folders/1IOQCQ3sntxrbt7RtJIsSlBo0PFrR7Ets?usp=share_link) to download.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import train
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Differentiable Augmentation for Data-Efficient GAN Training
2+
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3+
# https://arxiv.org/pdf/2006.10738
4+
5+
import torch
6+
import torch.nn.functional as F
7+
8+
9+
def DiffAugment(x, policy='', channels_first=True):
10+
if policy:
11+
if not channels_first:
12+
x = x.permute(0, 3, 1, 2)
13+
for p in policy.split(','):
14+
for f in AUGMENT_FNS[p]:
15+
x = f(x)
16+
if not channels_first:
17+
x = x.permute(0, 2, 3, 1)
18+
x = x.contiguous()
19+
return x
20+
21+
22+
def rand_brightness(x):
23+
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
24+
return x
25+
26+
27+
def rand_saturation(x):
28+
x_mean = x.mean(dim=1, keepdim=True)
29+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
30+
return x
31+
32+
33+
def rand_contrast(x):
34+
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
35+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
36+
return x
37+
38+
39+
def rand_translation(x, ratio=0.125):
40+
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
41+
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
42+
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
43+
grid_batch, grid_x, grid_y = torch.meshgrid(
44+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
45+
torch.arange(x.size(2), dtype=torch.long, device=x.device),
46+
torch.arange(x.size(3), dtype=torch.long, device=x.device),
47+
)
48+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
49+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
50+
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
51+
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
52+
return x
53+
54+
55+
def rand_cutout(x, ratio=0.5):
56+
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
57+
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
58+
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
59+
grid_batch, grid_x, grid_y = torch.meshgrid(
60+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
61+
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
62+
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
63+
)
64+
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
65+
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
66+
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
67+
mask[grid_batch, grid_x, grid_y] = 0
68+
x = x * mask.unsqueeze(1)
69+
return x
70+
71+
72+
AUGMENT_FNS = {
73+
'color': [rand_brightness, rand_saturation, rand_contrast],
74+
'translation': [rand_translation],
75+
'cutout': [rand_cutout],
76+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
3+
from models import Generator, weights_init
4+
5+
import matplotlib.pyplot as plt
6+
7+
import os
8+
9+
from collections import OrderedDict
10+
11+
import numpy as np
12+
13+
from skimage import io
14+
15+
16+
device = 'cuda'
17+
18+
ngf = 64
19+
nz = 256
20+
im_size = 1024
21+
netG = Generator(ngf=ngf, nz=nz, im_size=im_size).to(device)
22+
weights_init(netG)
23+
weights = torch.load(os.getcwd() + '/train_results/test1/models/50000.pth')
24+
netG_weights = OrderedDict()
25+
for name, weight in weights['g'].items():
26+
name = name.split('.')[1:]
27+
name = '.'.join(name)
28+
netG_weights[name] = weight
29+
netG.load_state_dict(netG_weights)
30+
current_batch_size = 1
31+
32+
33+
index = 1
34+
while index <= 3000:
35+
noise = torch.Tensor(current_batch_size, nz).normal_(0, 1).to(device)
36+
fake_images = netG(noise)[0]
37+
for fake_image in fake_images:
38+
fake_image = fake_image.detach().cpu().numpy().transpose(1, 2, 0)
39+
fake_image = fake_image * np.array([0.5, 0.5, 0.5])
40+
fake_image = fake_image + np.array([0.5, 0.5, 0.5])
41+
fake_image = (fake_image * 255).astype(np.uint8)
42+
io.imsave('../data/fake_imgs1/' + str(index) + '.png', fake_image)
43+
index += 1
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
6+
import numpy as np
7+
import skimage
8+
import torch
9+
from torch.autograd import Variable
10+
11+
from lpips import dist_model
12+
13+
14+
from skimage.metrics import structural_similarity as compare_ssim
15+
16+
17+
class PerceptualLoss(torch.nn.Module):
18+
# VGG using our perceptually-learned weights (LPIPS metric)
19+
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]):
20+
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
21+
super(PerceptualLoss, self).__init__()
22+
print('Setting up Perceptual loss...')
23+
self.use_gpu = use_gpu
24+
self.spatial = spatial
25+
self.gpu_ids = gpu_ids
26+
self.model = dist_model.DistModel()
27+
self.model.initialize(model=model, net=net, use_gpu=use_gpu,
28+
colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
29+
print('...[%s] initialized' % self.model.name())
30+
print('...Done')
31+
32+
def forward(self, pred, target, normalize=False):
33+
"""
34+
Pred and target are Variables.
35+
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
36+
If normalize is False, assumes the images are already between [-1,+1]
37+
38+
Inputs pred and target are Nx3xHxW
39+
Output pytorch Variable N long
40+
"""
41+
42+
if normalize:
43+
target = 2 * target - 1
44+
pred = 2 * pred - 1
45+
46+
return self.model.forward(target, pred)
47+
48+
49+
def normalize_tensor(in_feat, eps=1e-10):
50+
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
51+
return in_feat/(norm_factor+eps)
52+
53+
54+
def l2(p0, p1, range=255.):
55+
return .5*np.mean((p0 / range - p1 / range)**2)
56+
57+
58+
def psnr(p0, p1, peak=255.):
59+
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
60+
61+
62+
def dssim(p0, p1, range=255.):
63+
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
64+
65+
66+
def rgb2lab(in_img, mean_cent=False):
67+
from skimage import color
68+
img_lab = color.rgb2lab(in_img)
69+
if(mean_cent):
70+
img_lab[:, :, 0] = img_lab[:, :, 0]-50
71+
return img_lab
72+
73+
74+
def tensor2np(tensor_obj):
75+
# change dimension of a tensor object into a numpy array
76+
return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))
77+
78+
79+
def np2tensor(np_obj):
80+
# change dimenion of np array into tensor array
81+
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
82+
83+
84+
def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
85+
# image tensor to lab tensor
86+
from skimage import color
87+
88+
img = tensor2im(image_tensor)
89+
img_lab = color.rgb2lab(img)
90+
if(mc_only):
91+
img_lab[:, :, 0] = img_lab[:, :, 0]-50
92+
if(to_norm and not mc_only):
93+
img_lab[:, :, 0] = img_lab[:, :, 0]-50
94+
img_lab = img_lab/100.
95+
96+
return np2tensor(img_lab)
97+
98+
99+
def tensorlab2tensor(lab_tensor, return_inbnd=False):
100+
from skimage import color
101+
import warnings
102+
warnings.filterwarnings("ignore")
103+
104+
lab = tensor2np(lab_tensor)*100.
105+
lab[:, :, 0] = lab[:, :, 0]+50
106+
107+
rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')), 0, 1)
108+
if(return_inbnd):
109+
# convert back to lab, see if we match
110+
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
111+
mask = 1.*np.isclose(lab_back, lab, atol=2.)
112+
mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
113+
return (im2tensor(rgb_back), mask)
114+
else:
115+
return im2tensor(rgb_back)
116+
117+
118+
def rgb2lab(input):
119+
from skimage import color
120+
return color.rgb2lab(input / 255.)
121+
122+
123+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
124+
image_numpy = image_tensor[0].cpu().float().numpy()
125+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
126+
return image_numpy.astype(imtype)
127+
128+
129+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
130+
return torch.Tensor((image / factor - cent)
131+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
132+
133+
134+
def tensor2vec(vector_tensor):
135+
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
136+
137+
138+
def voc_ap(rec, prec, use_07_metric=False):
139+
""" ap = voc_ap(rec, prec, [use_07_metric])
140+
Compute VOC AP given precision and recall.
141+
If use_07_metric is true, uses the
142+
VOC 07 11 point method (default:False).
143+
"""
144+
if use_07_metric:
145+
# 11 point metric
146+
ap = 0.
147+
for t in np.arange(0., 1.1, 0.1):
148+
if np.sum(rec >= t) == 0:
149+
p = 0
150+
else:
151+
p = np.max(prec[rec >= t])
152+
ap = ap + p / 11.
153+
else:
154+
# correct AP calculation
155+
# first append sentinel values at the end
156+
mrec = np.concatenate(([0.], rec, [1.]))
157+
mpre = np.concatenate(([0.], prec, [0.]))
158+
159+
# compute the precision envelope
160+
for i in range(mpre.size - 1, 0, -1):
161+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
162+
163+
# to calculate area under PR curve, look for points
164+
# where X axis (recall) changes value
165+
i = np.where(mrec[1:] != mrec[:-1])[0]
166+
167+
# and sum (\Delta recall) * prec
168+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
169+
return ap
170+
171+
172+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
173+
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
174+
image_numpy = image_tensor[0].cpu().float().numpy()
175+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
176+
return image_numpy.astype(imtype)
177+
178+
179+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
180+
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
181+
return torch.Tensor((image / factor - cent)
182+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
import torch
3+
from torch.autograd import Variable
4+
from pdb import set_trace as st
5+
from IPython import embed
6+
7+
class BaseModel():
8+
def __init__(self):
9+
pass;
10+
11+
def name(self):
12+
return 'BaseModel'
13+
14+
def initialize(self, use_gpu=True, gpu_ids=[0]):
15+
self.use_gpu = use_gpu
16+
self.gpu_ids = gpu_ids
17+
18+
def forward(self):
19+
pass
20+
21+
def get_image_paths(self):
22+
pass
23+
24+
def optimize_parameters(self):
25+
pass
26+
27+
def get_current_visuals(self):
28+
return self.input
29+
30+
def get_current_errors(self):
31+
return {}
32+
33+
def save(self, label):
34+
pass
35+
36+
# helper saving function that can be used by subclasses
37+
def save_network(self, network, path, network_label, epoch_label):
38+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
39+
save_path = os.path.join(path, save_filename)
40+
torch.save(network.state_dict(), save_path)
41+
42+
# helper loading function that can be used by subclasses
43+
def load_network(self, network, network_label, epoch_label):
44+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
45+
save_path = os.path.join(self.save_dir, save_filename)
46+
print('Loading network from %s'%save_path)
47+
network.load_state_dict(torch.load(save_path))
48+
49+
def update_learning_rate():
50+
pass
51+
52+
def get_image_paths(self):
53+
return self.image_paths
54+
55+
def save_done(self, flag=False):
56+
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
57+
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
58+

0 commit comments

Comments
 (0)