Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 130 additions & 55 deletions examples/lifelong_learning/RFNet/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
import torch
from PIL import Image
import argparse
from tqdm import tqdm

from torchvision import transforms
from torch.utils.data import DataLoader
from sedna.common.config import Context
from sedna.common.file_ops import FileOps
from sedna.common.log import LOGGER

from utils.metrics import Evaluator
from train import Trainer
from eval import Validator
from tqdm import tqdm
from eval import load_my_state_dict
from utils.metrics import Evaluator
from dataloaders import make_data_loader
from dataloaders import custom_transforms as tr
from torchvision import transforms
from sedna.common.class_factory import ClassType, ClassFactory
from sedna.common.config import Context
from sedna.datasources import TxtDataParse
from torch.utils.data import DataLoader
from sedna.common.file_ops import FileOps
from utils.lr_scheduler import LR_Scheduler

def preprocess(image_urls):
transformed_images = []
Expand All @@ -34,7 +34,10 @@ def preprocess(image_urls):
composed_transforms = transforms.Compose([
# tr.CropBlackArea(),
# tr.FixedResize(size=self.args.crop_size),
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
tr.Normalize(
mean=(
0.485, 0.456, 0.406), std=(
0.229, 0.224, 0.225)),
tr.ToTensor()])

transformed_images.append((composed_transforms(sample), img_path))
Expand All @@ -52,78 +55,113 @@ def __init__(self, **kwargs):
self.train_args.no_val = kwargs.get("no_val", True)
# self.train_args.resume = Context.get_parameters("PRETRAINED_MODEL_URL", None)
self.trainer = None

label_save_dir = Context.get_parameters("INFERENCE_RESULT_DIR", "./inference_results")
self.val_args.color_label_save_path = os.path.join(label_save_dir, "color")
self.val_args.merge_label_save_path = os.path.join(label_save_dir, "merge")
self.train_model_url = None

label_save_dir = Context.get_parameters(
"INFERENCE_RESULT_DIR", "./inference_results")
self.val_args.color_label_save_path = os.path.join(
label_save_dir, "color")
self.val_args.merge_label_save_path = os.path.join(
label_save_dir, "merge")
self.val_args.label_save_path = os.path.join(label_save_dir, "label")
self.val_args.save_predicted_image = kwargs.get(
"save_predicted_image", "true").lower()
self.validator = Validator(self.val_args)

def train(self, train_data, valid_data=None, **kwargs):
def train(self, train_data, valid_data=None, **kwargs):
self.trainer = Trainer(self.train_args, train_data=train_data)
print("Total epoches:", self.trainer.args.epochs)
for epoch in range(self.trainer.args.start_epoch, self.trainer.args.epochs):
for epoch in range(
self.trainer.args.start_epoch,
self.trainer.args.epochs):
if epoch == 0 and self.trainer.val_loader:
self.trainer.validation(epoch)
self.trainer.training(epoch)

