Skip to content

Commit f172ba9

Browse files
authored
add torchvision detector support (#486)
* add torchvision detector support * add test config files for torchvision models * update notebook * add TorchVisionDetectionModel * add torchvision to automodel * add torchvision demo url, update versions in readme * fix linting
1 parent 268da8d commit f172ba9

File tree

8 files changed

+604
-4
lines changed

8 files changed

+604
-4
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ Object detection and instance segmentation are by far the most important fields
7575

7676
- `HuggingFace` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_huggingface.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-huggingface"></a> (NEW)
7777

78+
- `TorchVision` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_torchvision.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-torchvision"></a> (NEW)
79+
7880
<a href="https://huggingface.co/spaces/fcakyon/sahi-yolox"><img width="600" src="https://user-images.githubusercontent.com/34196005/144092739-c1d9bade-a128-4346-947f-424ce00e5c4f.gif" alt="sahi-yolox"></a>
7981

8082

@@ -111,17 +113,17 @@ conda install pytorch=1.10.2 torchvision=0.11.3 cudatoolkit=11.3 -c pytorch
111113
- Install your desired detection framework (yolov5):
112114

113115
```console
114-
pip install yolov5
116+
pip install yolov5==6.1.3
115117
```
116118

117119
- Install your desired detection framework (mmdet):
118120

119121
```console
120-
pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html
122+
pip install mmcv-full==1.5.3 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html
121123
```
122124

123125
```console
124-
pip install mmdet==2.21.0
126+
pip install mmdet==2.25.0
125127
```
126128

127129
- Install your desired detection framework (detectron2):

demo/inference_for_torchvision.ipynb

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

sahi/auto_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
from typing import Dict, Optional
22

3-
from sahi.model import Detectron2DetectionModel, HuggingfaceDetectionModel, MmdetDetectionModel, Yolov5DetectionModel
3+
from sahi.model import (
4+
Detectron2DetectionModel,
5+
HuggingfaceDetectionModel,
6+
MmdetDetectionModel,
7+
TorchVisionDetectionModel,
8+
Yolov5DetectionModel,
9+
)
410

511
MODEL_TYPE_TO_MODEL_CLASS_NAME = {
612
"mmdet": MmdetDetectionModel,
713
"yolov5": Yolov5DetectionModel,
814
"detectron2": Detectron2DetectionModel,
915
"huggingface": HuggingfaceDetectionModel,
16+
"torchvision": TorchVisionDetectionModel,
1017
}
1118

1219

sahi/model.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,3 +879,209 @@ def _create_object_prediction_list_from_original_predictions(
879879
object_prediction_list_per_image.append(object_prediction_list)
880880

881881
self._object_prediction_list_per_image = object_prediction_list_per_image
882+
883+
884+
@check_requirements(["torch", "torchvision"])
885+
class TorchVisionDetectionModel(DetectionModel):
886+
def __init__(
887+
self,
888+
model_path: Optional[str] = None,
889+
model: Optional[Any] = None,
890+
config_path: Optional[str] = None,
891+
device: Optional[str] = None,
892+
mask_threshold: float = 0.5,
893+
confidence_threshold: float = 0.3,
894+
category_mapping: Optional[Dict] = None,
895+
category_remapping: Optional[Dict] = None,
896+
load_at_init: bool = True,
897+
image_size: int = None,
898+
):
899+
900+
super().__init__(
901+
model_path=model_path,
902+
model=model,
903+
config_path=config_path,
904+
device=device,
905+
mask_threshold=mask_threshold,
906+
confidence_threshold=confidence_threshold,
907+
category_mapping=category_mapping,
908+
category_remapping=category_remapping,
909+
load_at_init=load_at_init,
910+
image_size=image_size,
911+
)
912+
913+
def load_model(self):
914+
import torch
915+
916+
from sahi.utils.torchvision import MODEL_NAME_TO_CONSTRUCTOR
917+
918+
# read config params
919+
model_name = None
920+
num_classes = None
921+
if self.config_path is not None:
922+
import yaml
923+
924+
with open(self.config_path, "r") as stream:
925+
try:
926+
config = yaml.safe_load(stream)
927+
except yaml.YAMLError as exc:
928+
raise RuntimeError(exc)
929+
930+
model_name = config.get("model_name", None)
931+
num_classes = config.get("num_classes", None)
932+
933+
# complete params if not provided in config
934+
if not model_name:
935+
model_name = "fasterrcnn_resnet50_fpn"
936+
logger.warning(f"model_name not provided in config, using default model_type: {model_name}'")
937+
if num_classes is None:
938+
logger.warning("num_classes not provided in config, using default num_classes: 91")
939+
num_classes = 91
940+
if self.model_path is None:
941+
logger.warning("model_path not provided in config, using pretrained weights and default num_classes: 91.")
942+
pretrained = True
943+
num_classes = 91
944+
else:
945+
pretrained = False
946+
947+
# load model
948+
model = MODEL_NAME_TO_CONSTRUCTOR[model_name](num_classes=num_classes, pretrained=pretrained)
949+
try:
950+
model.load_state_dict(torch.load(self.model_path))
951+
except Exception as e:
952+
TypeError("model_path is not a valid torchvision model path: ", e)
953+
954+
self.set_model(model)
955+
956+
def set_model(self, model: Any):
957+
"""
958+
Sets the underlying TorchVision model.
959+
Args:
960+
model: Any
961+
A TorchVision model
962+
"""
963+
964+
model.eval()
965+
self.model = model.to(self.device)
966+
967+
# set category_mapping
968+
from sahi.utils.torchvision import COCO_CLASSES
969+
970+
if self.category_mapping is None:
971+
category_names = {str(i): COCO_CLASSES[i] for i in range(len(COCO_CLASSES))}
972+
self.category_mapping = category_names
973+
974+
def perform_inference(self, image: np.ndarray, image_size: int = None):
975+
"""
976+
Prediction is performed using self.model and the prediction result is set to self._original_predictions.
977+
Args:
978+
image: np.ndarray
979+
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
980+
image_size: int
981+
Inference input size.
982+
"""
983+
from sahi.utils.torch import to_float_tensor
984+
985+
# arrange model input size
986+
if self.image_size is not None:
987+
# get min and max of image height and width
988+
min_shape, max_shape = min(image.shape[:2]), max(image.shape[:2])
989+
# torchvision resize transform scales the shorter dimension to the target size
990+
# we want to scale the longer dimension to the target size
991+
image_size = self.image_size * min_shape / max_shape
992+
self.model.transform.min_size = (image_size,) # default is (800,)
993+
self.model.transform.max_size = image_size # default is 1333
994+
995+
image = to_float_tensor(image)
996+
image = image.to(self.device)
997+
prediction_result = self.model([image])
998+
999+
self._original_predictions = prediction_result
1000+
1001+
@property
1002+
def num_categories(self):
1003+
"""
1004+
Returns number of categories
1005+
"""
1006+
return len(self.category_mapping)
1007+
1008+
@property
1009+
def has_mask(self):
1010+
"""
1011+
Returns if model output contains segmentation mask
1012+
"""
1013+
return self.model.with_mask
1014+
1015+
@property
1016+
def category_names(self):
1017+
return list(self.category_mapping.values())
1018+
1019+
def _create_object_prediction_list_from_original_predictions(
1020+
self,
1021+
shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
1022+
full_shape_list: Optional[List[List[int]]] = None,
1023+
):
1024+
"""
1025+
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
1026+
self._object_prediction_list_per_image.
1027+
Args:
1028+
shift_amount_list: list of list
1029+
To shift the box and mask predictions from sliced image to full sized image, should
1030+
be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
1031+
full_shape_list: list of list
1032+
Size of the full image after shifting, should be in the form of
1033+
List[[height, width],[height, width],...]
1034+
"""
1035+
original_predictions = self._original_predictions
1036+
1037+
# compatilibty for sahi v0.8.20
1038+
if isinstance(shift_amount_list[0], int):
1039+
shift_amount_list = [shift_amount_list]
1040+
if full_shape_list is not None and isinstance(full_shape_list[0], int):
1041+
full_shape_list = [full_shape_list]
1042+
1043+
for image_predictions in original_predictions:
1044+
object_prediction_list_per_image = []
1045+
1046+
# get indices of boxes with score > confidence_threshold
1047+
scores = image_predictions["scores"].cpu().detach().numpy()
1048+
selected_indices = np.where(scores > self.confidence_threshold)[0]
1049+
1050+
# parse boxes, masks, scores, category_ids from predictions
1051+
category_ids = list(image_predictions["labels"][selected_indices].cpu().detach().numpy())
1052+
boxes = list(image_predictions["boxes"][selected_indices].cpu().detach().numpy())
1053+
scores = scores[selected_indices]
1054+
1055+
# check if predictions contain mask
1056+
masks = image_predictions.get("masks", None)
1057+
if masks is not None:
1058+
masks = list(image_predictions["masks"][selected_indices].cpu().detach().numpy())
1059+
else:
1060+
masks = None
1061+
1062+
# create object_prediction_list
1063+
object_prediction_list = []
1064+
1065+
shift_amount = shift_amount_list[0]
1066+
full_shape = None if full_shape_list is None else full_shape_list[0]
1067+
1068+
for ind in range(len(boxes)):
1069+
1070+
if masks is not None:
1071+
mask = np.array(masks[ind])
1072+
else:
1073+
mask = None
1074+
1075+
object_prediction = ObjectPrediction(
1076+
bbox=boxes[ind],
1077+
bool_mask=mask,
1078+
category_id=int(category_ids[ind]),
1079+
category_name=self.category_mapping[str(int(category_ids[ind]))],
1080+
shift_amount=shift_amount,
1081+
score=scores[ind],
1082+
full_shape=full_shape,
1083+
)
1084+
object_prediction_list.append(object_prediction)
1085+
object_prediction_list_per_image.append(object_prediction_list)
1086+
1087+
self._object_prediction_list_per_image = object_prediction_list_per_image

sahi/utils/torchvision.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# OBSS SAHI Tool
2+
# Code written by Kadir Nar, 2022.
3+
4+
5+
from packaging import version
6+
7+
from sahi.utils.import_utils import _torchvision_available, _torchvision_version, is_available
8+
9+
10+
class TorchVisionTestConstants:
11+
FASTERRCNN_CONFIG_PATH = "tests/data/models/torchvision/fasterrcnn_resnet50_fpn.yaml"
12+
SSD300_CONFIG_PATH = "tests/data/models/torchvision/ssd300_vgg16.yaml"
13+
14+
15+
if _torchvision_available:
16+
import torchvision
17+
18+
MODEL_NAME_TO_CONSTRUCTOR = {
19+
"fasterrcnn_resnet50_fpn": torchvision.models.detection.fasterrcnn_resnet50_fpn,
20+
"fasterrcnn_mobilenet_v3_large_fpn": torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn,
21+
"fasterrcnn_mobilenet_v3_large_320_fpn": torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn,
22+
"retinanet_resnet50_fpn": torchvision.models.detection.retinanet_resnet50_fpn,
23+
"ssd300_vgg16": torchvision.models.detection.ssd300_vgg16,
24+
"ssdlite320_mobilenet_v3_large": torchvision.models.detection.ssdlite320_mobilenet_v3_large,
25+
}
26+
27+
# fcos requires torchvision >= 0.12.0
28+
if version.parse(_torchvision_version) >= version.parse("0.12.0"):
29+
MODEL_NAME_TO_CONSTRUCTOR["fcos_resnet50_fpn"] = (torchvision.models.detection.fcos_resnet50_fpn,)
30+
31+
32+
COCO_CLASSES = [
33+
"__background__",
34+
"person",
35+
"bicycle",
36+
"car",
37+
"motorcycle",
38+
"airplane",
39+
"bus",
40+
"train",
41+
"truck",
42+
"boat",
43+
"traffic light",
44+
"fire hydrant",
45+
"N/A",
46+
"stop sign",
47+
"parking meter",
48+
"bench",
49+
"bird",
50+
"cat",
51+
"dog",
52+
"horse",
53+
"sheep",
54+
"cow",
55+
"elephant",
56+
"bear",
57+
"zebra",
58+
"giraffe",
59+
"N/A",
60+
"backpack",
61+
"umbrella",
62+
"N/A",
63+
"N/A",
64+
"handbag",
65+
"tie",
66+
"suitcase",
67+
"frisbee",
68+
"skis",
69+
"snowboard",
70+
"sports ball",
71+
"kite",
72+
"baseball bat",
73+
"baseball glove",
74+
"skateboard",
75+
"surfboard",
76+
"tennis racket",
77+
"bottle",
78+
"N/A",
79+
"wine glass",
80+
"cup",
81+
"fork",
82+
"knife",
83+
"spoon",
84+
"bowl",
85+
"banana",
86+
"apple",
87+
"sandwich",
88+
"orange",
89+
"broccoli",
90+
"carrot",
91+
"hot dog",
92+
"pizza",
93+
"donut",
94+
"cake",
95+
"chair",
96+
"couch",
97+
"potted plant",
98+
"bed",
99+
"N/A",
100+
"dining table",
101+
"N/A",
102+
"N/A",
103+
"toilet",
104+
"N/A",
105+
"tv",
106+
"laptop",
107+
"mouse",
108+
"remote",
109+
"keyboard",
110+
"cell phone",
111+
"microwave",
112+
"oven",
113+
"toaster",
114+
"sink",
115+
"refrigerator",
116+
"N/A",
117+
"book",
118+
"clock",
119+
"vase",
120+
"scissors",
121+
"teddy bear",
122+
"hair drier",
123+
"toothbrush",
124+
]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model_name: fasterrcnn_resnet50_fpn
2+
num_classes: 91
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model_name: ssd300_vgg16
2+
num_classes: 91

0 commit comments

Comments
 (0)