Skip to content

Commit 3ed85a0

Browse files
authored
Merge pull request #350 from luosiqi/dev-lifelong-n
Code check and base model improvement of unstructured lifelong learning framework
2 parents 53e5ae4 + 88827ef commit 3ed85a0

File tree

15 files changed

+256
-160
lines changed

15 files changed

+256
-160
lines changed

examples/lifelong_learning/RFNet/basemodel.py

Lines changed: 130 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@
33
import torch
44
from PIL import Image
55
import argparse
6+
from tqdm import tqdm
7+
8+
from torchvision import transforms
9+
from torch.utils.data import DataLoader
10+
from sedna.common.config import Context
11+
from sedna.common.file_ops import FileOps
12+
from sedna.common.log import LOGGER
13+
14+
from utils.metrics import Evaluator
615
from train import Trainer
716
from eval import Validator
8-
from tqdm import tqdm
917
from eval import load_my_state_dict
10-
from utils.metrics import Evaluator
1118
from dataloaders import make_data_loader
1219
from dataloaders import custom_transforms as tr
13-
from torchvision import transforms
14-
from sedna.common.class_factory import ClassType, ClassFactory
15-
from sedna.common.config import Context
16-
from sedna.datasources import TxtDataParse
17-
from torch.utils.data import DataLoader
18-
from sedna.common.file_ops import FileOps
19-
from utils.lr_scheduler import LR_Scheduler
2020

2121
def preprocess(image_urls):
2222
transformed_images = []
@@ -34,7 +34,10 @@ def preprocess(image_urls):
3434
composed_transforms = transforms.Compose([
3535
# tr.CropBlackArea(),
3636
# tr.FixedResize(size=self.args.crop_size),
37-
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
37+
tr.Normalize(
38+
mean=(
39+
0.485, 0.456, 0.406), std=(
40+
0.229, 0.224, 0.225)),
3841
tr.ToTensor()])
3942

4043
transformed_images.append((composed_transforms(sample), img_path))
@@ -52,78 +55,113 @@ def __init__(self, **kwargs):
5255
self.train_args.no_val = kwargs.get("no_val", True)
5356
# self.train_args.resume = Context.get_parameters("PRETRAINED_MODEL_URL", None)
5457
self.trainer = None
55-
56-
label_save_dir = Context.get_parameters("INFERENCE_RESULT_DIR", "./inference_results")
57-
self.val_args.color_label_save_path = os.path.join(label_save_dir, "color")
58-
self.val_args.merge_label_save_path = os.path.join(label_save_dir, "merge")
58+
self.train_model_url = None
59+
60+
label_save_dir = Context.get_parameters(
61+
"INFERENCE_RESULT_DIR", "./inference_results")
62+
self.val_args.color_label_save_path = os.path.join(
63+
label_save_dir, "color")
64+
self.val_args.merge_label_save_path = os.path.join(
65+
label_save_dir, "merge")
5966
self.val_args.label_save_path = os.path.join(label_save_dir, "label")
67+
self.val_args.save_predicted_image = kwargs.get(
68+
"save_predicted_image", "true").lower()
6069
self.validator = Validator(self.val_args)
6170

62-
def train(self, train_data, valid_data=None, **kwargs):
71+
def train(self, train_data, valid_data=None, **kwargs):
6372
self.trainer = Trainer(self.train_args, train_data=train_data)
6473
print("Total epoches:", self.trainer.args.epochs)
65-
for epoch in range(self.trainer.args.start_epoch, self.trainer.args.epochs):
74+
for epoch in range(
75+
self.trainer.args.start_epoch,
76+
self.trainer.args.epochs):
6677
if epoch == 0 and self.trainer.val_loader:
6778
self.trainer.validation(epoch)
6879
self.trainer.training(epoch)
6980

