Skip to content

Commit 8045480

Browse files
authored
Add a new object detector, Lite-DINO (#2457)
* Initial implementation of Lite DETR * Update model config for lite dino * Add norm to intermediate layer of ffn * Change FFN's norm order and add enc_scale attribute to encoder's layers * Merge with incremental recipe * Add model pretrained weight path * Update model info and add intg tests * Update docs * Update CHANGELOG * Change num iters
1 parent fc6386c commit 8045480

File tree

13 files changed

+648
-9
lines changed

13 files changed

+648
-9
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ All notable changes to this project will be documented in this file.
1212
- Add ONNX metadata to detection, instance segmantation, and segmentation models (<https://github.com/openvinotoolkit/training_extensions/pull/2418>)
1313
- Add a new feature to configure input size(<https://github.com/openvinotoolkit/training_extensions/pull/2420>)
1414
- Introduce the OTXSampler and AdaptiveRepeatDataHook to achieve faster training at the small data regime (<https://github.com/openvinotoolkit/training_extensions/pull/2428>)
15+
- Add a new object detector Lite-DINO(<https://github.com/openvinotoolkit/training_extensions/pull/2457>)
1516

1617
### Enhancements
1718

docs/source/guide/explanation/algorithms/object_detection/object_detection.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ In addition to these models, we supports experimental models for object detectio
100100
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+
101101
| `Custom_Object_Detection_Gen3_DINO <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/algorithms/detection/configs/detection/resnet50_dino/template_experimental.yaml>`_ | DINO | 235 | 182.0 |
102102
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+
103+
| `Custom_Object_Detection_Gen3_Lite_DINO <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/algorithms/detection/configs/detection/resnet50_litedino/template_experimental.yaml>`_ | Lite-DINO | 140 | 190.0 |
104+
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+
103105
| `Custom_Object_Detection_Gen3_ResNeXt101_ATSS <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/algorithms/detection/configs/detection/resnext101_atss/template_experimental.yaml>`_ | ResNeXt101-ATSS | 434.75 | 344.0 |
104106
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+
105107
| `Object_Detection_YOLOX_S <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_s/template_experimental.yaml>`_ | YOLOX_S | 33.51 | 46.0 |
@@ -110,6 +112,7 @@ In addition to these models, we supports experimental models for object detectio
110112
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+---------------------+-----------------+
111113

112114
`Deformable_DETR <https://arxiv.org/abs/2010.04159>`_ is `DETR <https://arxiv.org/abs/2005.12872>`_ based model, and it solves slow convergence problem of DETR. `DINO <https://arxiv.org/abs/2203.03605>`_ improves Deformable DETR based methods via denoising anchor boxes. Current SOTA models for object detection are based on DINO.
115+
`Lite-DINO <https://arxiv.org/abs/2303.07335>`_ is efficient structure for DINO. It reduces FLOPS of transformer's encoder which takes the highest computational costs.
113116
Although transformer based models show notable performance on various object detection benchmark, CNN based model still show good performance with proper latency.
114117
Therefore, we added a new experimental CNN based method, ResNeXt101-ATSS. ATSS still shows good performance among `RetinaNet <https://arxiv.org/abs/1708.02002>`_ based models. We integrated large ResNeXt101 backbone to our Custom ATSS head, and it shows good transfer learning performance.
115118
In addition, we added a YOLOX variants to support users' diverse situations.
@@ -154,6 +157,8 @@ We trained each model with a single Nvidia GeForce RTX3090.
154157
+----------------------------+------------------+-----------+-----------+-----------+-----------+--------------+
155158
| ResNet50-DINO | 49.0 (66.4) | 47.2 | 99.5 | 62.9 | 93.5 | 99.1 |
156159
+----------------------------+------------------+-----------+-----------+-----------+-----------+--------------+
160+
| ResNet50-Lite-DINO | 48.1 (64.4) | 47.0 | 99.0 | 62.5 | 93.6 | 99.4 |
161+
+----------------------------+------------------+-----------+-----------+-----------+-----------+--------------+
157162
| YOLOX_S | 40.3 (59.1) | 37.1 | 93.6 | 54.8 | 92.7 | 98.8 |
158163
+----------------------------+------------------+-----------+-----------+-----------+-----------+--------------+
159164
| YOLOX_L | 49.4 (67.1) | 44.5 | 94.6 | 55.8 | 91.8 | 99.0 |

src/otx/algorithms/common/adapters/mmcv/ops/multi_scale_deformable_attn_pytorch.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def _custom_grid_sample(im: torch.Tensor, grid: torch.Tensor, align_corners: boo
7878
Returns:
7979
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
8080
"""
81+
device = im.device
8182
n, c, h, w = im.shape
8283
gn, gh, gw, _ = grid.shape
8384
assert n == gn
@@ -113,14 +114,14 @@ def _custom_grid_sample(im: torch.Tensor, grid: torch.Tensor, align_corners: boo
113114
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
114115

115116
# Clip coordinates to padded image size
116-
x0 = torch.where(x0 < 0, torch.tensor(0), x0)
117-
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
118-
x1 = torch.where(x1 < 0, torch.tensor(0), x1)
119-
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
120-
y0 = torch.where(y0 < 0, torch.tensor(0), y0)
121-
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
122-
y1 = torch.where(y1 < 0, torch.tensor(0), y1)
123-
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)
117+
x0 = torch.where(x0 < 0, torch.tensor(0).to(device), x0)
118+
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1).to(device), x0)
119+
x1 = torch.where(x1 < 0, torch.tensor(0).to(device), x1)
120+
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1).to(device), x1)
121+
y0 = torch.where(y0 < 0, torch.tensor(0).to(device), y0)
122+
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1).to(device), y0)
123+
y1 = torch.where(y1 < 0, torch.tensor(0).to(device), y1)
124+
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1).to(device), y1)
124125

125126
im_padded = im_padded.view(n, c, -1)
126127

src/otx/algorithms/detection/adapters/mmdet/models/detectors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .custom_atss_detector import CustomATSS
77
from .custom_deformable_detr_detector import CustomDeformableDETR
88
from .custom_dino_detector import CustomDINO
9+
from .custom_lite_dino import CustomLiteDINO
910
from .custom_maskrcnn_detector import CustomMaskRCNN
1011
from .custom_maskrcnn_tile_optimized import CustomMaskRCNNTileOptimized
1112
from .custom_single_stage_detector import CustomSingleStageDetector
@@ -19,6 +20,7 @@
1920
__all__ = [
2021
"CustomATSS",
2122
"CustomDeformableDETR",
23+
"CustomLiteDINO",
2224
"CustomDINO",
2325
"CustomMaskRCNN",
2426
"CustomSingleStageDetector",
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""OTX Lite-DINO Class for object detection."""
2+
3+
# Copyright (C) 2023 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
7+
from mmdet.models.builder import DETECTORS
8+
9+
from otx.algorithms.common.utils.logger import get_logger
10+
from otx.algorithms.detection.adapters.mmdet.models.detectors import CustomDINO
11+
12+
logger = get_logger()
13+
14+
15+
@DETECTORS.register_module()
16+
class CustomLiteDINO(CustomDINO):
17+
"""Custom Lite-DINO <https://arxiv.org/pdf/2303.07335.pdf> for object detection."""
18+
19+
def load_state_dict_pre_hook(self, model_classes, ckpt_classes, ckpt_dict, *args, **kwargs):
20+
"""Modify official lite dino version's weights before weight loading."""
21+
super(CustomDINO, self).load_state_dict_pre_hook(model_classes, ckpt_classes, ckpt_dict, *args, *kwargs)

src/otx/algorithms/detection/adapters/mmdet/models/layers/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,13 @@
55

66
from .dino import CustomDINOTransformer
77
from .dino_layers import CdnQueryGenerator, DINOTransformerDecoder
8+
from .lite_detr_layers import EfficientTransformerEncoder, EfficientTransformerLayer, SmallExpandFFN
89

9-
__all__ = ["CustomDINOTransformer", "DINOTransformerDecoder", "CdnQueryGenerator"]
10+
__all__ = [
11+
"CustomDINOTransformer",
12+
"DINOTransformerDecoder",
13+
"CdnQueryGenerator",
14+
"EfficientTransformerEncoder",
15+
"EfficientTransformerLayer",
16+
"SmallExpandFFN",
17+
]

0 commit comments

Comments
 (0)