Skip to content

Commit d586af9

Browse files
RunningLeonlvhan028
authored andcommitted
[Enhancement]: Support fcn_unet deployment with dynamic shape (#251)
* support mmseg fcn+unet dynamic shape * add test * fix ci * fix units * resolve comments
1 parent e3e0b1c commit d586af9

File tree

7 files changed

+131
-11
lines changed

7 files changed

+131
-11
lines changed

docs/en/codebases/mmseg.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmsegmentation/bl
1515
| DeepLabV3 | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) |
1616
| DeepLabV3+ | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) |
1717
| Fast-SCNN[*](#static_shape) | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) |
18-
| UNet[*](#static_shape) | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
18+
| UNet | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
1919
| ANN[*](#static_shape) | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) |
2020
| APCNet | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) |
2121
| BiSeNetV1 | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) |

docs/en/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ The table below lists the models that are guaranteed to be exportable to other b
2929
| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) |
3030
| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) |
3131
| Fast-SCNN[*static](#note) | MMSegmentation | Y | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) |
32-
| UNet[*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
32+
| UNet | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
3333
| ANN[*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) |
3434
| APCNet | MMSegmentation | ? | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) |
3535
| BiSeNetV1 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) |
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .decode_heads import * # noqa: F401,F403
33
from .segmentors import * # noqa: F401,F403
4+
from .utils import * # noqa: F401,F403
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .up_conv_block import up_conv_block__forward
3+
4+
__all__ = ['up_conv_block__forward']
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
3+
import torch
4+
5+
from mmdeploy.core import FUNCTION_REWRITER
6+
from mmdeploy.utils import is_dynamic_shape
7+
8+
9+
@FUNCTION_REWRITER.register_rewriter(
10+
func_name='mmseg.models.utils.UpConvBlock.forward')
11+
def up_conv_block__forward(ctx, self, skip, x):
12+
"""Rewrite `forward` for default backend.
13+
14+
To support dynamic shape for UNet backbone,
15+
upsample feature maps with `size` instead of `scale_factor`
16+
17+
Args:
18+
ctx (ContextCaller): The context with additional information.
19+
self: The instance of the original class.
20+
skip (Tensor): Skip branch feature.
21+
x (Tensor): Input feature to be upsampled.
22+
23+
Returns:
24+
Tensor: Upsampled output feature map.
25+
"""
26+
from mmcv.cnn import ConvModule
27+
28+
# only valid when self.upsample is from build_upsample_layer
29+
if is_dynamic_shape(ctx.cfg) and not isinstance(self.upsample, ConvModule):
30+
# upsample with `size` instead of `scale_factor`
31+
from mmseg.ops import Upsample
32+
for c in self.upsample.interp_upsample:
33+
if isinstance(c, Upsample):
34+
c.size = skip.shape[-2:]
35+
c.scale_factor = None
36+
37+
x = self.upsample(x)
38+
out = torch.cat([skip, x], dim=1)
39+
out = self.conv_block(out)
40+
return out