if self.trainer.args.no_val and \
(epoch % self.trainer.args.eval_interval == (self.trainer.args.eval_interval - 1)
or epoch == self.trainer.args.epochs - 1):
# save checkpoint when it meets eval_interval or the training finished
if self.trainer.args.no_val and (
epoch %
self.trainer.args.eval_interval == (
self.trainer.args.eval_interval -
1) or epoch == self.trainer.args.epochs -
1):
# save checkpoint when it meets eval_interval or the training
# finished
is_best = False
checkpoint_path = self.trainer.saver.save_checkpoint({
self.train_model_url = self.trainer.saver.save_checkpoint({
'epoch': epoch + 1,
'state_dict': self.trainer.model.state_dict(),
'optimizer': self.trainer.optimizer.state_dict(),
'best_pred': self.trainer.best_pred,
}, is_best)

# if not self.trainer.args.no_val and \
# epoch % self.train_args.eval_interval == (self.train_args.eval_interval - 1) \
# and self.trainer.val_loader:
# self.trainer.validation(epoch)

self.trainer.writer.close()

return checkpoint_path
return self.train_model_url

def predict(self, data, **kwargs):
if not isinstance(data[0][0], dict):
data = preprocess(data)

if type(data) is np.ndarray:
if isinstance(data, np.ndarray):
data = data.tolist()

self.validator.test_loader = DataLoader(data, batch_size=self.val_args.test_batch_size, shuffle=False,
pin_memory=True)
self.validator.test_loader = DataLoader(
data,
batch_size=self.val_args.test_batch_size,
shuffle=False,
pin_memory=True)
return self.validator.validate()

def evaluate(self, data, **kwargs):
self.val_args.save_predicted_image = kwargs.get("save_predicted_image", True)
samples = preprocess(data.x)
predictions = self.predict(samples)
return accuracy(data.y, predictions)

def load(self, model_url, **kwargs):
if model_url:
self.validator.new_state_dict = torch.load(model_url, map_location=torch.device("cpu"))
self.validator.new_state_dict = torch.load(
model_url, map_location=torch.device("cpu"))
self.validator.model = load_my_state_dict(
self.validator.model, self.validator.new_state_dict['state_dict'])

self.train_args.resume = model_url
else:
raise Exception("model url does not exist.")
self.validator.model = load_my_state_dict(self.validator.model, self.validator.new_state_dict['state_dict'])

def save(self, model_path=None):
# TODO: how to save unstructured data model
pass
# TODO: save unstructured data model
if not model_path:
LOGGER.warning(f"Not specify model path.")
return self.train_model_url

return FileOps.upload(self.train_model_url, model_path)


def train_args():
parser = argparse.ArgumentParser(description="PyTorch RFNet Training")
parser.add_argument('--depth', action="store_true", default=False,
help='training with depth image or not (default: False)')
parser.add_argument('--dataset', type=str, default='cityscapes',
choices=['citylostfound', 'cityscapes', 'cityrand', 'target', 'xrlab', 'e1', 'mapillary'],
help='dataset name (default: cityscapes)')
parser.add_argument(
'--depth',
action="store_true",
default=False,
help='training with depth image or not (default: False)')
parser.add_argument(
'--dataset',
type=str,
default='cityscapes',
choices=[
'citylostfound',
'cityscapes',
'cityrand',
'target',
'xrlab',
'e1',
'mapillary'],
help='dataset name (default: cityscapes)')
parser.add_argument('--workers', type=int, default=4,
metavar='N', help='dataloader threads')
parser.add_argument('--base-size', type=int, default=1024,
Expand All @@ -149,8 +187,11 @@ def train_args():
parser.add_argument('--test-batch-size', type=int, default=1,
metavar='N', help='input batch size for \
testing (default: auto)')
parser.add_argument('--use-balanced-weights', action='store_true', default=False,
help='whether to use balanced weights (default: True)')
parser.add_argument(
'--use-balanced-weights',
action='store_true',
default=False,
help='whether to use balanced weights (default: True)')
parser.add_argument('--num-class', type=int, default=24,
help='number of training classes (default: 24')
# optimizer params
Expand All @@ -164,8 +205,11 @@ def train_args():
parser.add_argument('--weight-decay', type=float, default=2.5e-5,
metavar='M', help='w-decay (default: 5e-4)')
# cuda, seed and logging
parser.add_argument('--no-cuda', action='store_true', default=
False, help='disables CUDA training')
parser.add_argument(
'--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument('--gpu-ids', type=str, default='0',
help='use which gpu to train, must be a \
comma-separated list of integers only (default=0)')
Expand Down Expand Up @@ -193,7 +237,8 @@ def train_args():
try:
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
except ValueError:
raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')
raise ValueError(
'Argument --gpu_ids must be a comma-separated list of integers only')

if args.epochs is None:
epoches = {
Expand All @@ -214,7 +259,8 @@ def train_args():
'citylostfound': 0.0001,
'cityrand': 0.0001
}
args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size
args.lr = lrs[args.dataset.lower()] / \
(4 * len(args.gpu_ids)) * args.batch_size

if args.checkname is None:
args.checkname = 'RFNet'
Expand All @@ -223,11 +269,19 @@ def train_args():

return args


def val_args():
parser = argparse.ArgumentParser(description="PyTorch RFNet validation")
parser.add_argument('--dataset', type=str, default='cityscapes',
choices=['citylostfound', 'cityscapes', 'xrlab', 'mapillary'],
help='dataset name (default: cityscapes)')
parser.add_argument(
'--dataset',
type=str,
default='cityscapes',
choices=[
'citylostfound',
'cityscapes',
'xrlab',
'mapillary'],
help='dataset name (default: cityscapes)')
parser.add_argument('--workers', type=int, default=4,
metavar='N', help='dataloader threads')
parser.add_argument('--base-size', type=int, default=1024,
Expand All @@ -243,17 +297,27 @@ def val_args():
metavar='N', help='input batch size for \
testing (default: auto)')
parser.add_argument('--num-class', type=int, default=24,
help='number of training classes (default: 24')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
help='number of training classes (default: 24)')
parser.add_argument(
'--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument('--gpu-ids', type=str, default='0',
help='use which gpu to train, must be a \
comma-separated list of integers only (default=0)')
parser.add_argument('--checkname', type=str, default=None,
help='set the checkpoint name')
parser.add_argument('--weight-path', type=str, default="./models/530_exp3_2.pth",
help='enter your path of the weight')
parser.add_argument('--save-predicted-image', action='store_true', default=False,
help='save predicted images')
parser.add_argument(
'--weight-path',
type=str,
default="./models/530_exp3_2.pth",
help='enter your path of the weight')
parser.add_argument(
'--save-predicted-image',
action='store_true',
default=False,
help='save predicted images')
parser.add_argument('--color-label-save-path', type=str,
default='./test/color/',
help='path to save label')
Expand All @@ -262,19 +326,29 @@ def val_args():
help='path to save merged label')
parser.add_argument('--label-save-path', type=str, default='./test/label/',
help='path to save merged label')
parser.add_argument('--merge', action='store_true', default=True, help='merge image and label')
parser.add_argument('--depth', action='store_true', default=False, help='add depth image or not')
parser.add_argument(
'--merge',
action='store_true',
default=False,
help='merge image and label')
parser.add_argument(
'--depth',
action='store_true',
default=False,
help='add depth image or not')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
try:
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
except ValueError:
raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')
raise ValueError(
'Argument --gpu_ids must be a comma-separated list of integers only')

return args


def accuracy(y_true, y_pred, **kwargs):
args = val_args()
_, _, test_loader, num_class = make_data_loader(args, test_data=y_true)
Expand All @@ -291,7 +365,7 @@ def accuracy(y_true, y_pred, **kwargs):
if args.depth:
depth = depth.cuda()

target[target > evaluator.num_class-1] = 255
target[target > evaluator.num_class - 1] = 255
target = target.cpu().numpy()
# Add batch sample into evaluator
evaluator.add_batch(target, y_pred[i])
Expand All @@ -305,6 +379,7 @@ def accuracy(y_true, y_pred, **kwargs):
print("CPA:{}, mIoU:{}, fwIoU: {}".format(CPA, mIoU, FWIoU))
return CPA


if __name__ == '__main__':
model_path = "/tmp/RFNet/"
if not os.path.exists(model_path):
Expand Down
2 changes: 1 addition & 1 deletion examples/lifelong_learning/RFNet/dataloaders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def decode_segmap(label_mask, dataset, plot=False):
n_classes = 21
label_colours = get_pascal_labels()
elif dataset == 'cityscapes':
n_classes = 19
n_classes = 24
label_colours = get_cityscapes_labels()
elif dataset == 'target':
n_classes = 24
Expand Down
5 changes: 2 additions & 3 deletions examples/lifelong_learning/RFNet/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def validate(self):
if self.args.depth:
image, depth, target = sample['image'], sample['depth'], sample['label']
else:
# spec = time.time()
image, target = sample['image'], sample['label']
image, target = sample['image'], sample['label']

if self.args.cuda:
image = image.cuda()
Expand All @@ -82,7 +81,7 @@ def validate(self):
pred = np.argmax(pred, axis=1)
predictions.append(pred)

if not self.args.save_predicted_image:
if self.args.save_predicted_image != "true":
continue

pre_colors = Colorize()(torch.max(output, 1)[1].detach().cpu().byte())
Expand Down
Loading