1111from vedaseg .runners import InferenceRunner
1212from vedaseg .utils import Config
1313
14+
1415CLASSES = ('background' , 'aeroplane' , 'bicycle' , 'bird' , 'boat' , 'bottle' ,
1516 'bus' , 'car' , 'cat' , 'chair' , 'cow' , 'diningtable' , 'dog' ,
1617 'horse' , 'motorbike' , 'person' , 'pottedplant' , 'sheep' , 'sofa' ,
2324 [128 , 64 , 0 ], [0 , 192 , 0 ], [128 , 192 , 0 ], [0 , 64 , 128 ]]
2425
2526
26- def inverse_resize (pred , image_shape ):
27- h , w , _ = image_shape
28- reisze_h , resized_w = pred .shape [0 ], pred .shape [1 ]
29- scale_factor = max (h / reisze_h , w / resized_w )
30- pred = cv2 .resize (pred , (
31- int (reisze_h * scale_factor ), int (reisze_h * scale_factor )),
32- interpolation = cv2 .INTER_NEAREST )
33- return pred
27+ def calc_resized_shape (target_shape , image_shape ):
28+ h , w = image_shape
29+ size_h , size_w = target_shape
30+ scale_factor = min (size_h / h , size_w / w )
31+ resized_h , resized_w = int (h * scale_factor ), int (w * scale_factor )
32+ return resized_h , resized_w
33+
34+
35+ def inverse_resize (output , image_shape ):
36+ h , w = image_shape
37+ output = cv2 .resize (output , (w , h ), interpolation = cv2 .INTER_NEAREST )
38+ return output
3439
3540
36- def inverse_pad (pred , image_shape ):
37- h , w , _ = image_shape
38- return pred [:h , :w ]
41+ def inverse_pad (output , image_shape ):
42+ h , w = image_shape
43+ return output [:h , :w ]
3944
4045
4146def plot_result (img , mask , cover ):
@@ -45,10 +50,10 @@ def plot_result(img, mask, cover):
4550 ax [0 ].set_title ('image' )
4651 ax [0 ].imshow (cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB ))
4752
48- ax [1 ].set_title (f 'mask' )
53+ ax [1 ].set_title ('mask' )
4954 ax [1 ].imshow (mask )
5055
51- ax [2 ].set_title (f 'cover' )
56+ ax [2 ].set_title ('cover' )
5257 ax [2 ].imshow (cv2 .cvtColor (cover , cv2 .COLOR_BGR2RGB ))
5358 plt .show ()
5459
@@ -93,16 +98,16 @@ def result(fname,
9398
9499
95100def parse_args ():
96- parser = argparse .ArgumentParser (description = 'Inference a segmentatation model' )
101+ parser = argparse .ArgumentParser (
102+ description = 'Inference a segmentatation model' )
97103 parser .add_argument ('config' , type = str ,
98104 help = 'config file path' )
99- parser .add_argument ('checkpoint' ,
100- type = str , help = 'checkpoint file path' )
101- parser .add_argument ('image' ,
102- type = str ,
105+ parser .add_argument ('checkpoint' , type = str ,
106+ help = 'checkpoint file path' )
107+ parser .add_argument ('image' , type = str ,
103108 help = 'input image path' )
104109 parser .add_argument ('--show' , action = 'store_true' ,
105- help = 'show result' )
110+ help = 'show result images on screen ' )
106111 parser .add_argument ('--need_resize' , action = 'store_true' ,
107112 help = 'set true if there is LongestMaxSize in transform' )
108113 parser .add_argument ('--out' , default = './result' ,
@@ -123,17 +128,23 @@ def main():
123128
124129 runner = InferenceRunner (inference_cfg , common_cfg )
125130 runner .load_checkpoint (args .checkpoint )
131+
126132 image = cv2 .imread (args .image )
127133 image = cv2 .cvtColor (image , cv2 .COLOR_BGR2RGB )
128- h , w , c = image .shape
129- dummy_mask = np .zeros ((h , w ))
134+ image_shape = image .shape [:2 ]
135+ dummy_mask = np .zeros (image_shape )
136+
130137 output = runner (image , [dummy_mask ])
131138 if multi_label :
132139 output = output .transpose ((1 , 2 , 0 ))
140+ output_shape = output .shape [:2 ]
133141
134142 if args .need_resize :
135- output = inverse_resize (output , image .shape )
136- output = inverse_pad (output , image .shape )
143+ resized_shape = calc_resized_shape (output_shape , image_shape )
144+ output = inverse_pad (output , resized_shape )
145+ output = inverse_resize (output , image_shape )
146+ else :
147+ output = inverse_pad (output , image_shape )
137148
138149 result (args .image , output , multi_label = multi_label ,
139150 classes = CLASSES , palette = PALETTE , show = args .show ,
0 commit comments