diff --git a/docs/proposals/algorithms/[OSPP]GAN and Self-taught Learning/[OSPP]A Proposal of Integrating GAN and Self-taught Learning to ianvs.md b/docs/proposals/algorithms/[OSPP]GAN and Self-taught Learning/[OSPP]A Proposal of Integrating GAN and Self-taught Learning to ianvs.md new file mode 100644 index 00000000..c90e3f1a --- /dev/null +++ b/docs/proposals/algorithms/[OSPP]GAN and Self-taught Learning/[OSPP]A Proposal of Integrating GAN and Self-taught Learning to ianvs.md @@ -0,0 +1,33 @@ +# Integrate GAN and Self-taught Learning into ianvs Lifelong Learning to Handle Unknown Tasks + +## Motivation + +In the process of ianvs lifelong learning, there would be a chance to confront unknown tasks, whose data are always heterogeneous small sample. Generate Adversarial Networks(GAN) is the start-of-art generative model and GAN can generate fake data according to the distribution of the real data. Naturally, we try to utilize GAN to handle small sample problem. Self-taught learning is an approach to improve classfication performance using sparse coding to construct higher-level features with the unlabeled data. Hence, we combine GAN and self-taught learning to help ianvs lifelong learning handle unknown tasks. + +### Goals + +* Handle unknown tasks +* Implement of a lightweight GAN to solve small sample problem +* Utilize self-taught learning to solve heterogeneous problem + +## Proposal +We focus on the process of handling unknown tasks. + +The overview is as follows: + +![](images/overview.png) + +The process is illustrated as below: +1. GAN exploits the unknown task sample to generate more fake sample. +2. Self-taught learning unit utilize the fake sample and orginal unknown task sample and its label to train a classifier. +3. A well trained classifier is output. + +### GAN Design +We use the networks design by [TOWARDS FASTER AND STABILIZED GAN TRAINING FOR HIGH-FIDELITY FEW-SHOT IMAGE SYNTHESIS](https://openreview.net/forum?id=1Fqg133qRaI). The design is aimed for small training data and pour computing devices. Therefore, it is perfectly suitable for handling unkwnon tasks of ianvs lifelong learning. The networks is shown below. + +![](images/GAN.png) + +### Self-taught Learing Design +Self-taught learning uses unlabeled data to find the latent feature of data and then makes every labeled data a represention using the latent feature and uses the represention and label corresponding to train classifier. + +![](images/self-taught%20learning.png) \ No newline at end of file diff --git a/docs/proposals/algorithms/[OSPP]GAN and Self-taught Learning/images/GAN.png b/docs/proposals/algorithms/[OSPP]GAN and Self-taught Learning/images/GAN.png new file mode 100644 index 00000000..6712fc76 Binary files /dev/null and b/docs/proposals/algorithms/[OSPP]GAN and Self-taught Learning/images/GAN.png differ diff --git a/docs/proposals/algorithms/[OSPP]GAN and Self-taught Learning/images/overview.png b/docs/proposals/algorithms/[OSPP]GAN and Self-taught Learning/images/overview.png new file mode 100644 index 00000000..4f1c1575 Binary files /dev/null and b/docs/proposals/algorithms/[OSPP]GAN and Self-taught Learning/images/overview.png differ diff --git a/docs/proposals/algorithms/[OSPP]GAN and Self-taught Learning/images/self-taught learning.png b/docs/proposals/algorithms/[OSPP]GAN and Self-taught Learning/images/self-taught learning.png new file mode 100644 index 00000000..f42a1a5d Binary files /dev/null and b/docs/proposals/algorithms/[OSPP]GAN and Self-taught Learning/images/self-taught learning.png differ diff --git a/examples/GANwithSelf-taughtLearning/GAN/__init__.py b/examples/GANwithSelf-taughtLearning/GAN/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/GANwithSelf-taughtLearning/GAN/diffaug.py b/examples/GANwithSelf-taughtLearning/GAN/diffaug.py new file mode 100644 index 00000000..54c0894f --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/GAN/diffaug.py @@ -0,0 +1,76 @@ +# Differentiable Augmentation for Data-Efficient GAN Training +# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han +# https://arxiv.org/pdf/2006.10738 + +import torch +import torch.nn.functional as F + + +def DiffAugment(x, policy='', channels_first=True): + if policy: + if not channels_first: + x = x.permute(0, 3, 1, 2) + for p in policy.split(','): + for f in AUGMENT_FNS[p]: + x = f(x) + if not channels_first: + x = x.permute(0, 2, 3, 1) + x = x.contiguous() + return x + + +def rand_brightness(x): + x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) + return x + + +def rand_saturation(x): + x_mean = x.mean(dim=1, keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean + return x + + +def rand_contrast(x): + x_mean = x.mean(dim=[1, 2, 3], keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean + return x + + +def rand_translation(x, ratio=0.125): + shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) + translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(x.size(2), dtype=torch.long, device=x.device), + torch.arange(x.size(3), dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) + grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) + x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) + x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) + return x + + +def rand_cutout(x, ratio=0.5): + cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) + offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(cutout_size[0], dtype=torch.long, device=x.device), + torch.arange(cutout_size[1], dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) + grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) + mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) + mask[grid_batch, grid_x, grid_y] = 0 + x = x * mask.unsqueeze(1) + return x + + +AUGMENT_FNS = { + 'color': [rand_brightness, rand_saturation, rand_contrast], + 'translation': [rand_translation], + 'cutout': [rand_cutout], +} \ No newline at end of file diff --git a/examples/GANwithSelf-taughtLearning/GAN/generate_fake_imgs.py b/examples/GANwithSelf-taughtLearning/GAN/generate_fake_imgs.py new file mode 100644 index 00000000..4f0bb609 --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/GAN/generate_fake_imgs.py @@ -0,0 +1,46 @@ +import torch + +from models import Generator, weights_init + +import matplotlib.pyplot as plt + +import os + +from collections import OrderedDict + +import numpy as np + +from skimage import io + +# print(os.getcwd()) + +device = 'cuda' + +ngf = 64 +nz = 256 +im_size = 1024 +netG = Generator(ngf=ngf, nz=nz, im_size=im_size).to(device) +weights_init(netG) +weights = torch.load(os.getcwd() + '/train_results/test1/models/50000.pth') +# print(weights['g']) +netG_weights = OrderedDict() +for name, weight in weights['g'].items(): + name = name.split('.')[1:] + name = '.'.join(name) + netG_weights[name] = weight +netG.load_state_dict(netG_weights) +current_batch_size = 1 + + +index = 1 +while index <= 3000: + noise = torch.Tensor(current_batch_size, nz).normal_(0, 1).to(device) + fake_images = netG(noise)[0] + for fake_image in fake_images: + fake_image = fake_image.detach().cpu().numpy().transpose(1, 2, 0) + fake_image = fake_image * np.array([0.5, 0.5, 0.5]) + fake_image = fake_image + np.array([0.5, 0.5, 0.5]) + fake_image = (fake_image * 255).astype(np.uint8) + io.imsave('../data/fake_imgs1/' + str(index) + '.png', fake_image) + print('figure {} done'.format(index)) + index += 1 \ No newline at end of file diff --git a/examples/GANwithSelf-taughtLearning/GAN/lpips/__init__.py b/examples/GANwithSelf-taughtLearning/GAN/lpips/__init__.py new file mode 100644 index 00000000..1878b3ff --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/GAN/lpips/__init__.py @@ -0,0 +1,168 @@ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import skimage +import torch +from torch.autograd import Variable + +from lpips import dist_model + + +if skimage.__version__ == '0.14.3': + from skimage.measure import compare_ssim +else: + from skimage.metrics import structural_similarity as compare_ssim + + + +class PerceptualLoss(torch.nn.Module): + def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) + # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss + super(PerceptualLoss, self).__init__() + print('Setting up Perceptual loss...') + self.use_gpu = use_gpu + self.spatial = spatial + self.gpu_ids = gpu_ids + self.model = dist_model.DistModel() + self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) + print('...[%s] initialized'%self.model.name()) + print('...Done') + + def forward(self, pred, target, normalize=False): + """ + Pred and target are Variables. + If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] + If normalize is False, assumes the images are already between [-1,+1] + + Inputs pred and target are Nx3xHxW + Output pytorch Variable N long + """ + + if normalize: + target = 2 * target - 1 + pred = 2 * pred - 1 + + return self.model.forward(target, pred) + +def normalize_tensor(in_feat,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) + return in_feat/(norm_factor+eps) + +def l2(p0, p1, range=255.): + return .5*np.mean((p0 / range - p1 / range)**2) + +def psnr(p0, p1, peak=255.): + return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) + +def dssim(p0, p1, range=255.): + return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. + +def rgb2lab(in_img,mean_cent=False): + from skimage import color + img_lab = color.rgb2lab(in_img) + if(mean_cent): + img_lab[:,:,0] = img_lab[:,:,0]-50 + return img_lab + +def tensor2np(tensor_obj): + # change dimension of a tensor object into a numpy array + return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) + +def np2tensor(np_obj): + # change dimenion of np array into tensor array + return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + +def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): + # image tensor to lab tensor + from skimage import color + + img = tensor2im(image_tensor) + img_lab = color.rgb2lab(img) + if(mc_only): + img_lab[:,:,0] = img_lab[:,:,0]-50 + if(to_norm and not mc_only): + img_lab[:,:,0] = img_lab[:,:,0]-50 + img_lab = img_lab/100. + + return np2tensor(img_lab) + +def tensorlab2tensor(lab_tensor,return_inbnd=False): + from skimage import color + import warnings + warnings.filterwarnings("ignore") + + lab = tensor2np(lab_tensor)*100. + lab[:,:,0] = lab[:,:,0]+50 + + rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) + if(return_inbnd): + # convert back to lab, see if we match + lab_back = color.rgb2lab(rgb_back.astype('uint8')) + mask = 1.*np.isclose(lab_back,lab,atol=2.) + mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) + return (im2tensor(rgb_back),mask) + else: + return im2tensor(rgb_back) + +def rgb2lab(input): + from skimage import color + return color.rgb2lab(input / 255.) + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + +def tensor2vec(vector_tensor): + return vector_tensor.data.cpu().numpy()[:, :, 0, 0] + +def voc_ap(rec, prec, use_07_metric=False): + """ ap = voc_ap(rec, prec, [use_07_metric]) + Compute VOC AP given precision and recall. + If use_07_metric is true, uses the + VOC 07 11 point method (default:False). + """ + if use_07_metric: + # 11 point metric + ap = 0. + for t in np.arange(0., 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): +# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): +# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) diff --git a/examples/GANwithSelf-taughtLearning/GAN/lpips/base_model.py b/examples/GANwithSelf-taughtLearning/GAN/lpips/base_model.py new file mode 100644 index 00000000..9fdb9306 --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/GAN/lpips/base_model.py @@ -0,0 +1,58 @@ +import os +import torch +from torch.autograd import Variable +from pdb import set_trace as st +from IPython import embed + +class BaseModel(): + def __init__(self): + pass; + + def name(self): + return 'BaseModel' + + def initialize(self, use_gpu=True, gpu_ids=[0]): + self.use_gpu = use_gpu + self.gpu_ids = gpu_ids + + def forward(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, path, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(path, save_filename) + torch.save(network.state_dict(), save_path) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + print('Loading network from %s'%save_path) + network.load_state_dict(torch.load(save_path)) + + def update_learning_rate(): + pass + + def get_image_paths(self): + return self.image_paths + + def save_done(self, flag=False): + np.save(os.path.join(self.save_dir, 'done_flag'),flag) + np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') + diff --git a/examples/GANwithSelf-taughtLearning/GAN/lpips/dist_model.py b/examples/GANwithSelf-taughtLearning/GAN/lpips/dist_model.py new file mode 100644 index 00000000..4ff0aa4c --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/GAN/lpips/dist_model.py @@ -0,0 +1,284 @@ + +from __future__ import absolute_import + +import sys +import numpy as np +import torch +from torch import nn +import os +from collections import OrderedDict +from torch.autograd import Variable +import itertools +from .base_model import BaseModel +from scipy.ndimage import zoom +import fractions +import functools +import skimage.transform +from tqdm import tqdm + +from IPython import embed + +from . import networks_basic as networks +import lpips as util + +class DistModel(BaseModel): + def name(self): + return self.model_name + + def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, + use_gpu=True, printNet=False, spatial=False, + is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): + ''' + INPUTS + model - ['net-lin'] for linearly calibrated network + ['net'] for off-the-shelf network + ['L2'] for L2 distance in Lab colorspace + ['SSIM'] for ssim in RGB colorspace + net - ['squeeze','alex','vgg'] + model_path - if None, will look in weights/[NET_NAME].pth + colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM + use_gpu - bool - whether or not to use a GPU + printNet - bool - whether or not to print network architecture out + spatial - bool - whether to output an array containing varying distances across spatial dimensions + spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). + spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. + spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). + is_train - bool - [True] for training mode + lr - float - initial learning rate + beta1 - float - initial momentum term for adam + version - 0.1 for latest, 0.0 was original (with a bug) + gpu_ids - int array - [0] by default, gpus to use + ''' + BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) + + self.model = model + self.net = net + self.is_train = is_train + self.spatial = spatial + self.gpu_ids = gpu_ids + self.model_name = '%s [%s]'%(model,net) + + if(self.model == 'net-lin'): # pretrained net + linear layer + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, + use_dropout=True, spatial=spatial, version=version, lpips=True) + kw = {} + if not use_gpu: + kw['map_location'] = 'cpu' + if(model_path is None): + import inspect + model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) + + if(not is_train): + print('Loading model from: %s'%model_path) + self.net.load_state_dict(torch.load(model_path, **kw), strict=False) + + elif(self.model=='net'): # pretrained network + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) + elif(self.model in ['L2','l2']): + self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing + self.model_name = 'L2' + elif(self.model in ['DSSIM','dssim','SSIM','ssim']): + self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) + self.model_name = 'SSIM' + else: + raise ValueError("Model [%s] not recognized." % self.model) + + self.parameters = list(self.net.parameters()) + + if self.is_train: # training mode + # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) + self.rankLoss = networks.BCERankingLoss() + self.parameters += list(self.rankLoss.net.parameters()) + self.lr = lr + self.old_lr = lr + self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) + else: # test mode + self.net.eval() + + if(use_gpu): + self.net.to(gpu_ids[0]) + self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) + if(self.is_train): + self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 + + if(printNet): + print('---------- Networks initialized -------------') + networks.print_network(self.net) + print('-----------------------------------------------') + + def forward(self, in0, in1, retPerLayer=False): + ''' Function computes the distance between image patches in0 and in1 + INPUTS + in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] + OUTPUT + computed distances between in0 and in1 + ''' + + return self.net.forward(in0, in1, retPerLayer=retPerLayer) + + # ***** TRAINING FUNCTIONS ***** + def optimize_parameters(self): + self.forward_train() + self.optimizer_net.zero_grad() + self.backward_train() + self.optimizer_net.step() + self.clamp_weights() + + def clamp_weights(self): + for module in self.net.modules(): + if(hasattr(module, 'weight') and module.kernel_size==(1,1)): + module.weight.data = torch.clamp(module.weight.data,min=0) + + def set_input(self, data): + self.input_ref = data['ref'] + self.input_p0 = data['p0'] + self.input_p1 = data['p1'] + self.input_judge = data['judge'] + + if(self.use_gpu): + self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) + self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) + self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) + self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) + + self.var_ref = Variable(self.input_ref,requires_grad=True) + self.var_p0 = Variable(self.input_p0,requires_grad=True) + self.var_p1 = Variable(self.input_p1,requires_grad=True) + + def forward_train(self): # run forward pass + # print(self.net.module.scaling_layer.shift) + # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) + + self.d0 = self.forward(self.var_ref, self.var_p0) + self.d1 = self.forward(self.var_ref, self.var_p1) + self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) + + self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) + + self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) + + return self.loss_total + + def backward_train(self): + torch.mean(self.loss_total).backward() + + def compute_accuracy(self,d0,d1,judge): + ''' d0, d1 are Variables, judge is a Tensor ''' + d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) + self.old_lr = lr + +def score_2afc_dataset(data_loader, func, name=''): + ''' Function computes Two Alternative Forced Choice (2AFC) score using + distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return numpy array of length N + OUTPUTS + [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators + [1] - dictionary with following elements + d0s,d1s - N arrays containing distances between reference patch to perturbed patches + gts - N array in [0,1], preferred patch selected by human evaluators + (closer to "0" for left patch p0, "1" for right patch p1, + "0.6" means 60pct people preferred right patch, 40pct preferred left) + scores - N array in [0,1], corresponding to what percentage function agreed with humans + CONSTS + N - number of test triplets in data_loader + ''' + + d0s = [] + d1s = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() + d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() + gts+=data['judge'].cpu().numpy().flatten().tolist() + + d0s = np.array(d0s) + d1s = np.array(d1s) + gts = np.array(gts) + scores = (d0s 256: + self.feat_512 = UpBlockComp(nfc[256], nfc[512]) + self.se_512 = SEBlock(nfc[32], nfc[512]) + if im_size > 512: + self.feat_1024 = UpBlock(nfc[512], nfc[1024]) + + def forward(self, input): + + feat_4 = self.init(input) + feat_8 = self.feat_8(feat_4) + feat_16 = self.feat_16(feat_8) + feat_32 = self.feat_32(feat_16) + + feat_64 = self.se_64(feat_4, self.feat_64(feat_32)) + + feat_128 = self.se_128(feat_8, self.feat_128(feat_64)) + + feat_256 = self.se_256(feat_16, self.feat_256(feat_128)) + + if self.im_size == 256: + return [self.to_big(feat_256), self.to_128(feat_128)] + + feat_512 = self.se_512(feat_32, self.feat_512(feat_256)) + if self.im_size == 512: + return [self.to_big(feat_512), self.to_128(feat_128)] + + feat_1024 = self.feat_1024(feat_512) + + im_128 = torch.tanh(self.to_128(feat_128)) + im_1024 = torch.tanh(self.to_big(feat_1024)) + + return [im_1024, im_128] + + +class DownBlock(nn.Module): + def __init__(self, in_planes, out_planes): + super(DownBlock, self).__init__() + + self.main = nn.Sequential( + conv2d(in_planes, out_planes, 4, 2, 1, bias=False), + batchNorm2d(out_planes), nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, feat): + return self.main(feat) + + +class DownBlockComp(nn.Module): + def __init__(self, in_planes, out_planes): + super(DownBlockComp, self).__init__() + + self.main = nn.Sequential( + conv2d(in_planes, out_planes, 4, 2, 1, bias=False), + batchNorm2d(out_planes), nn.LeakyReLU(0.2, inplace=True), + conv2d(out_planes, out_planes, 3, 1, 1, bias=False), + batchNorm2d(out_planes), nn.LeakyReLU(0.2) + ) + + self.direct = nn.Sequential( + nn.AvgPool2d(2, 2), + conv2d(in_planes, out_planes, 1, 1, 0, bias=False), + batchNorm2d(out_planes), nn.LeakyReLU(0.2)) + + def forward(self, feat): + return (self.main(feat) + self.direct(feat)) / 2 + + +class Discriminator(nn.Module): + def __init__(self, ndf=64, nc=3, im_size=512): + super(Discriminator, self).__init__() + self.ndf = ndf + self.im_size = im_size + + nfc_multi = {4: 16, 8: 16, 16: 8, 32: 4, 64: 2, 128: 1, 256: 0.5, 512: 0.25, 1024: 0.125} + nfc = {} + for k, v in nfc_multi.items(): + nfc[k] = int(v * ndf) + + if im_size == 1024: + self.down_from_big = nn.Sequential( + conv2d(nc, nfc[1024], 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + conv2d(nfc[1024], nfc[512], 4, 2, 1, bias=False), + batchNorm2d(nfc[512]), + nn.LeakyReLU(0.2, inplace=True)) + elif im_size == 512: + self.down_from_big = nn.Sequential( + conv2d(nc, nfc[512], 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True)) + elif im_size == 256: + self.down_from_big = nn.Sequential( + conv2d(nc, nfc[512], 3, 1, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True)) + + self.down_4 = DownBlockComp(nfc[512], nfc[256]) + self.down_8 = DownBlockComp(nfc[256], nfc[128]) + self.down_16 = DownBlockComp(nfc[128], nfc[64]) + self.down_32 = DownBlockComp(nfc[64], nfc[32]) + self.down_64 = DownBlockComp(nfc[32], nfc[16]) + + self.rf_big = nn.Sequential( + conv2d(nfc[16], nfc[8], 1, 1, 0, bias=False), + batchNorm2d(nfc[8]), nn.LeakyReLU(0.2, inplace=True), + conv2d(nfc[8], 1, 4, 1, 0, bias=False)) + + self.se_2_16 = SEBlock(nfc[512], nfc[64]) + self.se_4_32 = SEBlock(nfc[256], nfc[32]) + self.se_8_64 = SEBlock(nfc[128], nfc[16]) + + self.down_from_small = nn.Sequential( + conv2d(nc, nfc[256], 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + DownBlock(nfc[256], nfc[128]), + DownBlock(nfc[128], nfc[64]), + DownBlock(nfc[64], nfc[32]), ) + + self.rf_small = conv2d(nfc[32], 1, 4, 1, 0, bias=False) + + self.decoder_big = SimpleDecoder(nfc[16], nc) + self.decoder_part = SimpleDecoder(nfc[32], nc) + self.decoder_small = SimpleDecoder(nfc[32], nc) + + def forward(self, imgs, label, part=None): + if type(imgs) is not list: + imgs = [F.interpolate(imgs, size=self.im_size), F.interpolate(imgs, size=128)] + + feat_2 = self.down_from_big(imgs[0]) + feat_4 = self.down_4(feat_2) + feat_8 = self.down_8(feat_4) + + feat_16 = self.down_16(feat_8) + feat_16 = self.se_2_16(feat_2, feat_16) + + feat_32 = self.down_32(feat_16) + feat_32 = self.se_4_32(feat_4, feat_32) + + feat_last = self.down_64(feat_32) + feat_last = self.se_8_64(feat_8, feat_last) + + # rf_0 = torch.cat([self.rf_big_1(feat_last).view(-1),self.rf_big_2(feat_last).view(-1)]) + # rff_big = torch.sigmoid(self.rf_factor_big) + rf_0 = self.rf_big(feat_last).view(-1) + + feat_small = self.down_from_small(imgs[1]) + # rf_1 = torch.cat([self.rf_small_1(feat_small).view(-1),self.rf_small_2(feat_small).view(-1)]) + rf_1 = self.rf_small(feat_small).view(-1) + + if label == 'real': + rec_img_big = self.decoder_big(feat_last) + rec_img_small = self.decoder_small(feat_small) + + assert part is not None + rec_img_part = None + if part == 0: + rec_img_part = self.decoder_part(feat_32[:, :, :8, :8]) + if part == 1: + rec_img_part = self.decoder_part(feat_32[:, :, :8, 8:]) + if part == 2: + rec_img_part = self.decoder_part(feat_32[:, :, 8:, :8]) + if part == 3: + rec_img_part = self.decoder_part(feat_32[:, :, 8:, 8:]) + + return torch.cat([rf_0, rf_1]), [rec_img_big, rec_img_small, rec_img_part] + + return torch.cat([rf_0, rf_1]) + + +class SimpleDecoder(nn.Module): + """docstring for CAN_SimpleDecoder""" + + def __init__(self, nfc_in=64, nc=3): + super(SimpleDecoder, self).__init__() + + nfc_multi = {4: 16, 8: 8, 16: 4, 32: 2, 64: 2, 128: 1, 256: 0.5, 512: 0.25, 1024: 0.125} + nfc = {} + for k, v in nfc_multi.items(): + nfc[k] = int(v * 32) + + def upBlock(in_planes, out_planes): + block = nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + conv2d(in_planes, out_planes * 2, 3, 1, 1, bias=False), + batchNorm2d(out_planes * 2), GLU()) + return block + + self.main = nn.Sequential(nn.AdaptiveAvgPool2d(8), + upBlock(nfc_in, nfc[16]), + upBlock(nfc[16], nfc[32]), + upBlock(nfc[32], nfc[64]), + upBlock(nfc[64], nfc[128]), + conv2d(nfc[128], nc, 3, 1, 1, bias=False), + nn.Tanh()) + + def forward(self, input): + # input shape: c x 4 x 4 + return self.main(input) + + +from random import randint + + +def random_crop(image, size): + h, w = image.shape[2:] + ch = randint(0, h - size - 1) + cw = randint(0, w - size - 1) + return image[:, :, ch:ch + size, cw:cw + size] + + +class TextureDiscriminator(nn.Module): + def __init__(self, ndf=64, nc=3, im_size=512): + super(TextureDiscriminator, self).__init__() + self.ndf = ndf + self.im_size = im_size + + nfc_multi = {4: 16, 8: 8, 16: 8, 32: 4, 64: 2, 128: 1, 256: 0.5, 512: 0.25, 1024: 0.125} + nfc = {} + for k, v in nfc_multi.items(): + nfc[k] = int(v * ndf) + + self.down_from_small = nn.Sequential( + conv2d(nc, nfc[256], 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + DownBlock(nfc[256], nfc[128]), + DownBlock(nfc[128], nfc[64]), + DownBlock(nfc[64], nfc[32]), ) + self.rf_small = nn.Sequential( + conv2d(nfc[16], 1, 4, 1, 0, bias=False)) + + self.decoder_small = SimpleDecoder(nfc[32], nc) + + def forward(self, img, label): + img = random_crop(img, size=128) + + feat_small = self.down_from_small(img) + rf = self.rf_small(feat_small).view(-1) + + if label == 'real': + rec_img_small = self.decoder_small(feat_small) + + return rf, rec_img_small, img + + return rf \ No newline at end of file diff --git a/examples/GANwithSelf-taughtLearning/GAN/operation.py b/examples/GANwithSelf-taughtLearning/GAN/operation.py new file mode 100644 index 00000000..f61e46bb --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/GAN/operation.py @@ -0,0 +1,137 @@ +import os +import numpy as np +import torch +import torch.utils.data as data +from torch.utils.data import Dataset +from PIL import Image +from copy import deepcopy +import shutil +import json + +def InfiniteSampler(n): + """Data sampler""" + i = n - 1 + order = np.random.permutation(n) + while True: + yield order[i] + i += 1 + if i >= n: + np.random.seed() + order = np.random.permutation(n) + i = 0 + + +class InfiniteSamplerWrapper(data.sampler.Sampler): + """Data sampler wrapper""" + def __init__(self, data_source): + self.num_samples = len(data_source) + + def __iter__(self): + return iter(InfiniteSampler(self.num_samples)) + + def __len__(self): + return 2 ** 31 + + +def copy_G_params(model): + flatten = deepcopy(list(p.data for p in model.parameters())) + return flatten + + +def load_params(model, new_param): + for p, new_p in zip(model.parameters(), new_param): + p.data.copy_(new_p) + + +def get_dir(args): + task_name = 'train_results/' + args['name'] + saved_model_folder = os.path.join( task_name, 'models') + saved_image_folder = os.path.join( task_name, 'images') + + os.makedirs(saved_model_folder, exist_ok=True) + os.makedirs(saved_image_folder, exist_ok=True) + + # for f in os.listdir('./'): + # if '.py' in f: + # shutil.copy(f, task_name+'/'+f) + + # with open( os.path.join(saved_model_folder, '../args.txt'), 'w') as f: + # json.dump(args.__dict__, f, indent=2) + + return saved_model_folder, saved_image_folder + + +class ImageFolder(Dataset): + """docstring for ArtDataset""" + def __init__(self, root, transform=None): + super( ImageFolder, self).__init__() + self.root = root + + self.frame = self._parse_frame() + self.transform = transform + + def _parse_frame(self): + frame = [] + img_names = os.listdir(self.root) + img_names.sort() + for i in range(len(img_names)): + image_path = os.path.join(self.root, img_names[i]) + if image_path[-4:] == '.jpg' or image_path[-4:] == '.png' or image_path[-5:] == '.jpeg': + frame.append(image_path) + return frame + + def __len__(self): + return len(self.frame) + + def __getitem__(self, idx): + file = self.frame[idx] + img = Image.open(file).convert('RGB') + + if self.transform: + img = self.transform(img) + + return img + + + +from io import BytesIO +import lmdb +from torch.utils.data import Dataset + + +class MultiResolutionDataset(Dataset): + def __init__(self, path, transform, resolution=256): + self.env = lmdb.open( + path, + max_readers=32, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + + if not self.env: + raise IOError('Cannot open lmdb dataset', path) + + with self.env.begin(write=False) as txn: + self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) + + self.resolution = resolution + self.transform = transform + + def __len__(self): + return self.length + + def __getitem__(self, index): + with self.env.begin(write=False) as txn: + key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') + img_bytes = txn.get(key) + #key_asp = f'aspect_ratio-{str(index).zfill(5)}'.encode('utf-8') + #aspect_ratio = float(txn.get(key_asp).decode()) + + buffer = BytesIO(img_bytes) + img = Image.open(buffer) + img = self.transform(img) + + return img + diff --git a/examples/GANwithSelf-taughtLearning/GAN/train.py b/examples/GANwithSelf-taughtLearning/GAN/train.py new file mode 100644 index 00000000..2d3be08f --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/GAN/train.py @@ -0,0 +1,220 @@ +import torch +from torch import nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data.dataloader import DataLoader +from torchvision import transforms +from torchvision import utils as vutils + +import argparse +import random +from tqdm import tqdm + +import csv + +from models import weights_init, Discriminator, Generator +from operation import copy_G_params, load_params, get_dir +from operation import ImageFolder, InfiniteSamplerWrapper +from diffaug import DiffAugment +policy = 'color,translation' +import lpips +percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) + +from util import load_yaml + +def crop_image_by_part(image, part): + hw = image.shape[2] // 2 + if part == 0: + return image[:, :, :hw, :hw] + if part == 1: + return image[:, :, :hw, hw:] + if part == 2: + return image[:, :, hw:, :hw] + if part == 3: + return image[:, :, hw:, hw:] + + +def train_d(net, data, label="real"): + """Train function of discriminator""" + if label == "real": + part = random.randint(0, 3) + pred, [rec_all, rec_small, rec_part] = net(data, label, part=part) + err = F.relu(torch.rand_like(pred) * 0.2 + 0.8 - pred).mean() + \ + percept(rec_all, F.interpolate(data, rec_all.shape[2])).sum() + \ + percept(rec_small, F.interpolate(data, rec_small.shape[2])).sum() + \ + percept(rec_part, F.interpolate(crop_image_by_part(data, part), rec_part.shape[2])).sum() + err.backward() + return pred.mean().item(), rec_all, rec_small, rec_part + else: + pred = net(data, label) + err = F.relu(torch.rand_like(pred) * 0.2 + 0.8 + pred).mean() + err.backward() + return pred.mean().item() + + +def train(args): + data_root = args['path'] + total_iterations = args['iter'] + # checkpoint = args.ckpt + batch_size = args['batch_size'] + im_size = args['im_size'] + ndf = 64 + ngf = 64 + nz = 256 + nlr = 0.0002 + nbeta1 = 0.5 + use_cuda = True + multi_gpu = False + dataloader_workers = 8 + current_iteration = 0 + save_interval = 100 + saved_model_folder, saved_image_folder = get_dir(args) + + device = torch.device("cpu") + if use_cuda: + device = torch.device("cuda:0") + + transform_list = [ + transforms.Resize((int(im_size), int(im_size))), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + ] + trans = transforms.Compose(transform_list) + + if 'lmdb' in data_root: + from operation import MultiResolutionDataset + dataset = MultiResolutionDataset(data_root, trans, 1024) + else: + dataset = ImageFolder(root=data_root, transform=trans) + + dataloader = iter(DataLoader(dataset, batch_size=batch_size, shuffle=False, + sampler=InfiniteSamplerWrapper(dataset), num_workers=dataloader_workers, + pin_memory=True)) + ''' + loader = MultiEpochsDataLoader(dataset, batch_size=batch_size, + shuffle=True, num_workers=dataloader_workers, + pin_memory=True) + dataloader = CudaDataLoader(loader, 'cuda') + ''' + + # from model_s import Generator, Discriminator + netG = Generator(ngf=ngf, nz=nz, im_size=im_size) + netG.apply(weights_init) + + netD = Discriminator(ndf=ndf, im_size=im_size) + netD.apply(weights_init) + + netG.to(device) + netD.to(device) + + avg_param_G = copy_G_params(netG) + + fixed_noise = torch.FloatTensor(8, nz).normal_(0, 1).to(device) + + optimizerG = optim.Adam(netG.parameters(), lr=nlr, betas=(nbeta1, 0.999)) + optimizerD = optim.Adam(netD.parameters(), lr=nlr, betas=(nbeta1, 0.999)) + + # if checkpoint != 'None': + # ckpt = torch.load(checkpoint) + # netG.load_state_dict({k.replace('module.', ''): v for k, v in ckpt['g'].items()}) + # netD.load_state_dict({k.replace('module.', ''): v for k, v in ckpt['d'].items()}) + # avg_param_G = ckpt['g_ema'] + # optimizerG.load_state_dict(ckpt['opt_g']) + # optimizerD.load_state_dict(ckpt['opt_d']) + # current_iteration = int(checkpoint.split('_')[-1].split('.')[0]) + # del ckpt + + if multi_gpu: + netG = nn.DataParallel(netG.to(device)) + netD = nn.DataParallel(netD.to(device)) + + with open('train_cityscape.csv', 'w') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['epoch', 'd_loss', 'g_loss']) + + for iteration in tqdm(range(current_iteration, total_iterations + 1)): + real_image = next(dataloader) + real_image = real_image.to(device) + current_batch_size = real_image.size(0) + noise = torch.Tensor(current_batch_size, nz).normal_(0, 1).to(device) + + fake_images = netG(noise) + + real_image = DiffAugment(real_image, policy=policy) + fake_images = [DiffAugment(fake, policy=policy) for fake in fake_images] + + ## 2. train Discriminator + netD.zero_grad() + + err_dr, rec_img_all, rec_img_small, rec_img_part = train_d(netD, real_image, label="real") + train_d(netD, [fi.detach() for fi in fake_images], label="fake") + optimizerD.step() + + ## 3. train Generator + netG.zero_grad() + pred_g = netD(fake_images, "fake") + err_g = -pred_g.mean() + + err_g.backward() + optimizerG.step() + + for p, avg_p in zip(netG.parameters(), avg_param_G): + avg_p.mul_(0.999).add_(0.001 * p.data) + + if iteration % 100 == 0: + print("GAN: loss d: %.5f loss g: %.5f" % (err_dr, -err_g.item())) + + if iteration % (save_interval * 10) == 0: + backup_para = copy_G_params(netG) + load_params(netG, avg_param_G) + with torch.no_grad(): + vutils.save_image(netG(fixed_noise)[0].add(1).mul(0.5), saved_image_folder + '/%d.jpg' % iteration, + nrow=4) + vutils.save_image(torch.cat([ + F.interpolate(real_image, 128), + rec_img_all, rec_img_small, + rec_img_part]).add(1).mul(0.5), saved_image_folder + '/rec_%d.jpg' % iteration) + load_params(netG, backup_para) + + if iteration % (save_interval * 50) == 0 or iteration == total_iterations: + backup_para = copy_G_params(netG) + load_params(netG, avg_param_G) + torch.save({'g': netG.state_dict(), 'd': netD.state_dict()}, saved_model_folder + '/%d.pth' % iteration) + load_params(netG, backup_para) + torch.save({'g': netG.state_dict(), + 'd': netD.state_dict(), + 'g_ema': avg_param_G, + 'opt_g': optimizerG.state_dict(), + 'opt_d': optimizerD.state_dict()}, saved_model_folder + '/all_%d.pth' % iteration) + + with open('train_cityscape.csv', 'a') as csvfile: + writer = csv.writer(csvfile) + writer.writerow([iteration, err_dr, -err_g.item()]) + + +if __name__ == "__main__": + # parser = argparse.ArgumentParser(description='region gan') + # + # parser.add_argument('--path', type=str, default='../lmdbs/art_landscape_1k', + # help='path of resource dataset, should be a folder that has one or many sub image folders inside') + # parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use') + # parser.add_argument('--name', type=str, default='test1', help='experiment name') + # parser.add_argument('--iter', type=int, default=50000, help='number of iterations') + # parser.add_argument('--start_iter', type=int, default=0, help='the iteration to start training') + # parser.add_argument('--batch_size', type=int, default=8, help='mini batch number of images') + # parser.add_argument('--im_size', type=int, default=1024, help='image resolution') + # parser.add_argument('--ckpt', type=str, default='None', help='checkpoint weight path if have one') + + # args = parser.parse_args() + configs = load_yaml('../config.yaml') + print(configs) + args = dict() + args['path'] = configs['dataset_path'] + args['iter'] = configs['GAN'][0]['iter'] + args['batch_size'] = configs['GAN'][1]['batch_size'] + args['im_size'] = configs['GAN'][2]['im_size'] + args['name'] = configs['GAN'][3]['name'] + print(args) + # + train(args) diff --git a/examples/GANwithSelf-taughtLearning/GAN/train_cityscape.csv b/examples/GANwithSelf-taughtLearning/GAN/train_cityscape.csv new file mode 100644 index 00000000..fe7aa9fd --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/GAN/train_cityscape.csv @@ -0,0 +1,18 @@ +epoch,d_loss,g_loss +0,0.20575577020645142,-0.3701064884662628 +1,0.1202360987663269,-0.382342666387558 +2,0.2591087818145752,-0.5365990996360779 +3,0.595060408115387,-0.6694589257240295 +4,0.4233413636684418,-0.45368048548698425 +5,0.7397832274436951,-0.513699471950531 +6,0.4850901961326599,-0.730381429195404 +7,0.6998284459114075,-0.7783517241477966 +8,0.7868878841400146,-0.8684926629066467 +9,0.6464230418205261,-0.9065934419631958 +10,0.6726473569869995,-0.8063123822212219 +11,0.6846868991851807,-0.7946229577064514 +12,0.6749208569526672,-0.8458480834960938 +13,0.7467646598815918,-0.9234620332717896 +14,1.1013246774673462,-1.1124314069747925 +15,0.5670914053916931,-0.8788895010948181 +16,0.814713716506958,-0.9621263146400452 diff --git a/examples/GANwithSelf-taughtLearning/README.md b/examples/GANwithSelf-taughtLearning/README.md new file mode 100644 index 00000000..d3664859 --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/README.md @@ -0,0 +1,71 @@ +# Integrating GAN and Self-taught Learning into Ianvs Lifelong Learning + +## Overview + +We proposal an approach of combining [GAN](https://en.wikipedia.org/wiki/Generative_adversarial_network) and [Self-taught Learning](https://ai.stanford.edu/~hllee/icml07-selftaughtlearning.pdf) to solve small sample problem in the very stage of ianvs lifelong learning, as shown in the figure below. + +For quick start, jumping directly to [Developer Notes](##Developer Notes) is fine. + +![](./readme_img/ianvs-lifelonglearning.png) + +We describe the arichtecture and the process. More details can be seen in the [Architecture](##Architecture). + +![](./readme_img/ianvs-lifelonglearning2.png) + +1. Train GAN with orginal small sample data +2. GAN generates more data according to the probability distribution +3. Train Autoencoder (which is consist of encoder and decoder) with the data generated by GAN +4. Use encoder to get feature representation of original small sample data +5. Use feature representation and orignal labels to train model that the user needs +6. Ouput a well trained model + +## Architecture + +- Overview + + ![](./readme_img/overview.png) + +- GAN (We refer to [Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis](https://openreview.net/forum?id=1Fqg133qRaI).) + + Discriminator + + ![](./readme_img/discriminator.png) + + Generator + + ![](./readme_img/generator.png) + +- Convolutional AutoEncoder of Self-taught Learning + + ![](./readme_img/cae.png) + +## Developer Notes + +```bash +GANwithSelf-taughtLearning # root path of the project + - config.yaml # the config of input path (dataset and model to be trained) as well as hyperparameter of GAN and Self-taught Learning + - GAN # GAN module + - models.py # define discriminator and generator + - train.py # train GAN here + - ./train_results # training outputs, like GAN model, training loss and evaluation of GAN + - self-taught-learning # self-taught learning module + - models.py # Define AutoEncoder. Here we use Convolutional AutoEncoder (CAE). + - train.py # train GAN here + - ./train_results # training outputs, like encoder model and training loss + - model-to-be-trained # model to be trained module + - train.py # train model + - model.py # define model + - ./train_results # training results of the model to be trained + - util.py # util module +``` + +To use, config the `config.yaml` to let `GANwithSelf-taughtLearning` know where the **dataset** is, what **model** you want to train and the **hyperparameters**. + +A common use process can be shown below: + +1. run `./GAN/train.py` +2. run `./self-taught-learning/train.py` +3. run `./train.py` + + + diff --git a/examples/GANwithSelf-taughtLearning/config.yaml b/examples/GANwithSelf-taughtLearning/config.yaml new file mode 100644 index 00000000..8a7fcbf5 --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/config.yaml @@ -0,0 +1,26 @@ +dataset_path: '../data/img' + +# GAN config +GAN: + - iter: 50000 + - batch_size: 8 + - im_size: 1024 + - name: 'test2' + +# Self-taught Learning config +STL: + - iter: 100 + - lr: 1.0e-3 + - batch_size: 30 + - name: 'encoder_models' + +# model to be trained config +# here we take deeplabv3 as example +deeplabv3: + - iter: 1000 + - batch_size: 3 + - lr: 1.0e-4 + - name: "1" + - cityscapes_data_path: "/home/nailtu/data/cityscapes" + - cityscapes_meta_path: '/home/nailtu/data/cityscapes/meta' + - class_weights: '/home/nailtu/data/cityscapes/meta/class_weights.pkl' \ No newline at end of file diff --git a/examples/GANwithSelf-taughtLearning/deeplabv3/__init__.py b/examples/GANwithSelf-taughtLearning/deeplabv3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/GANwithSelf-taughtLearning/deeplabv3/datasets.py b/examples/GANwithSelf-taughtLearning/deeplabv3/datasets.py new file mode 100644 index 00000000..196eed16 --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/deeplabv3/datasets.py @@ -0,0 +1,146 @@ +import torch +import torch.utils.data + +import numpy as np +import cv2 +import os + +train_dirs = ["jena/", "zurich/", "weimar/", "ulm/", "tubingen/", "stuttgart/", + "strasbourg/", "monchengladbach/", "krefeld/", "hanover/", + "hamburg/", "erfurt/", "dusseldorf/", "darmstadt/", "cologne/", + "bremen/", "bochum/", "aachen/"] +val_dirs = ["frankfurt/", "munster/", "lindau/"] +test_dirs = ["berlin", "bielefeld", "bonn", "leverkusen", "mainz", "munich"] + +class DatasetTrain(torch.utils.data.Dataset): + def __init__(self, cityscapes_data_path, cityscapes_meta_path): + self.img_dir = cityscapes_data_path + "/leftImg8bit/train/" + self.label_dir = cityscapes_meta_path + "/label_imgs/" + + self.img_h = 1024 + self.img_w = 2048 + + self.new_img_h = 512 + self.new_img_w = 1024 + + self.examples = [] + for train_dir in train_dirs: + train_img_dir_path = self.img_dir + train_dir + + file_names = os.listdir(train_img_dir_path) + for file_name in file_names: + img_id = file_name.split("_leftImg8bit.png")[0] + + img_path = train_img_dir_path + file_name + + label_img_path = self.label_dir + img_id + ".png" + + example = {} + example["img_path"] = img_path + example["label_img_path"] = label_img_path + example["img_id"] = img_id + self.examples.append(example) + self.examples = self.examples[0:99] + self.num_examples = len(self.examples) + + def __getitem__(self, index): + example = self.examples[index] + + img_path = example["img_path"] + img = cv2.imread(img_path, -1) + # img = cv2.resize(img, (self.new_img_w, self.new_img_h), + # interpolation=cv2.INTER_NEAREST) + label_img_path = example["label_img_path"] + label_img = cv2.imread(label_img_path, -1) + label_img = cv2.resize(label_img, (self.new_img_w, self.new_img_h), + interpolation=cv2.INTER_NEAREST) + # flip = np.random.randint(low=0, high=2) + # if flip == 1: + # img = cv2.flip(img, 1) + # label_img = cv2.flip(label_img, 1) + # scale = np.random.uniform(low=0.7, high=2.0) + # new_img_h = int(scale*self.new_img_h) + # new_img_w = int(scale*self.new_img_w) + # img = cv2.resize(img, (new_img_w, new_img_h), + # interpolation=cv2.INTER_NEAREST) + # label_img = cv2.resize(label_img, (new_img_w, new_img_h), + # interpolation=cv2.INTER_NEAREST) + # start_x = np.random.randint(low=0, high=(new_img_w - 256)) + # end_x = start_x + 256 + # start_y = np.random.randint(low=0, high=(new_img_h - 256)) + # end_y = start_y + 256 + # + # img = img[start_y:end_y, start_x:end_x] + # label_img = label_img[start_y:end_y, start_x:end_x] + img = img/255.0 + img = img - np.array([0.485, 0.456, 0.406]) + img = img/np.array([0.229, 0.224, 0.225]) + img = np.transpose(img, (2, 0, 1)) + img = img.astype(np.float32) + img = torch.from_numpy(img) + label_img = torch.from_numpy(label_img) + + return (img, label_img) + + def __len__(self): + return self.num_examples + +class DatasetVal(torch.utils.data.Dataset): + def __init__(self, cityscapes_data_path, cityscapes_meta_path): + self.img_dir = cityscapes_data_path + "/leftImg8bit/val/" + self.label_dir = cityscapes_meta_path + "/label_imgs/" + + self.img_h = 1024 + self.img_w = 2048 + + self.new_img_h = 1024 + self.new_img_w = 2048 + + self.examples = [] + for val_dir in val_dirs: + val_img_dir_path = self.img_dir + val_dir + + file_names = os.listdir(val_img_dir_path) + for file_name in file_names: + img_id = file_name.split("_leftImg8bit.png")[0] + + img_path = val_img_dir_path + file_name + + label_img_path = self.label_dir + img_id + ".png" + label_img = cv2.imread(label_img_path, -1) + + example = {} + example["img_path"] = img_path + example["label_img_path"] = label_img_path + example["img_id"] = img_id + self.examples.append(example) + self.examples = self.examples[0:99] + self.num_examples = len(self.examples) + + def __getitem__(self, index): + example = self.examples[index] + + img_id = example["img_id"] + + img_path = example["img_path"] + img = cv2.imread(img_path, -1) + img = cv2.resize(img, (self.new_img_w, self.new_img_h), + interpolation=cv2.INTER_NEAREST) + + label_img_path = example["label_img_path"] + label_img = cv2.imread(label_img_path, -1) + label_img = cv2.resize(label_img, (self.new_img_w, self.new_img_h), + interpolation=cv2.INTER_NEAREST) + img = img/255.0 + img = img - np.array([0.485, 0.456, 0.406]) + img = img/np.array([0.229, 0.224, 0.225]) + img = np.transpose(img, (2, 0, 1)) + img = img.astype(np.float32) + + img = torch.from_numpy(img) + label_img = torch.from_numpy(label_img) + + return (img, label_img, img_id) + + def __len__(self): + return self.num_examples \ No newline at end of file diff --git a/examples/GANwithSelf-taughtLearning/deeplabv3/model/aspp.py b/examples/GANwithSelf-taughtLearning/deeplabv3/model/aspp.py new file mode 100644 index 00000000..ccb49b88 --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/deeplabv3/model/aspp.py @@ -0,0 +1,99 @@ +# camera-ready + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ASPP(nn.Module): + def __init__(self, num_classes): + super(ASPP, self).__init__() + + self.conv_1x1_1 = nn.Conv2d(512, 256, kernel_size=1) + self.bn_conv_1x1_1 = nn.BatchNorm2d(256) + + self.conv_3x3_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=6, dilation=6) + self.bn_conv_3x3_1 = nn.BatchNorm2d(256) + + self.conv_3x3_2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=12, dilation=12) + self.bn_conv_3x3_2 = nn.BatchNorm2d(256) + + self.conv_3x3_3 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=18, dilation=18) + self.bn_conv_3x3_3 = nn.BatchNorm2d(256) + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + + self.conv_1x1_2 = nn.Conv2d(512, 256, kernel_size=1) + self.bn_conv_1x1_2 = nn.BatchNorm2d(256) + + self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256) + self.bn_conv_1x1_3 = nn.BatchNorm2d(256) + + self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1) + + def forward(self, feature_map): + # (feature_map has shape (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16. If self.resnet instead is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8)) + + feature_map_h = feature_map.size()[2] # (== h/16) + feature_map_w = feature_map.size()[3] # (== w/16) + + out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) + out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) + out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) + out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) + + out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1)) + out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1)) + out_img = F.upsample(out_img, size=(feature_map_h, feature_map_w), mode="bilinear") # (shape: (batch_size, 256, h/16, w/16)) + + out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16)) + out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16)) + out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16)) + + return out + +class ASPP_Bottleneck(nn.Module): + def __init__(self, num_classes): + super(ASPP_Bottleneck, self).__init__() + + self.conv_1x1_1 = nn.Conv2d(4*512, 256, kernel_size=1) + self.bn_conv_1x1_1 = nn.BatchNorm2d(256) + + self.conv_3x3_1 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=6, dilation=6) + self.bn_conv_3x3_1 = nn.BatchNorm2d(256) + + self.conv_3x3_2 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=12, dilation=12) + self.bn_conv_3x3_2 = nn.BatchNorm2d(256) + + self.conv_3x3_3 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=18, dilation=18) + self.bn_conv_3x3_3 = nn.BatchNorm2d(256) + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + + self.conv_1x1_2 = nn.Conv2d(4*512, 256, kernel_size=1) + self.bn_conv_1x1_2 = nn.BatchNorm2d(256) + + self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256) + self.bn_conv_1x1_3 = nn.BatchNorm2d(256) + + self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1) + + def forward(self, feature_map): + # (feature_map has shape (batch_size, 4*512, h/16, w/16)) + + feature_map_h = feature_map.size()[2] # (== h/16) + feature_map_w = feature_map.size()[3] # (== w/16) + + out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) + out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) + out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) + out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) + + out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1)) + out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1)) + out_img = F.upsample(out_img, size=(feature_map_h, feature_map_w), mode="bilinear") # (shape: (batch_size, 256, h/16, w/16)) + + out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16)) + out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16)) + out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16)) + + return out diff --git a/examples/GANwithSelf-taughtLearning/deeplabv3/model/deeplabv3.py b/examples/GANwithSelf-taughtLearning/deeplabv3/model/deeplabv3.py new file mode 100644 index 00000000..a97107e9 --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/deeplabv3/model/deeplabv3.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import os +import sys +sys.path.append('model') +from resnet import ResNet18_OS16, ResNet34_OS16, ResNet50_OS16, ResNet101_OS16, ResNet152_OS16, ResNet18_OS8, ResNet34_OS8 +from aspp import ASPP, ASPP_Bottleneck + +class DeepLabV3(nn.Module): + def __init__(self, model_id, project_dir): + super(DeepLabV3, self).__init__() + + self.num_classes = 20 + + self.model_id = model_id + self.project_dir = project_dir + self.create_model_dirs() + + self.resnet = ResNet18_OS8() # NOTE! specify the type of ResNet here + self.aspp = ASPP(num_classes=self.num_classes) # NOTE! if you use ResNet50-152, set self.aspp = ASPP_Bottleneck(num_classes=self.num_classes) instead + + def forward(self, x): + # (x has shape (batch_size, 3, h, w)) + + h = x.size()[2] + w = x.size()[3] + + feature_map = self.resnet(x) # (shape: (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16. If self.resnet is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8). If self.resnet is ResNet50-152, it will be (batch_size, 4*512, h/16, w/16)) + # print(feature_map.shape) + output = self.aspp(feature_map) # (shape: (batch_size, num_classes, h/16, w/16)) + + output = F.upsample(output, size=(h, w), mode="bilinear") # (shape: (batch_size, num_classes, h, w)) + + return output + + def create_model_dirs(self): + self.logs_dir = self.project_dir + "/training_logs" + self.model_dir = self.logs_dir + "/model_%s" % self.model_id + self.checkpoints_dir = self.model_dir + "/checkpoints" + if not os.path.exists(self.logs_dir): + os.makedirs(self.logs_dir) + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + os.makedirs(self.checkpoints_dir) diff --git a/examples/GANwithSelf-taughtLearning/deeplabv3/model/resnet.py b/examples/GANwithSelf-taughtLearning/deeplabv3/model/resnet.py new file mode 100644 index 00000000..80d2dacd --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/deeplabv3/model/resnet.py @@ -0,0 +1,233 @@ +# camera-ready + +# NOTE! OS: output stride, the ratio of input image resolution to final output resolution (OS16: output size is (img_h/16, img_w/16)) (OS8: output size is (img_h/8, img_w/8)) + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +def make_layer(block, in_channels, channels, num_blocks, stride=1, dilation=1): + strides = [stride] + [1]*(num_blocks - 1) # (stride == 2, num_blocks == 4 --> strides == [2, 1, 1, 1]) + + blocks = [] + for stride in strides: + blocks.append(block(in_channels=in_channels, channels=channels, stride=stride, dilation=dilation)) + in_channels = block.expansion*channels + + layer = nn.Sequential(*blocks) # (*blocks: call with unpacked list entires as arguments) + + return layer + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_channels, channels, stride=1, dilation=1): + super(BasicBlock, self).__init__() + + out_channels = self.expansion*channels + + self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) + self.bn1 = nn.BatchNorm2d(channels) + + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=False) + self.bn2 = nn.BatchNorm2d(channels) + + if (stride != 1) or (in_channels != out_channels): + conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + bn = nn.BatchNorm2d(out_channels) + self.downsample = nn.Sequential(conv, bn) + else: + self.downsample = nn.Sequential() + + def forward(self, x): + # (x has shape: (batch_size, in_channels, h, w)) + + out = F.relu(self.bn1(self.conv1(x))) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2) + out = self.bn2(self.conv2(out)) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2) + + out = out + self.downsample(x) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2) + + out = F.relu(out) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2) + + return out + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_channels, channels, stride=1, dilation=1): + super(Bottleneck, self).__init__() + + out_channels = self.expansion*channels + + self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(channels) + + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) + self.bn2 = nn.BatchNorm2d(channels) + + self.conv3 = nn.Conv2d(channels, out_channels, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(out_channels) + + if (stride != 1) or (in_channels != out_channels): + conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + bn = nn.BatchNorm2d(out_channels) + self.downsample = nn.Sequential(conv, bn) + else: + self.downsample = nn.Sequential() + + def forward(self, x): + # (x has shape: (batch_size, in_channels, h, w)) + + out = F.relu(self.bn1(self.conv1(x))) # (shape: (batch_size, channels, h, w)) + out = F.relu(self.bn2(self.conv2(out))) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2) + out = self.bn3(self.conv3(out)) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2) + + out = out + self.downsample(x) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2) + + out = F.relu(out) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2) + + return out + +class ResNet_Bottleneck_OS16(nn.Module): + def __init__(self, num_layers): + super(ResNet_Bottleneck_OS16, self).__init__() + + if num_layers == 50: + resnet = models.resnet50() + # load pretrained model: + resnet.load_state_dict(torch.load("/root/deeplabv3/pretrained_models/resnet/resnet50-19c8e357.pth")) + # remove fully connected layer, avg pool and layer5: + self.resnet = nn.Sequential(*list(resnet.children())[:-3]) + + print ("pretrained resnet, 50") + elif num_layers == 101: + resnet = models.resnet101() + # load pretrained model: + resnet.load_state_dict(torch.load("/root/deeplabv3/pretrained_models/resnet/resnet101-5d3b4d8f.pth")) + # remove fully connected layer, avg pool and layer5: + self.resnet = nn.Sequential(*list(resnet.children())[:-3]) + + print ("pretrained resnet, 101") + elif num_layers == 152: + resnet = models.resnet152() + # load pretrained model: + resnet.load_state_dict(torch.load("/root/deeplabv3/pretrained_models/resnet/resnet152-b121ed2d.pth")) + # remove fully connected layer, avg pool and layer5: + self.resnet = nn.Sequential(*list(resnet.children())[:-3]) + + print ("pretrained resnet, 152") + else: + raise Exception("num_layers must be in {50, 101, 152}!") + + self.layer5 = make_layer(Bottleneck, in_channels=4*256, channels=512, num_blocks=3, stride=1, dilation=2) + + def forward(self, x): + # (x has shape (batch_size, 3, h, w)) + + # pass x through (parts of) the pretrained ResNet: + c4 = self.resnet(x) # (shape: (batch_size, 4*256, h/16, w/16)) (it's called c4 since 16 == 2^4) + + output = self.layer5(c4) # (shape: (batch_size, 4*512, h/16, w/16)) + + return output + +class ResNet_BasicBlock_OS16(nn.Module): + def __init__(self, num_layers): + super(ResNet_BasicBlock_OS16, self).__init__() + + if num_layers == 18: + resnet = models.resnet18() + # load pretrained model: + resnet.load_state_dict(torch.load("/root/deeplabv3/pretrained_models/resnet/resnet18-5c106cde.pth")) + # remove fully connected layer, avg pool and layer5: + self.resnet = nn.Sequential(*list(resnet.children())[:-3]) + + num_blocks = 2 + print ("pretrained resnet, 18") + elif num_layers == 34: + resnet = models.resnet34() + # load pretrained model: + resnet.load_state_dict(torch.load("/root/deeplabv3/pretrained_models/resnet/resnet34-333f7ec4.pth")) + # remove fully connected layer, avg pool and layer5: + self.resnet = nn.Sequential(*list(resnet.children())[:-3]) + + num_blocks = 3 + print ("pretrained resnet, 34") + else: + raise Exception("num_layers must be in {18, 34}!") + + self.layer5 = make_layer(BasicBlock, in_channels=256, channels=512, num_blocks=num_blocks, stride=1, dilation=2) + + def forward(self, x): + # (x has shape (batch_size, 3, h, w)) + + # pass x through (parts of) the pretrained ResNet: + c4 = self.resnet(x) # (shape: (batch_size, 256, h/16, w/16)) (it's called c4 since 16 == 2^4) + + output = self.layer5(c4) # (shape: (batch_size, 512, h/16, w/16)) + + return output + +class ResNet_BasicBlock_OS8(nn.Module): + def __init__(self, num_layers): + super(ResNet_BasicBlock_OS8, self).__init__() + + if num_layers == 18: + resnet = models.resnet18() + # load pretrained model: + resnet.load_state_dict(torch.load("/home/nailtu/PycharmProjects/deeplabv3-master/pretrained_models/resnet/resnet18-5c106cde.pth")) + # remove fully connected layer, avg pool, layer4 and layer5: + self.resnet = nn.Sequential(*list(resnet.children())[:-4]) + + num_blocks_layer_4 = 2 + num_blocks_layer_5 = 2 + print ("pretrained resnet, 18") + elif num_layers == 34: + resnet = models.resnet34() + # load pretrained model: + resnet.load_state_dict(torch.load("/root/deeplabv3/pretrained_models/resnet/resnet34-333f7ec4.pth")) + # remove fully connected layer, avg pool, layer4 and layer5: + self.resnet = nn.Sequential(*list(resnet.children())[:-4]) + + num_blocks_layer_4 = 6 + num_blocks_layer_5 = 3 + print ("pretrained resnet, 34") + else: + raise Exception("num_layers must be in {18, 34}!") + + self.layer4 = make_layer(BasicBlock, in_channels=128, channels=256, num_blocks=num_blocks_layer_4, stride=1, dilation=2) + + self.layer5 = make_layer(BasicBlock, in_channels=256, channels=512, num_blocks=num_blocks_layer_5, stride=1, dilation=4) + + def forward(self, x): + # (x has shape (batch_size, 3, h, w)) + + # pass x through (parts of) the pretrained ResNet: + c3 = self.resnet(x) # (shape: (batch_size, 128, h/8, w/8)) (it's called c3 since 8 == 2^3) + + output = self.layer4(c3) # (shape: (batch_size, 256, h/8, w/8)) + output = self.layer5(output) # (shape: (batch_size, 512, h/8, w/8)) + + return output + +def ResNet18_OS16(): + return ResNet_BasicBlock_OS16(num_layers=18) + +def ResNet34_OS16(): + return ResNet_BasicBlock_OS16(num_layers=34) + +def ResNet50_OS16(): + return ResNet_Bottleneck_OS16(num_layers=50) + +def ResNet101_OS16(): + return ResNet_Bottleneck_OS16(num_layers=101) + +def ResNet152_OS16(): + return ResNet_Bottleneck_OS16(num_layers=152) + +def ResNet18_OS8(): + return ResNet_BasicBlock_OS8(num_layers=18) + +def ResNet34_OS8(): + return ResNet_BasicBlock_OS8(num_layers=34) diff --git a/examples/GANwithSelf-taughtLearning/deeplabv3/train.py b/examples/GANwithSelf-taughtLearning/deeplabv3/train.py new file mode 100644 index 00000000..c7ac21ac --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/deeplabv3/train.py @@ -0,0 +1,147 @@ +import os + +from datasets import DatasetTrain, DatasetVal +from model.deeplabv3 import DeepLabV3 +import sys +from utils.utils import add_weight_decay + +import torch +import torch.utils.data +import torch.nn as nn +from torch.autograd import Variable +import torch.optim as optim +import torch.nn.functional as F + +import numpy as np +import pickle +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +from util import load_yaml + +class Encoder(nn.Module): + def __init__(self) -> None: + super(Encoder, self).__init__() + self.enc3 = nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, stride=2, padding=1 + ) + self.enc4 = nn.Conv2d( + in_channels=8, out_channels=3, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = F.relu(self.enc3(x)) + x = F.relu(self.enc4(x)) + return x + + +if __name__ == '__main__': + configs = load_yaml('../config.yaml') + model_id = configs['deeplabv3'][3]['name'] + + encoder = Encoder().cuda() + encoder.load_state_dict(torch.load('../self-taught-learning/train_results/encoder_models4/encoder50.pth')) + + num_epochs = configs['deeplabv3'][0]['iter'] + batch_size = configs['deeplabv3'][1]['batch_size'] + learning_rate = configs['deeplabv3'][2]['lr'] + + network = DeepLabV3(model_id, project_dir=os.getcwd()).cuda() + + train_dataset = DatasetTrain(cityscapes_data_path=configs['deeplabv3'][4]['cityscapes_data_path'], + cityscapes_meta_path=configs['deeplabv3'][5]['cityscapes_meta_path']) + val_dataset = DatasetVal(cityscapes_data_path=configs['deeplabv3'][4]['cityscapes_data_path'], + cityscapes_meta_path=configs['deeplabv3'][5]['cityscapes_meta_path']) + + num_train_batches = int(len(train_dataset)/batch_size) + num_val_batches = int(len(val_dataset)/batch_size) + print ("num_train_batches:", num_train_batches) + print ("num_val_batches:", num_val_batches) + + train_loader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=batch_size, shuffle=True, + num_workers=1) + val_loader = torch.utils.data.DataLoader(dataset=val_dataset, + batch_size=batch_size, shuffle=False, + num_workers=1) + + params = add_weight_decay(network, l2_value=0.0001) + optimizer = torch.optim.Adam(params, lr=learning_rate) + + with open(configs['deeplabv3'][6]['class_weights'], "rb") as file: + class_weights = np.array(pickle.load(file)) + class_weights = torch.from_numpy(class_weights) + class_weights = Variable(class_weights.type(torch.FloatTensor)).cuda() + + loss_fn = nn.CrossEntropyLoss(weight=class_weights) + + epoch_losses_train = [] + epoch_losses_val = [] + for epoch in range(num_epochs): + print ("###########################") + print ("######## NEW EPOCH ########") + print ("###########################") + print ("epoch: %d/%d" % (epoch+1, num_epochs)) + + ############################################################################ + # train: + ############################################################################ + network.train() # (set in training mode, this affects BatchNorm and dropout) + batch_losses = [] + for step, (imgs, label_imgs) in enumerate(train_loader): + imgs = Variable(imgs).cuda() + imgs = encoder(imgs) + label_imgs = Variable(label_imgs.type(torch.LongTensor)).cuda() + + outputs = network(imgs) + + loss = loss_fn(outputs, label_imgs) + loss_value = loss.data.cpu().numpy() + batch_losses.append(loss_value) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + epoch_loss = np.mean(batch_losses) + epoch_losses_train.append(epoch_loss) + with open("%s/epoch_losses_train.pkl" % network.model_dir, "wb") as file: + pickle.dump(epoch_losses_train, file) + print ("train loss: %g" % epoch_loss) + plt.figure(1) + plt.plot(epoch_losses_train, "k^") + plt.plot(epoch_losses_train, "k") + plt.ylabel("loss") + plt.xlabel("epoch") + plt.title("train loss per epoch") + plt.savefig("%s/epoch_losses_train.png" % network.model_dir) + plt.close(1) + + print ("####") + + network.eval() + batch_losses = [] + for step, (imgs, label_imgs, img_ids) in enumerate(val_loader): + with torch.no_grad(): + imgs = Variable(imgs).cuda() + label_imgs = Variable(label_imgs.type(torch.LongTensor)).cuda() + + outputs = network(imgs) + loss = loss_fn(outputs, label_imgs) + loss_value = loss.data.cpu().numpy() + batch_losses.append(loss_value) + + epoch_loss = np.mean(batch_losses) + epoch_losses_val.append(epoch_loss) + with open("%s/epoch_losses_val.pkl" % network.model_dir, "wb") as file: + pickle.dump(epoch_losses_val, file) + print ("val loss: %g" % epoch_loss) + plt.figure(1) + plt.plot(epoch_losses_val, "k^") + plt.plot(epoch_losses_val, "k") + plt.ylabel("loss") + plt.xlabel("epoch") + plt.title("val loss per epoch") + plt.savefig("%s/epoch_losses_val.png" % network.model_dir) + plt.close(1) diff --git a/examples/GANwithSelf-taughtLearning/deeplabv3/utils/__init__.py b/examples/GANwithSelf-taughtLearning/deeplabv3/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/GANwithSelf-taughtLearning/deeplabv3/utils/preprocess_data.py b/examples/GANwithSelf-taughtLearning/deeplabv3/utils/preprocess_data.py new file mode 100644 index 00000000..ca9dc83b --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/deeplabv3/utils/preprocess_data.py @@ -0,0 +1,190 @@ +# camera-ready + +import pickle +import numpy as np +import cv2 +import os +from collections import namedtuple + +# (NOTE! this is taken from the official Cityscapes scripts:) +Label = namedtuple( 'Label' , [ + + 'name' , # The identifier of this label, e.g. 'car', 'person', ... . + # We use them to uniquely name a class + + 'id' , # An integer ID that is associated with this label. + # The IDs are used to represent the label in ground truth images + # An ID of -1 means that this label does not have an ID and thus + # is ignored when creating ground truth images (e.g. license plate). + # Do not modify these IDs, since exactly these IDs are expected by the + # evaluation server. + + 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create + # ground truth images with train IDs, using the tools provided in the + # 'preparation' folder. However, make sure to validate or submit results + # to our evaluation server using the regular IDs above! + # For trainIds, multiple labels might have the same ID. Then, these labels + # are mapped to the same class in the ground truth images. For the inverse + # mapping, we use the label that is defined first in the list below. + # For example, mapping all void-type classes to the same ID in training, + # might make sense for some approaches. + # Max value is 255! + + 'category' , # The name of the category that this label belongs to + + 'categoryId' , # The ID of this category. Used to create ground truth images + # on category level. + + 'hasInstances', # Whether this label distinguishes between single instances or not + + 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored + # during evaluations or not + + 'color' , # The color of this label + ] ) + +# (NOTE! this is taken from the official Cityscapes scripts:) +labels = [ + # name id trainId category catId hasInstances ignoreInEval color + Label( 'unlabeled' , 0 , 19 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'ego vehicle' , 1 , 19 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'rectification border' , 2 , 19 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'out of roi' , 3 , 19 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'static' , 4 , 19 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'dynamic' , 5 , 19 , 'void' , 0 , False , True , (111, 74, 0) ), + Label( 'ground' , 6 , 19 , 'void' , 0 , False , True , ( 81, 0, 81) ), + Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), + Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), + Label( 'parking' , 9 , 19 , 'flat' , 1 , False , True , (250,170,160) ), + Label( 'rail track' , 10 , 19 , 'flat' , 1 , False , True , (230,150,140) ), + Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), + Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), + Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), + Label( 'guard rail' , 14 , 19 , 'construction' , 2 , False , True , (180,165,180) ), + Label( 'bridge' , 15 , 19 , 'construction' , 2 , False , True , (150,100,100) ), + Label( 'tunnel' , 16 , 19 , 'construction' , 2 , False , True , (150,120, 90) ), + Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), + Label( 'polegroup' , 18 , 19 , 'object' , 3 , False , True , (153,153,153) ), + Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), + Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), + Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), + Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), + Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), + Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), + Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), + Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), + Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), + Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), + Label( 'caravan' , 29 , 19 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), + Label( 'trailer' , 30 , 19 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), + Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), + Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), + Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), + Label( 'license plate' , -1 , 19 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), +] + +# create a function which maps id to trainId: +id_to_trainId = {label.id: label.trainId for label in labels} +id_to_trainId_map_func = np.vectorize(id_to_trainId.get) + +train_dirs = ["jena/", "zurich/", "weimar/", "ulm/", "tubingen/", "stuttgart/", + "strasbourg/", "monchengladbach/", "krefeld/", "hanover/", + "hamburg/", "erfurt/", "dusseldorf/", "darmstadt/", "cologne/", + "bremen/", "bochum/", "aachen/"] +val_dirs = ["frankfurt/", "munster/", "lindau/"] +test_dirs = ["berlin", "bielefeld", "bonn", "leverkusen", "mainz", "munich"] + +cityscapes_data_path = "/home/nailtu/data/cityscapes" +cityscapes_meta_path = "/home/nailtu/data/cityscapes/meta" + +if not os.path.exists(cityscapes_meta_path): + os.makedirs(cityscapes_meta_path) +if not os.path.exists(cityscapes_meta_path + "/label_imgs"): + os.makedirs(cityscapes_meta_path + "/label_imgs") + +################################################################################ +# convert all labels to label imgs with trainId pixel values (and save to disk): +################################################################################ +train_label_img_paths = [] + +img_dir = cityscapes_data_path + "/leftImg8bit/train/" +label_dir = cityscapes_data_path + "/gtFine/train/" +for train_dir in train_dirs: + print (train_dir) + + train_img_dir_path = img_dir + train_dir + train_label_dir_path = label_dir + train_dir + + file_names = os.listdir(train_img_dir_path) + for file_name in file_names: + img_id = file_name.split("_leftImg8bit.png")[0] + + gtFine_img_path = train_label_dir_path + img_id + "_gtFine_labelIds.png" + gtFine_img = cv2.imread(gtFine_img_path, -1) # (shape: (1024, 2048)) + + # convert gtFine_img from id to trainId pixel values: + label_img = id_to_trainId_map_func(gtFine_img) # (shape: (1024, 2048)) + label_img = label_img.astype(np.uint8) + + cv2.imwrite(cityscapes_meta_path + "/label_imgs/" + img_id + ".png", label_img) + train_label_img_paths.append(cityscapes_meta_path + "/label_imgs/" + img_id + ".png") + +img_dir = cityscapes_data_path + "/leftImg8bit/val/" +label_dir = cityscapes_data_path + "/gtFine/val/" +for val_dir in val_dirs: + print (val_dir) + + val_img_dir_path = img_dir + val_dir + val_label_dir_path = label_dir + val_dir + + file_names = os.listdir(val_img_dir_path) + for file_name in file_names: + img_id = file_name.split("_leftImg8bit.png")[0] + + gtFine_img_path = val_label_dir_path + img_id + "_gtFine_labelIds.png" + gtFine_img = cv2.imread(gtFine_img_path, -1) # (shape: (1024, 2048)) + + # convert gtFine_img from id to trainId pixel values: + label_img = id_to_trainId_map_func(gtFine_img) # (shape: (1024, 2048)) + label_img = label_img.astype(np.uint8) + + cv2.imwrite(cityscapes_meta_path + "/label_imgs/" + img_id + ".png", label_img) + +################################################################################ +# compute the class weigths: +################################################################################ +print ("computing class weights") + +num_classes = 20 + +trainId_to_count = {} +for trainId in range(num_classes): + trainId_to_count[trainId] = 0 + +# get the total number of pixels in all train label_imgs that are of each object class: +for step, label_img_path in enumerate(train_label_img_paths): + if step % 100 == 0: + print (step) + + label_img = cv2.imread(label_img_path, -1) + + for trainId in range(num_classes): + # count how many pixels in label_img which are of object class trainId: + trainId_mask = np.equal(label_img, trainId) + trainId_count = np.sum(trainId_mask) + + # add to the total count: + trainId_to_count[trainId] += trainId_count + +# compute the class weights according to the ENet paper: +class_weights = [] +total_count = sum(trainId_to_count.values()) +for trainId, count in trainId_to_count.items(): + trainId_prob = float(count)/float(total_count) + trainId_weight = 1/np.log(1.02 + trainId_prob) + class_weights.append(trainId_weight) + +print (class_weights) + +with open(cityscapes_meta_path + "/class_weights.pkl", "wb") as file: + pickle.dump(class_weights, file, protocol=2) # (protocol=2 is needed to be able to open this file with python2) diff --git a/examples/GANwithSelf-taughtLearning/deeplabv3/utils/random_code.py b/examples/GANwithSelf-taughtLearning/deeplabv3/utils/random_code.py new file mode 100644 index 00000000..4c19463f --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/deeplabv3/utils/random_code.py @@ -0,0 +1,23 @@ +# camera-ready + +# this file contains code snippets which I have found (more or less) useful at +# some point during the project. Probably nothing interesting to see here. + +import pickle +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +import numpy as np + +model_id = "13_2_2_2" + +with open("/home/fregu856/exjobb/training_logs/multitask/model_" + model_id + "/epoch_losses_train.pkl", "rb") as file: + train_loss = pickle.load(file) + +with open("/home/fregu856/exjobb/training_logs/multitask/model_" + model_id + "/epoch_losses_val.pkl", "rb") as file: + val_loss = pickle.load(file) + +print ("train loss min:", np.argmin(np.array(train_loss)), np.min(np.array(train_loss))) + +print ("val loss min:", np.argmin(np.array(val_loss)), np.min(np.array(val_loss))) diff --git a/examples/GANwithSelf-taughtLearning/deeplabv3/utils/utils.py b/examples/GANwithSelf-taughtLearning/deeplabv3/utils/utils.py new file mode 100644 index 00000000..b1844c10 --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/deeplabv3/utils/utils.py @@ -0,0 +1,56 @@ +# camera-ready + +import torch +import torch.nn as nn + +import numpy as np + +def add_weight_decay(net, l2_value, skip_list=()): + # https://raberrytv.wordpress.com/2017/10/29/pytorch-weight-decay-made-easy/ + + decay, no_decay = [], [] + for name, param in net.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + no_decay.append(param) + else: + decay.append(param) + + return [{'params': no_decay, 'weight_decay': 0.0}, {'params': decay, 'weight_decay': l2_value}] + +# function for colorizing a label image: +def label_img_to_color(img): + label_to_color = { + 0: [128, 64,128], + 1: [244, 35,232], + 2: [ 70, 70, 70], + 3: [102,102,156], + 4: [190,153,153], + 5: [153,153,153], + 6: [250,170, 30], + 7: [220,220, 0], + 8: [107,142, 35], + 9: [152,251,152], + 10: [ 70,130,180], + 11: [220, 20, 60], + 12: [255, 0, 0], + 13: [ 0, 0,142], + 14: [ 0, 0, 70], + 15: [ 0, 60,100], + 16: [ 0, 80,100], + 17: [ 0, 0,230], + 18: [119, 11, 32], + 19: [81, 0, 81] + } + + img_height, img_width = img.shape + + img_color = np.zeros((img_height, img_width, 3)) + for row in range(img_height): + for col in range(img_width): + label = img[row, col] + + img_color[row, col] = np.array(label_to_color[label]) + + return img_color diff --git a/examples/GANwithSelf-taughtLearning/imgs/cae.png b/examples/GANwithSelf-taughtLearning/imgs/cae.png new file mode 100644 index 00000000..57dcb21a Binary files /dev/null and b/examples/GANwithSelf-taughtLearning/imgs/cae.png differ diff --git a/examples/GANwithSelf-taughtLearning/imgs/discriminator.png b/examples/GANwithSelf-taughtLearning/imgs/discriminator.png new file mode 100644 index 00000000..d8bfa8df Binary files /dev/null and b/examples/GANwithSelf-taughtLearning/imgs/discriminator.png differ diff --git a/examples/GANwithSelf-taughtLearning/imgs/generator.png b/examples/GANwithSelf-taughtLearning/imgs/generator.png new file mode 100644 index 00000000..dd07fd94 Binary files /dev/null and b/examples/GANwithSelf-taughtLearning/imgs/generator.png differ diff --git a/examples/GANwithSelf-taughtLearning/imgs/ianvs-lifelonglearning.png b/examples/GANwithSelf-taughtLearning/imgs/ianvs-lifelonglearning.png new file mode 100644 index 00000000..5a5497fa Binary files /dev/null and b/examples/GANwithSelf-taughtLearning/imgs/ianvs-lifelonglearning.png differ diff --git a/examples/GANwithSelf-taughtLearning/imgs/ianvs-lifelonglearning2.png b/examples/GANwithSelf-taughtLearning/imgs/ianvs-lifelonglearning2.png new file mode 100644 index 00000000..b409c4dc Binary files /dev/null and b/examples/GANwithSelf-taughtLearning/imgs/ianvs-lifelonglearning2.png differ diff --git a/examples/GANwithSelf-taughtLearning/imgs/overview.png b/examples/GANwithSelf-taughtLearning/imgs/overview.png new file mode 100644 index 00000000..4596bbe5 Binary files /dev/null and b/examples/GANwithSelf-taughtLearning/imgs/overview.png differ diff --git a/examples/GANwithSelf-taughtLearning/self-taught-learning/__init__.py b/examples/GANwithSelf-taughtLearning/self-taught-learning/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/GANwithSelf-taughtLearning/self-taught-learning/models.py b/examples/GANwithSelf-taughtLearning/self-taught-learning/models.py new file mode 100644 index 00000000..b1eb6fea --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/self-taught-learning/models.py @@ -0,0 +1,47 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class Encoder(nn.Module): + def __init__(self) -> None: + super(Encoder, self).__init__() + self.enc1 = nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, stride=2, padding=1 + ) + self.enc2 = nn.Conv2d( + in_channels=8, out_channels=3, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = F.relu(self.enc1(x)) + x = F.relu(self.enc2(x)) + return x + + +class Decoder(nn.Module): + def __init__(self) -> None: + super(Decoder, self).__init__() + self.dec1 = nn.ConvTranspose2d( + in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1, output_padding=0 + ) + self.dec2 = nn.ConvTranspose2d( + in_channels=8, out_channels=3, kernel_size=3, stride=2, padding=1, output_padding=1 + ) + + def forward(self, x): + x = F.relu(self.dec1(x)) + x = F.relu(self.dec2(x)) + return x + + +class Autoencoder(nn.Module): + def __init__(self) -> None: + super(Autoencoder, self).__init__() + self.encoder = Encoder() + self.decoder = Decoder() + + def forward(self, x): + x = self.encoder(x) + # print(x.shape) + x = self.decoder(x) + return x diff --git a/examples/GANwithSelf-taughtLearning/self-taught-learning/train.py b/examples/GANwithSelf-taughtLearning/self-taught-learning/train.py new file mode 100644 index 00000000..4ea32483 --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/self-taught-learning/train.py @@ -0,0 +1,104 @@ +from torchvision.utils import save_image +from models import Autoencoder +import torch +import cv2 +import numpy as np +from torch.utils.data import DataLoader +import os +import torch.nn as nn +import torch.optim as optim +import csv +import time +from util import load_yaml + + +class DatasetAutoEncoder(torch.utils.data.Dataset): + def __init__(self, fake_images_path): + self.img_dir = fake_images_path + + self.new_img_w = 2048 + + self.new_img_h = 1024 + + self.examples = [] + + file_names = os.listdir(fake_images_path) + + for file_name in file_names: + img_path = fake_images_path + file_name + self.examples.append({'img_path': img_path}) + + self.examples = self.examples[0:60] + + self.num_examples = len(self.examples) + + def __getitem__(self, index): + example = self.examples[index] + + img_path = example["img_path"] + img = cv2.imread(img_path, -1) + img = cv2.resize(img, (self.new_img_w, self.new_img_h), + interpolation=cv2.INTER_NEAREST) + img = img / 255.0 + img = img - np.array([0.485, 0.456, 0.406]) + img = img / np.array([0.229, 0.224, 0.225]) + img = np.transpose(img, (2, 0, 1)) + + img = img.astype(np.float32) + + img = torch.from_numpy(img) + img = torch.Tensor(img) + return img + + def __len__(self): + return self.num_examples + + +def save_decoded_image(img, name): + img = img.view(1, 3, 1024, 2048) + save_image(img, name) + + +if __name__ == '__main__': + configs = load_yaml('../config.yaml') + device = 'cuda' if torch.cuda.is_available() else 'cpu' + LEARNING_RATE = configs['STL'][1]['lr'] + NUM_EPOCHS = configs['STL'][0]['iter'] + batch_size = configs['STL'][2]['batch_size'] + name = configs['STL'][3]['name'] + save_dir = 'train_results/' + name + if not os.path.exists(save_dir): + os.mkdir(save_dir) + net = Autoencoder().to(device) + encoder_dataset = DatasetAutoEncoder(fake_images_path='../data/fake_imgs/') + encoder_loader = DataLoader(dataset=encoder_dataset, batch_size=batch_size, drop_last=True) + criterion = nn.MSELoss() + optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE) + train_loss = [] + with open('train_loss1.csv', 'w') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['epoch', 'loss']) + print('========start training============') + start_time = time.time() + for epoch in range(1, NUM_EPOCHS + 1): + print('======={}========'.format(epoch)) + running_loss = 0.0 + for batch_idx, img in enumerate(encoder_loader): + img = img.to(device) + optimizer.zero_grad() + outputs = net(img) + loss = criterion(outputs, img) + loss.backward() + optimizer.step() + running_loss += loss.item() + loss = running_loss / len(encoder_loader) + train_loss.append(loss) + print('Epoch {} of {}, Train Loss: {}'.format( + epoch, NUM_EPOCHS, loss)) + with open('train_loss1.csv', 'a') as csvfile: + writer = csv.writer(csvfile) + writer.writerow([epoch, loss]) + torch.save(net.encoder.state_dict(), save_dir + '/encoder{}.pth'.format(epoch)) + save_decoded_image(img[0].cpu().data, name=save_dir + '/original{}.png'.format(epoch)) + save_decoded_image(outputs[0].cpu().data, name=save_dir + '/decoded{}.png'.format(epoch)) + print('consume time: {}'.format(time.time() - start_time)) diff --git a/examples/GANwithSelf-taughtLearning/self-taught-learning/train_loss1.csv b/examples/GANwithSelf-taughtLearning/self-taught-learning/train_loss1.csv new file mode 100644 index 00000000..c99f3e62 --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/self-taught-learning/train_loss1.csv @@ -0,0 +1,5 @@ +epoch,loss +1,1.1943154335021973 +2,1.1841652989387512 +3,1.1787938475608826 +4,1.1748800873756409 diff --git a/examples/GANwithSelf-taughtLearning/util.py b/examples/GANwithSelf-taughtLearning/util.py new file mode 100644 index 00000000..005a9cab --- /dev/null +++ b/examples/GANwithSelf-taughtLearning/util.py @@ -0,0 +1,12 @@ +import yaml + + +def load_yaml(path): + with open(path) as f: + data = yaml.load(f, Loader=yaml.FullLoader) + # print(data) + return data + + +if __name__ == '__main__': + load_yaml('config.yaml')