mmdeploy/utils/test.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,12 @@ def get_flatten_inputs(
333333
if isinstance(value, torch.Tensor):
334334
flatten_inputs[name] = value
335335
elif isinstance(value, (list, tuple)):
336-
for i, tensor in enumerate(value):
337-
name_i = f'{name}_{i}'
338-
flatten_inputs[name_i] = tensor
336+
if len(value) == 1:
337+
flatten_inputs[name] = value[0]
338+
else:
339+
for i, tensor in enumerate(value):
340+
name_i = f'{name}_{i}'
341+
flatten_inputs[name_i] = tensor
339342
return flatten_inputs
340343

341344

@@ -358,15 +361,29 @@ def get_onnx_model(wrapped_model: nn.Module,
358361
patched_model = patch_model(
359362
wrapped_model, cfg=deploy_cfg, backend=backend.value)
360363
flatten_model_inputs = get_flatten_inputs(model_inputs)
361-
input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx']
364+
input_names = onnx_cfg.get('input_names', None)
365+
if input_names is None:
366+
input_names = [
367+
k for k, v in flatten_model_inputs.items() if k != 'ctx'
368+
]
362369
output_names = onnx_cfg.get('output_names', None)
363370
dynamic_axes = get_dynamic_axes(deploy_cfg, input_names)
364371

372+
class DummyModel(torch.nn.Module):
373+
374+
def __init__(self):
375+
super(DummyModel, self).__init__()
376+
self.model = patched_model
377+
378+
def forward(self, inputs: dict):
379+
return self.model(**inputs)
380+
381+
model = DummyModel().eval()
382+
365383
with RewriterContext(
366384
cfg=deploy_cfg, backend=backend.value, opset=11), torch.no_grad():
367385
torch.onnx.export(
368-
patched_model,
369-
tuple([v for k, v in model_inputs.items()]),
386+
model, (model_inputs, {}),
370387
onnx_file_path,
371388
export_params=True,
372389
input_names=input_names,
@@ -421,8 +438,13 @@ def get_backend_outputs(ir_file_path: str,
421438
"""
422439
backend = get_backend(deploy_cfg)
423440
flatten_model_inputs = get_flatten_inputs(model_inputs)
424-
input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx']
425-
output_names = get_ir_config(deploy_cfg).get('output_names', None)
441+
ir_config = get_ir_config(deploy_cfg)
442+
input_names = ir_config.get('input_names', None)
443+
output_names = ir_config.get('output_names', None)
444+
if input_names is None:
445+
input_names = [
446+
k for k, v in flatten_model_inputs.items() if k != 'ctx'
447+
]
426448

427449
# prepare backend model and input features
428450
if backend == Backend.TENSORRT:

tests/test_codebase/test_mmseg/test_mmseg_models.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
1010

1111
from mmdeploy.codebase import import_codebase
12-
from mmdeploy.utils import Backend, Codebase
12+
from mmdeploy.utils import Backend, Codebase, Task
1313
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
1414
get_rewrite_outputs)
1515

@@ -261,3 +261,56 @@ def test_emamodule_forward(backend):
261261
model_outputs.shape)
262262
assert torch.allclose(
263263
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)
264+
265+
266+
@pytest.mark.parametrize('is_dynamic_shape', [True, False])
267+
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
268+
def test_upconvblock_forward(backend, is_dynamic_shape):
269+
check_backend(backend)
270+
from mmseg.models.backbones.unet import BasicConvBlock
271+
from mmseg.models.utils import UpConvBlock
272+
273+
head = UpConvBlock(BasicConvBlock, 16, 8, 8).eval()
274+
dynamic_axes = {
275+
'x': {
276+
0: 'b',
277+
2: 'h',
278+
3: 'w'
279+
},
280+
'skip': {
281+
0: 'b',
282+
2: 'h',
283+
3: 'w'
284+
},
285+
'output': {
286+
0: 'b',
287+
2: 'h',
288+
3: 'w'
289+
},
290+
} if is_dynamic_shape else None
291+
deploy_cfg = mmcv.Config(
292+
dict(
293+
backend_config=dict(type=backend.value),
294+
onnx_config=dict(
295+
input_names=['skip', 'x'],
296+
output_names=['output'],
297+
dynamic_axes=dynamic_axes),
298+
codebase_config=dict(
299+
type=Codebase.MMSEG.value, task=Task.SEGMENTATION.value)))
300+
x = torch.randn(1, 16, 16, 16)
301+
skip = torch.randn(1, 8, 32, 32)
302+
model_inputs = {'x': x, 'skip': skip}
303+
with torch.no_grad():
304+
model_outputs = get_model_outputs(head, 'forward', model_inputs)
305+
306+
wrapped_model = WrapModel(head, 'forward')
307+
rewrite_outputs, is_backend_output = get_rewrite_outputs(
308+
wrapped_model=wrapped_model,
309+
model_inputs=model_inputs,
310+
deploy_cfg=deploy_cfg)
311+
if is_backend_output:
312+
rewrite_outputs = rewrite_outputs[0]
313+
rewrite_outputs = rewrite_outputs.to(model_outputs).reshape(
314+
model_outputs.shape)
315+
assert torch.allclose(
316+
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)

0 commit comments

Comments
 (0)