Skip to content

Commit 911c794

Browse files
authored
Merge pull request #43 from YuxinZou/dist
fix mismatch bug in inference.py
2 parents 180e158 + 607a028 commit 911c794

1 file changed

Lines changed: 34 additions & 23 deletions

File tree

tools/inference.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vedaseg.runners import InferenceRunner
1212
from vedaseg.utils import Config
1313

14+
1415
CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
1516
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
1617
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
@@ -23,19 +24,23 @@
2324
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
2425

2526

26-
def inverse_resize(pred, image_shape):
27-
h, w, _ = image_shape
28-
reisze_h, resized_w = pred.shape[0], pred.shape[1]
29-
scale_factor = max(h / reisze_h, w / resized_w)
30-
pred = cv2.resize(pred, (
31-
int(reisze_h * scale_factor), int(reisze_h * scale_factor)),
32-
interpolation=cv2.INTER_NEAREST)
33-
return pred
27+
def calc_resized_shape(target_shape, image_shape):
28+
h, w = image_shape
29+
size_h, size_w = target_shape
30+
scale_factor = min(size_h / h, size_w / w)
31+
resized_h, resized_w = int(h * scale_factor), int(w * scale_factor)
32+
return resized_h, resized_w
33+
34+
35+
def inverse_resize(output, image_shape):
36+
h, w = image_shape
37+
output = cv2.resize(output, (w, h), interpolation=cv2.INTER_NEAREST)
38+
return output
3439

3540

36-
def inverse_pad(pred, image_shape):
37-
h, w, _ = image_shape
38-
return pred[:h, :w]
41+
def inverse_pad(output, image_shape):
42+
h, w = image_shape
43+
return output[:h, :w]
3944

4045

4146
def plot_result(img, mask, cover):
@@ -45,10 +50,10 @@ def plot_result(img, mask, cover):
4550
ax[0].set_title('image')
4651
ax[0].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
4752

48-
ax[1].set_title(f'mask')
53+
ax[1].set_title('mask')
4954
ax[1].imshow(mask)
5055

51-
ax[2].set_title(f'cover')
56+
ax[2].set_title('cover')
5257
ax[2].imshow(cv2.cvtColor(cover, cv2.COLOR_BGR2RGB))
5358
plt.show()
5459

@@ -93,16 +98,16 @@ def result(fname,
9398

9499

95100
def parse_args():
96-
parser = argparse.ArgumentParser(description='Inference a segmentatation model')
101+
parser = argparse.ArgumentParser(
102+
description='Inference a segmentatation model')
97103
parser.add_argument('config', type=str,
98104
help='config file path')
99-
parser.add_argument('checkpoint',
100-
type=str, help='checkpoint file path')
101-
parser.add_argument('image',
102-
type=str,
105+
parser.add_argument('checkpoint', type=str,
106+
help='checkpoint file path')
107+
parser.add_argument('image', type=str,
103108
help='input image path')
104109
parser.add_argument('--show', action='store_true',
105-
help='show result')
110+
help='show result images on screen')
106111
parser.add_argument('--need_resize', action='store_true',
107112
help='set true if there is LongestMaxSize in transform')
108113
parser.add_argument('--out', default='./result',
@@ -123,17 +128,23 @@ def main():
123128

124129
runner = InferenceRunner(inference_cfg, common_cfg)
125130
runner.load_checkpoint(args.checkpoint)
131+
126132
image = cv2.imread(args.image)
127133
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
128-
h, w, c = image.shape
129-
dummy_mask = np.zeros((h, w))
134+
image_shape = image.shape[:2]
135+
dummy_mask = np.zeros(image_shape)
136+
130137
output = runner(image, [dummy_mask])
131138
if multi_label:
132139
output = output.transpose((1, 2, 0))
140+
output_shape = output.shape[:2]
133141

134142
if args.need_resize:
135-
output = inverse_resize(output, image.shape)
136-
output = inverse_pad(output, image.shape)
143+
resized_shape = calc_resized_shape(output_shape, image_shape)
144+
output = inverse_pad(output, resized_shape)
145+
output = inverse_resize(output, image_shape)
146+
else:
147+
output = inverse_pad(output, image_shape)
137148

138149
result(args.image, output, multi_label=multi_label,
139150
classes=CLASSES, palette=PALETTE, show=args.show,

0 commit comments

Comments
 (0)