70-
if self.trainer.args.no_val and \
71-
(epoch % self.trainer.args.eval_interval == (self.trainer.args.eval_interval - 1)
72-
or epoch == self.trainer.args.epochs - 1):
73-
# save checkpoint when it meets eval_interval or the training finished
81+
if self.trainer.args.no_val and (
82+
epoch %
83+
self.trainer.args.eval_interval == (
84+
self.trainer.args.eval_interval -
85+
1) or epoch == self.trainer.args.epochs -
86+
1):
87+
# save checkpoint when it meets eval_interval or the training
88+
# finished
7489
is_best = False
75-
checkpoint_path = self.trainer.saver.save_checkpoint({
90+
self.train_model_url = self.trainer.saver.save_checkpoint({
7691
'epoch': epoch + 1,
7792
'state_dict': self.trainer.model.state_dict(),
7893
'optimizer': self.trainer.optimizer.state_dict(),
7994
'best_pred': self.trainer.best_pred,
8095
}, is_best)
81-
96+
8297
# if not self.trainer.args.no_val and \
8398
# epoch % self.train_args.eval_interval == (self.train_args.eval_interval - 1) \
8499
# and self.trainer.val_loader:
85100
# self.trainer.validation(epoch)
86101

87102
self.trainer.writer.close()
88103

89-
return checkpoint_path
104+
return self.train_model_url
90105

91106
def predict(self, data, **kwargs):
92107
if not isinstance(data[0][0], dict):
93108
data = preprocess(data)
94109

95-
if type(data) is np.ndarray:
110+
if isinstance(data, np.ndarray):
96111
data = data.tolist()
97112

98-
self.validator.test_loader = DataLoader(data, batch_size=self.val_args.test_batch_size, shuffle=False,
99-
pin_memory=True)
113+
self.validator.test_loader = DataLoader(
114+
data,
115+
batch_size=self.val_args.test_batch_size,
116+
shuffle=False,
117+
pin_memory=True)
100118
return self.validator.validate()
101119

102120
def evaluate(self, data, **kwargs):
103-
self.val_args.save_predicted_image = kwargs.get("save_predicted_image", True)
104121
samples = preprocess(data.x)
105122
predictions = self.predict(samples)
106123
return accuracy(data.y, predictions)
107124

108125
def load(self, model_url, **kwargs):
109126
if model_url:
110-
self.validator.new_state_dict = torch.load(model_url, map_location=torch.device("cpu"))
127+
self.validator.new_state_dict = torch.load(
128+
model_url, map_location=torch.device("cpu"))
129+
self.validator.model = load_my_state_dict(
130+
self.validator.model, self.validator.new_state_dict['state_dict'])
131+
111132
self.train_args.resume = model_url
112133
else:
113134
raise Exception("model url does not exist.")
114-
self.validator.model = load_my_state_dict(self.validator.model, self.validator.new_state_dict['state_dict'])
115135

116136
def save(self, model_path=None):
117-
# TODO: how to save unstructured data model
118-
pass
137+
# TODO: save unstructured data model
138+
if not model_path:
139+
LOGGER.warning(f"Not specify model path.")
140+
return self.train_model_url
141+
142+
return FileOps.upload(self.train_model_url, model_path)
143+
119144

120145
def train_args():
121146
parser = argparse.ArgumentParser(description="PyTorch RFNet Training")
122-
parser.add_argument('--depth', action="store_true", default=False,
123-
help='training with depth image or not (default: False)')
124-
parser.add_argument('--dataset', type=str, default='cityscapes',
125-
choices=['citylostfound', 'cityscapes', 'cityrand', 'target', 'xrlab', 'e1', 'mapillary'],
126-
help='dataset name (default: cityscapes)')
147+
parser.add_argument(
148+
'--depth',
149+
action="store_true",
150+
default=False,
151+
help='training with depth image or not (default: False)')
152+
parser.add_argument(
153+
'--dataset',
154+
type=str,
155+
default='cityscapes',
156+
choices=[
157+
'citylostfound',
158+
'cityscapes',
159+
'cityrand',
160+
'target',
161+
'xrlab',
162+
'e1',
163+
'mapillary'],
164+
help='dataset name (default: cityscapes)')
127165
parser.add_argument('--workers', type=int, default=4,
128166
metavar='N', help='dataloader threads')
129167
parser.add_argument('--base-size', type=int, default=1024,
@@ -149,8 +187,11 @@ def train_args():
149187
parser.add_argument('--test-batch-size', type=int, default=1,
150188
metavar='N', help='input batch size for \
151189
testing (default: auto)')
152-
parser.add_argument('--use-balanced-weights', action='store_true', default=False,
153-
help='whether to use balanced weights (default: True)')
190+
parser.add_argument(
191+
'--use-balanced-weights',
192+
action='store_true',
193+
default=False,
194+
help='whether to use balanced weights (default: True)')
154195
parser.add_argument('--num-class', type=int, default=24,
155196
help='number of training classes (default: 24')
156197
# optimizer params
@@ -164,8 +205,11 @@ def train_args():
164205
parser.add_argument('--weight-decay', type=float, default=2.5e-5,
165206
metavar='M', help='w-decay (default: 5e-4)')
166207
# cuda, seed and logging
167-
parser.add_argument('--no-cuda', action='store_true', default=
168-
False, help='disables CUDA training')
208+
parser.add_argument(
209+
'--no-cuda',
210+
action='store_true',
211+
default=False,
212+
help='disables CUDA training')
169213
parser.add_argument('--gpu-ids', type=str, default='0',
170214
help='use which gpu to train, must be a \
171215
comma-separated list of integers only (default=0)')
@@ -193,7 +237,8 @@ def train_args():
193237
try:
194238
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
195239
except ValueError:
196-
raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')
240+
raise ValueError(
241+
'Argument --gpu_ids must be a comma-separated list of integers only')
197242

198243
if args.epochs is None:
199244
epoches = {
@@ -214,7 +259,8 @@ def train_args():
214259
'citylostfound': 0.0001,
215260
'cityrand': 0.0001
216261
}
217-
args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size
262+
args.lr = lrs[args.dataset.lower()] / \
263+
(4 * len(args.gpu_ids)) * args.batch_size
218264

219265
if args.checkname is None:
220266
args.checkname = 'RFNet'
@@ -223,11 +269,19 @@ def train_args():
223269

224270
return args
225271

272+
226273
def val_args():
227274
parser = argparse.ArgumentParser(description="PyTorch RFNet validation")
228-
parser.add_argument('--dataset', type=str, default='cityscapes',
229-
choices=['citylostfound', 'cityscapes', 'xrlab', 'mapillary'],
230-
help='dataset name (default: cityscapes)')
275+
parser.add_argument(
276+
'--dataset',
277+
type=str,
278+
default='cityscapes',
279+
choices=[
280+
'citylostfound',
281+
'cityscapes',
282+
'xrlab',
283+
'mapillary'],
284+
help='dataset name (default: cityscapes)')
231285
parser.add_argument('--workers', type=int, default=4,
232286
metavar='N', help='dataloader threads')
233287
parser.add_argument('--base-size', type=int, default=1024,
@@ -243,17 +297,27 @@ def val_args():
243297
metavar='N', help='input batch size for \
244298
testing (default: auto)')
245299
parser.add_argument('--num-class', type=int, default=24,
246-
help='number of training classes (default: 24')
247-
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
300+
help='number of training classes (default: 24)')
301+
parser.add_argument(
302+
'--no-cuda',
303+
action='store_true',
304+
default=False,
305+
help='disables CUDA training')
248306
parser.add_argument('--gpu-ids', type=str, default='0',
249307
help='use which gpu to train, must be a \
250308
comma-separated list of integers only (default=0)')
251309
parser.add_argument('--checkname', type=str, default=None,
252310
help='set the checkpoint name')
253-
parser.add_argument('--weight-path', type=str, default="./models/530_exp3_2.pth",
254-
help='enter your path of the weight')
255-
parser.add_argument('--save-predicted-image', action='store_true', default=False,
256-
help='save predicted images')
311+
parser.add_argument(
312+
'--weight-path',
313+
type=str,
314+
default="./models/530_exp3_2.pth",
315+
help='enter your path of the weight')
316+
parser.add_argument(
317+
'--save-predicted-image',
318+
action='store_true',
319+
default=False,
320+
help='save predicted images')
257321
parser.add_argument('--color-label-save-path', type=str,
258322
default='./test/color/',
259323
help='path to save label')
@@ -262,19 +326,29 @@ def val_args():
262326
help='path to save merged label')
263327
parser.add_argument('--label-save-path', type=str, default='./test/label/',
264328
help='path to save merged label')
265-
parser.add_argument('--merge', action='store_true', default=True, help='merge image and label')
266-
parser.add_argument('--depth', action='store_true', default=False, help='add depth image or not')
329+
parser.add_argument(
330+
'--merge',
331+
action='store_true',
332+
default=False,
333+
help='merge image and label')
334+
parser.add_argument(
335+
'--depth',
336+
action='store_true',
337+
default=False,
338+
help='add depth image or not')
267339

268340
args = parser.parse_args()
269341
args.cuda = not args.no_cuda and torch.cuda.is_available()
270342
if args.cuda:
271343
try:
272344
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
273345
except ValueError:
274-
raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')
346+
raise ValueError(
347+
'Argument --gpu_ids must be a comma-separated list of integers only')
275348

276349
return args
277350

351+
278352
def accuracy(y_true, y_pred, **kwargs):
279353
args = val_args()
280354
_, _, test_loader, num_class = make_data_loader(args, test_data=y_true)
@@ -291,7 +365,7 @@ def accuracy(y_true, y_pred, **kwargs):
291365
if args.depth:
292366
depth = depth.cuda()
293367

294-
target[target > evaluator.num_class-1] = 255
368+
target[target > evaluator.num_class - 1] = 255
295369
target = target.cpu().numpy()
296370
# Add batch sample into evaluator
297371
evaluator.add_batch(target, y_pred[i])
@@ -305,6 +379,7 @@ def accuracy(y_true, y_pred, **kwargs):
305379
print("CPA:{}, mIoU:{}, fwIoU: {}".format(CPA, mIoU, FWIoU))
306380
return CPA
307381

382+
308383
if __name__ == '__main__':
309384
model_path = "/tmp/RFNet/"
310385
if not os.path.exists(model_path):

examples/lifelong_learning/RFNet/dataloaders/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def decode_segmap(label_mask, dataset, plot=False):
2525
n_classes = 21
2626
label_colours = get_pascal_labels()
2727
elif dataset == 'cityscapes':
28-
n_classes = 19
28+
n_classes = 24
2929
label_colours = get_cityscapes_labels()
3030
elif dataset == 'target':
3131
n_classes = 24

examples/lifelong_learning/RFNet/eval.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ def validate(self):
6060
if self.args.depth:
6161
image, depth, target = sample['image'], sample['depth'], sample['label']
6262
else:
63-
# spec = time.time()
64-
image, target = sample['image'], sample['label']
63+
image, target = sample['image'], sample['label']
6564

6665
if self.args.cuda:
6766
image = image.cuda()
@@ -82,7 +81,7 @@ def validate(self):
8281
pred = np.argmax(pred, axis=1)
8382
predictions.append(pred)
8483

85-
if not self.args.save_predicted_image:
84+
if self.args.save_predicted_image != "true":
8685
continue
8786

8887
pre_colors = Colorize()(torch.max(output, 1)[1].detach().cpu().byte())

0 commit comments

Comments
 (0)