Skip to content

Commit 28f86de

Browse files
authored
[Unittest]: Test for CascadeRoIHead (open-mmlab#141)
* Fix include and lib paths for onnxruntime. * Fixes for SSD export test * Add onnx2openvino and OpenVINODetector. Test models: ssd, retinanet, fcos, fsaf. * Add support for two-stage models: faster_rcnn, cascade_rcnn * Add doc * Add strip_doc_string for openvino. * Fix openvino preprocess. * Add OpenVINO to test_wrapper.py. * Fix * Add openvino_execute. * Removed preprocessing. * Fix onnxruntime cmake. * Rewrote postprocessing and forward, added docstrings and fixes. * Added device type change to OpenVINOWrapper. * Update forward_of_single_roi_extractor_dynamic_openvino and fix doc. * Update docs. * Add OpenVINODetector and onn2openvino tests. * Add input_info to onnx2openvino. * Add TestOpenVINOExporter and test_single_roi_extractor. * Moved get_input_shape_from_cfg to openvino_utils.py and added test. * Added test_cascade_roi_head. * Add backend.check_env() to tests. * Add OpenVINO to get_rewrite_outputs and to some tests in test_mmdet_models. * Moved test_single_roi_extractor to test_mmdet_models. * Removed TestOpenVINOExporter. * Added test_cascade_roi_head. * Fix onnxruntime outputs type.
1 parent 8987fa8 commit 28f86de

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed

tests/test_mmdet/test_mmdet_models.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,134 @@ def test_single_roi_extractor(backend_type):
366366
backend_output = backend_output.squeeze()
367367
assert np.allclose(
368368
model_output, backend_output, rtol=1e-03, atol=1e-05)
369+
370+
371+
def get_cascade_roi_head():
372+
"""CascadeRoIHead Config."""
373+
num_stages = 3
374+
stage_loss_weights = [1, 0.5, 0.25]
375+
bbox_roi_extractor = {
376+
'type': 'SingleRoIExtractor',
377+
'roi_layer': {
378+
'type': 'RoIAlign',
379+
'output_size': 7,
380+
'sampling_ratio': 0
381+
},
382+
'out_channels': 64,
383+
'featmap_strides': [4, 8, 16, 32]
384+
}
385+
all_target_stds = [[0.1, 0.1, 0.2, 0.2], [0.05, 0.05, 0.1, 0.1],
386+
[0.033, 0.033, 0.067, 0.067]]
387+
bbox_head = [{
388+
'type': 'Shared2FCBBoxHead',
389+
'in_channels': 64,
390+
'fc_out_channels': 1024,
391+
'roi_feat_size': 7,
392+
'num_classes': 80,
393+
'bbox_coder': {
394+
'type': 'DeltaXYWHBBoxCoder',
395+
'target_means': [0.0, 0.0, 0.0, 0.0],
396+
'target_stds': target_stds
397+
},
398+
'reg_class_agnostic': True,
399+
'loss_cls': {
400+
'type': 'CrossEntropyLoss',
401+
'use_sigmoid': False,
402+
'loss_weight': 1.0
403+
},
404+
'loss_bbox': {
405+
'type': 'SmoothL1Loss',
406+
'beta': 1.0,
407+
'loss_weight': 1.0
408+
}
409+
} for target_stds in all_target_stds]
410+
411+
test_cfg = mmcv.Config(
412+
dict(
413+
score_thr=0.05,
414+
nms=mmcv.Config(dict(type='nms', iou_threshold=0.5)),
415+
max_per_img=100))
416+
417+
from mmdet.models import CascadeRoIHead
418+
model = CascadeRoIHead(
419+
num_stages,
420+
stage_loss_weights,
421+
bbox_roi_extractor,
422+
bbox_head,
423+
test_cfg=test_cfg).eval()
424+
return model
425+
426+
427+
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'openvino'])
428+
def test_cascade_roi_head(backend_type):
429+
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
430+
431+
cascade_roi_head = get_cascade_roi_head()
432+
seed_everything(1234)
433+
x = [
434+
torch.rand((1, 64, 200, 304)),
435+
torch.rand((1, 64, 100, 152)),
436+
torch.rand((1, 64, 50, 76)),
437+
torch.rand((1, 64, 25, 38)),
438+
]
439+
proposals = torch.tensor([[587.8285, 52.1405, 886.2484, 341.5644, 0.5]])
440+
img_metas = mmcv.Config({
441+
'img_shape': torch.tensor([800, 1216]),
442+
'ori_shape': torch.tensor([800, 1216]),
443+
'scale_factor': torch.tensor([1, 1, 1, 1])
444+
})
445+
446+
model_inputs = {
447+
'x': x,
448+
'proposal_list': [proposals],
449+
'img_metas': [img_metas]
450+
}
451+
model_outputs = get_model_outputs(cascade_roi_head, 'simple_test',
452+
model_inputs)
453+
processed_model_outputs = []
454+
for output in model_outputs[0]:
455+
if output.shape == (0, 5):
456+
processed_model_outputs.append(np.zeros((1, 5)))
457+
else:
458+
processed_model_outputs.append(output)
459+
processed_model_outputs = np.array(processed_model_outputs).squeeze()
460+
processed_model_outputs = processed_model_outputs[None, :, :]
461+
462+
output_names = ['results']
463+
deploy_cfg = mmcv.Config(
464+
dict(
465+
backend_config=dict(type=backend_type),
466+
onnx_config=dict(output_names=output_names, input_shape=None),
467+
codebase_config=dict(
468+
type='mmdet',
469+
task='ObjectDetection',
470+
post_processing=dict(
471+
score_threshold=0.05,
472+
iou_threshold=0.5,
473+
max_output_boxes_per_class=200,
474+
pre_top_k=-1,
475+
keep_top_k=100,
476+
background_label_id=-1))))
477+
model_inputs = {'x': x, 'proposals': proposals.unsqueeze(0)}
478+
wrapped_model = WrapModel(
479+
cascade_roi_head, 'simple_test', img_metas=img_metas)
480+
backend_outputs, _ = get_rewrite_outputs(
481+
wrapped_model=wrapped_model,
482+
model_inputs=model_inputs,
483+
deploy_cfg=deploy_cfg)
484+
processed_backend_outputs = []
485+
if isinstance(backend_outputs, dict):
486+
processed_backend_outputs = [
487+
backend_outputs[name] for name in output_names
488+
if name in backend_outputs
489+
]
490+
elif isinstance(backend_outputs, (list, tuple)) and \
491+
backend_outputs[0].shape == (1, 0, 5):
492+
processed_backend_outputs = np.zeros((1, 80, 5))
493+
else:
494+
processed_backend_outputs = backend_outputs
495+
assert np.allclose(
496+
processed_model_outputs,
497+
processed_backend_outputs,
498+
rtol=1e-03,
499+
atol=1e-05)

0 commit comments

Comments
 (0)