33import torch
44from PIL import Image
55import argparse
6+ from tqdm import tqdm
7+
8+ from torchvision import transforms
9+ from torch .utils .data import DataLoader
10+ from sedna .common .config import Context
11+ from sedna .common .file_ops import FileOps
12+ from sedna .common .log import LOGGER
13+
14+ from utils .metrics import Evaluator
615from train import Trainer
716from eval import Validator
8- from tqdm import tqdm
917from eval import load_my_state_dict
10- from utils .metrics import Evaluator
1118from dataloaders import make_data_loader
1219from dataloaders import custom_transforms as tr
13- from torchvision import transforms
14- from sedna .common .class_factory import ClassType , ClassFactory
15- from sedna .common .config import Context
16- from sedna .datasources import TxtDataParse
17- from torch .utils .data import DataLoader
18- from sedna .common .file_ops import FileOps
19- from utils .lr_scheduler import LR_Scheduler
2020
2121def preprocess (image_urls ):
2222 transformed_images = []
@@ -34,7 +34,10 @@ def preprocess(image_urls):
3434 composed_transforms = transforms .Compose ([
3535 # tr.CropBlackArea(),
3636 # tr.FixedResize(size=self.args.crop_size),
37- tr .Normalize (mean = (0.485 , 0.456 , 0.406 ), std = (0.229 , 0.224 , 0.225 )),
37+ tr .Normalize (
38+ mean = (
39+ 0.485 , 0.456 , 0.406 ), std = (
40+ 0.229 , 0.224 , 0.225 )),
3841 tr .ToTensor ()])
3942
4043 transformed_images .append ((composed_transforms (sample ), img_path ))
@@ -52,78 +55,113 @@ def __init__(self, **kwargs):
5255 self .train_args .no_val = kwargs .get ("no_val" , True )
5356 # self.train_args.resume = Context.get_parameters("PRETRAINED_MODEL_URL", None)
5457 self .trainer = None
55-
56- label_save_dir = Context .get_parameters ("INFERENCE_RESULT_DIR" , "./inference_results" )
57- self .val_args .color_label_save_path = os .path .join (label_save_dir , "color" )
58- self .val_args .merge_label_save_path = os .path .join (label_save_dir , "merge" )
58+ self .train_model_url = None
59+
60+ label_save_dir = Context .get_parameters (
61+ "INFERENCE_RESULT_DIR" , "./inference_results" )
62+ self .val_args .color_label_save_path = os .path .join (
63+ label_save_dir , "color" )
64+ self .val_args .merge_label_save_path = os .path .join (
65+ label_save_dir , "merge" )
5966 self .val_args .label_save_path = os .path .join (label_save_dir , "label" )
67+ self .val_args .save_predicted_image = kwargs .get (
68+ "save_predicted_image" , "true" ).lower ()
6069 self .validator = Validator (self .val_args )
6170
62- def train (self , train_data , valid_data = None , ** kwargs ):
71+ def train (self , train_data , valid_data = None , ** kwargs ):
6372 self .trainer = Trainer (self .train_args , train_data = train_data )
6473 print ("Total epoches:" , self .trainer .args .epochs )
65- for epoch in range (self .trainer .args .start_epoch , self .trainer .args .epochs ):
74+ for epoch in range (
75+ self .trainer .args .start_epoch ,
76+ self .trainer .args .epochs ):
6677 if epoch == 0 and self .trainer .val_loader :
6778 self .trainer .validation (epoch )
6879 self .trainer .training (epoch )
6980
70- if self .trainer .args .no_val and \
71- (epoch % self .trainer .args .eval_interval == (self .trainer .args .eval_interval - 1 )
72- or epoch == self .trainer .args .epochs - 1 ):
73- # save checkpoint when it meets eval_interval or the training finished
81+ if self .trainer .args .no_val and (
82+ epoch %
83+ self .trainer .args .eval_interval == (
84+ self .trainer .args .eval_interval -
85+ 1 ) or epoch == self .trainer .args .epochs -
86+ 1 ):
87+ # save checkpoint when it meets eval_interval or the training
88+ # finished
7489 is_best = False
75- checkpoint_path = self .trainer .saver .save_checkpoint ({
90+ self . train_model_url = self .trainer .saver .save_checkpoint ({
7691 'epoch' : epoch + 1 ,
7792 'state_dict' : self .trainer .model .state_dict (),
7893 'optimizer' : self .trainer .optimizer .state_dict (),
7994 'best_pred' : self .trainer .best_pred ,
8095 }, is_best )
81-
96+
8297 # if not self.trainer.args.no_val and \
8398 # epoch % self.train_args.eval_interval == (self.train_args.eval_interval - 1) \
8499 # and self.trainer.val_loader:
85100 # self.trainer.validation(epoch)
86101
87102 self .trainer .writer .close ()
88103
89- return checkpoint_path
104+ return self . train_model_url
90105
91106 def predict (self , data , ** kwargs ):
92107 if not isinstance (data [0 ][0 ], dict ):
93108 data = preprocess (data )
94109
95- if type (data ) is np .ndarray :
110+ if isinstance (data , np .ndarray ) :
96111 data = data .tolist ()
97112
98- self .validator .test_loader = DataLoader (data , batch_size = self .val_args .test_batch_size , shuffle = False ,
99- pin_memory = True )
113+ self .validator .test_loader = DataLoader (
114+ data ,
115+ batch_size = self .val_args .test_batch_size ,
116+ shuffle = False ,
117+ pin_memory = True )
100118 return self .validator .validate ()
101119
102120 def evaluate (self , data , ** kwargs ):
103- self .val_args .save_predicted_image = kwargs .get ("save_predicted_image" , True )
104121 samples = preprocess (data .x )
105122 predictions = self .predict (samples )
106123 return accuracy (data .y , predictions )
107124
108125 def load (self , model_url , ** kwargs ):
109126 if model_url :
110- self .validator .new_state_dict = torch .load (model_url , map_location = torch .device ("cpu" ))
127+ self .validator .new_state_dict = torch .load (
128+ model_url , map_location = torch .device ("cpu" ))
129+ self .validator .model = load_my_state_dict (
130+ self .validator .model , self .validator .new_state_dict ['state_dict' ])
131+
111132 self .train_args .resume = model_url
112133 else :
113134 raise Exception ("model url does not exist." )
114- self .validator .model = load_my_state_dict (self .validator .model , self .validator .new_state_dict ['state_dict' ])
115135
116136 def save (self , model_path = None ):
117- # TODO: how to save unstructured data model
118- pass
137+ # TODO: save unstructured data model
138+ if not model_path :
139+ LOGGER .warning (f"Not specify model path." )
140+ return self .train_model_url
141+
142+ return FileOps .upload (self .train_model_url , model_path )
143+
119144
120145def train_args ():
121146 parser = argparse .ArgumentParser (description = "PyTorch RFNet Training" )
122- parser .add_argument ('--depth' , action = "store_true" , default = False ,
123- help = 'training with depth image or not (default: False)' )
124- parser .add_argument ('--dataset' , type = str , default = 'cityscapes' ,
125- choices = ['citylostfound' , 'cityscapes' , 'cityrand' , 'target' , 'xrlab' , 'e1' , 'mapillary' ],
126- help = 'dataset name (default: cityscapes)' )
147+ parser .add_argument (
148+ '--depth' ,
149+ action = "store_true" ,
150+ default = False ,
151+ help = 'training with depth image or not (default: False)' )
152+ parser .add_argument (
153+ '--dataset' ,
154+ type = str ,
155+ default = 'cityscapes' ,
156+ choices = [
157+ 'citylostfound' ,
158+ 'cityscapes' ,
159+ 'cityrand' ,
160+ 'target' ,
161+ 'xrlab' ,
162+ 'e1' ,
163+ 'mapillary' ],
164+ help = 'dataset name (default: cityscapes)' )
127165 parser .add_argument ('--workers' , type = int , default = 4 ,
128166 metavar = 'N' , help = 'dataloader threads' )
129167 parser .add_argument ('--base-size' , type = int , default = 1024 ,
@@ -149,8 +187,11 @@ def train_args():
149187 parser .add_argument ('--test-batch-size' , type = int , default = 1 ,
150188 metavar = 'N' , help = 'input batch size for \
151189 testing (default: auto)' )
152- parser .add_argument ('--use-balanced-weights' , action = 'store_true' , default = False ,
153- help = 'whether to use balanced weights (default: True)' )
190+ parser .add_argument (
191+ '--use-balanced-weights' ,
192+ action = 'store_true' ,
193+ default = False ,
194+ help = 'whether to use balanced weights (default: True)' )
154195 parser .add_argument ('--num-class' , type = int , default = 24 ,
155196 help = 'number of training classes (default: 24' )
156197 # optimizer params
@@ -164,8 +205,11 @@ def train_args():
164205 parser .add_argument ('--weight-decay' , type = float , default = 2.5e-5 ,
165206 metavar = 'M' , help = 'w-decay (default: 5e-4)' )
166207 # cuda, seed and logging
167- parser .add_argument ('--no-cuda' , action = 'store_true' , default =
168- False , help = 'disables CUDA training' )
208+ parser .add_argument (
209+ '--no-cuda' ,
210+ action = 'store_true' ,
211+ default = False ,
212+ help = 'disables CUDA training' )
169213 parser .add_argument ('--gpu-ids' , type = str , default = '0' ,
170214 help = 'use which gpu to train, must be a \
171215 comma-separated list of integers only (default=0)' )
@@ -193,7 +237,8 @@ def train_args():
193237 try :
194238 args .gpu_ids = [int (s ) for s in args .gpu_ids .split (',' )]
195239 except ValueError :
196- raise ValueError ('Argument --gpu_ids must be a comma-separated list of integers only' )
240+ raise ValueError (
241+ 'Argument --gpu_ids must be a comma-separated list of integers only' )
197242
198243 if args .epochs is None :
199244 epoches = {
@@ -214,7 +259,8 @@ def train_args():
214259 'citylostfound' : 0.0001 ,
215260 'cityrand' : 0.0001
216261 }
217- args .lr = lrs [args .dataset .lower ()] / (4 * len (args .gpu_ids )) * args .batch_size
262+ args .lr = lrs [args .dataset .lower ()] / \
263+ (4 * len (args .gpu_ids )) * args .batch_size
218264
219265 if args .checkname is None :
220266 args .checkname = 'RFNet'
@@ -223,11 +269,19 @@ def train_args():
223269
224270 return args
225271
272+
226273def val_args ():
227274 parser = argparse .ArgumentParser (description = "PyTorch RFNet validation" )
228- parser .add_argument ('--dataset' , type = str , default = 'cityscapes' ,
229- choices = ['citylostfound' , 'cityscapes' , 'xrlab' , 'mapillary' ],
230- help = 'dataset name (default: cityscapes)' )
275+ parser .add_argument (
276+ '--dataset' ,
277+ type = str ,
278+ default = 'cityscapes' ,
279+ choices = [
280+ 'citylostfound' ,
281+ 'cityscapes' ,
282+ 'xrlab' ,
283+ 'mapillary' ],
284+ help = 'dataset name (default: cityscapes)' )
231285 parser .add_argument ('--workers' , type = int , default = 4 ,
232286 metavar = 'N' , help = 'dataloader threads' )
233287 parser .add_argument ('--base-size' , type = int , default = 1024 ,
@@ -243,17 +297,27 @@ def val_args():
243297 metavar = 'N' , help = 'input batch size for \
244298 testing (default: auto)' )
245299 parser .add_argument ('--num-class' , type = int , default = 24 ,
246- help = 'number of training classes (default: 24' )
247- parser .add_argument ('--no-cuda' , action = 'store_true' , default = False , help = 'disables CUDA training' )
300+ help = 'number of training classes (default: 24)' )
301+ parser .add_argument (
302+ '--no-cuda' ,
303+ action = 'store_true' ,
304+ default = False ,
305+ help = 'disables CUDA training' )
248306 parser .add_argument ('--gpu-ids' , type = str , default = '0' ,
249307 help = 'use which gpu to train, must be a \
250308 comma-separated list of integers only (default=0)' )
251309 parser .add_argument ('--checkname' , type = str , default = None ,
252310 help = 'set the checkpoint name' )
253- parser .add_argument ('--weight-path' , type = str , default = "./models/530_exp3_2.pth" ,
254- help = 'enter your path of the weight' )
255- parser .add_argument ('--save-predicted-image' , action = 'store_true' , default = False ,
256- help = 'save predicted images' )
311+ parser .add_argument (
312+ '--weight-path' ,
313+ type = str ,
314+ default = "./models/530_exp3_2.pth" ,
315+ help = 'enter your path of the weight' )
316+ parser .add_argument (
317+ '--save-predicted-image' ,
318+ action = 'store_true' ,
319+ default = False ,
320+ help = 'save predicted images' )
257321 parser .add_argument ('--color-label-save-path' , type = str ,
258322 default = './test/color/' ,
259323 help = 'path to save label' )
@@ -262,19 +326,29 @@ def val_args():
262326 help = 'path to save merged label' )
263327 parser .add_argument ('--label-save-path' , type = str , default = './test/label/' ,
264328 help = 'path to save merged label' )
265- parser .add_argument ('--merge' , action = 'store_true' , default = True , help = 'merge image and label' )
266- parser .add_argument ('--depth' , action = 'store_true' , default = False , help = 'add depth image or not' )
329+ parser .add_argument (
330+ '--merge' ,
331+ action = 'store_true' ,
332+ default = False ,
333+ help = 'merge image and label' )
334+ parser .add_argument (
335+ '--depth' ,
336+ action = 'store_true' ,
337+ default = False ,
338+ help = 'add depth image or not' )
267339
268340 args = parser .parse_args ()
269341 args .cuda = not args .no_cuda and torch .cuda .is_available ()
270342 if args .cuda :
271343 try :
272344 args .gpu_ids = [int (s ) for s in args .gpu_ids .split (',' )]
273345 except ValueError :
274- raise ValueError ('Argument --gpu_ids must be a comma-separated list of integers only' )
346+ raise ValueError (
347+ 'Argument --gpu_ids must be a comma-separated list of integers only' )
275348
276349 return args
277350
351+
278352def accuracy (y_true , y_pred , ** kwargs ):
279353 args = val_args ()
280354 _ , _ , test_loader , num_class = make_data_loader (args , test_data = y_true )
@@ -291,7 +365,7 @@ def accuracy(y_true, y_pred, **kwargs):
291365 if args .depth :
292366 depth = depth .cuda ()
293367
294- target [target > evaluator .num_class - 1 ] = 255
368+ target [target > evaluator .num_class - 1 ] = 255
295369 target = target .cpu ().numpy ()
296370 # Add batch sample into evaluator
297371 evaluator .add_batch (target , y_pred [i ])
@@ -305,6 +379,7 @@ def accuracy(y_true, y_pred, **kwargs):
305379 print ("CPA:{}, mIoU:{}, fwIoU: {}" .format (CPA , mIoU , FWIoU ))
306380 return CPA
307381
382+
308383if __name__ == '__main__' :
309384 model_path = "/tmp/RFNet/"
310385 if not os .path .exists (model_path ):
0 commit comments