@@ -879,3 +879,209 @@ def _create_object_prediction_list_from_original_predictions(
879879 object_prediction_list_per_image .append (object_prediction_list )
880880
881881 self ._object_prediction_list_per_image = object_prediction_list_per_image
882+
883+
884+ @check_requirements (["torch" , "torchvision" ])
885+ class TorchVisionDetectionModel (DetectionModel ):
886+ def __init__ (
887+ self ,
888+ model_path : Optional [str ] = None ,
889+ model : Optional [Any ] = None ,
890+ config_path : Optional [str ] = None ,
891+ device : Optional [str ] = None ,
892+ mask_threshold : float = 0.5 ,
893+ confidence_threshold : float = 0.3 ,
894+ category_mapping : Optional [Dict ] = None ,
895+ category_remapping : Optional [Dict ] = None ,
896+ load_at_init : bool = True ,
897+ image_size : int = None ,
898+ ):
899+
900+ super ().__init__ (
901+ model_path = model_path ,
902+ model = model ,
903+ config_path = config_path ,
904+ device = device ,
905+ mask_threshold = mask_threshold ,
906+ confidence_threshold = confidence_threshold ,
907+ category_mapping = category_mapping ,
908+ category_remapping = category_remapping ,
909+ load_at_init = load_at_init ,
910+ image_size = image_size ,
911+ )
912+
913+ def load_model (self ):
914+ import torch
915+
916+ from sahi .utils .torchvision import MODEL_NAME_TO_CONSTRUCTOR
917+
918+ # read config params
919+ model_name = None
920+ num_classes = None
921+ if self .config_path is not None :
922+ import yaml
923+
924+ with open (self .config_path , "r" ) as stream :
925+ try :
926+ config = yaml .safe_load (stream )
927+ except yaml .YAMLError as exc :
928+ raise RuntimeError (exc )
929+
930+ model_name = config .get ("model_name" , None )
931+ num_classes = config .get ("num_classes" , None )
932+
933+ # complete params if not provided in config
934+ if not model_name :
935+ model_name = "fasterrcnn_resnet50_fpn"
936+ logger .warning (f"model_name not provided in config, using default model_type: { model_name } '" )
937+ if num_classes is None :
938+ logger .warning ("num_classes not provided in config, using default num_classes: 91" )
939+ num_classes = 91
940+ if self .model_path is None :
941+ logger .warning ("model_path not provided in config, using pretrained weights and default num_classes: 91." )
942+ pretrained = True
943+ num_classes = 91
944+ else :
945+ pretrained = False
946+
947+ # load model
948+ model = MODEL_NAME_TO_CONSTRUCTOR [model_name ](num_classes = num_classes , pretrained = pretrained )
949+ try :
950+ model .load_state_dict (torch .load (self .model_path ))
951+ except Exception as e :
952+ TypeError ("model_path is not a valid torchvision model path: " , e )
953+
954+ self .set_model (model )
955+
956+ def set_model (self , model : Any ):
957+ """
958+ Sets the underlying TorchVision model.
959+ Args:
960+ model: Any
961+ A TorchVision model
962+ """
963+
964+ model .eval ()
965+ self .model = model .to (self .device )
966+
967+ # set category_mapping
968+ from sahi .utils .torchvision import COCO_CLASSES
969+
970+ if self .category_mapping is None :
971+ category_names = {str (i ): COCO_CLASSES [i ] for i in range (len (COCO_CLASSES ))}
972+ self .category_mapping = category_names
973+
974+ def perform_inference (self , image : np .ndarray , image_size : int = None ):
975+ """
976+ Prediction is performed using self.model and the prediction result is set to self._original_predictions.
977+ Args:
978+ image: np.ndarray
979+ A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
980+ image_size: int
981+ Inference input size.
982+ """
983+ from sahi .utils .torch import to_float_tensor
984+
985+ # arrange model input size
986+ if self .image_size is not None :
987+ # get min and max of image height and width
988+ min_shape , max_shape = min (image .shape [:2 ]), max (image .shape [:2 ])
989+ # torchvision resize transform scales the shorter dimension to the target size
990+ # we want to scale the longer dimension to the target size
991+ image_size = self .image_size * min_shape / max_shape
992+ self .model .transform .min_size = (image_size ,) # default is (800,)
993+ self .model .transform .max_size = image_size # default is 1333
994+
995+ image = to_float_tensor (image )
996+ image = image .to (self .device )
997+ prediction_result = self .model ([image ])
998+
999+ self ._original_predictions = prediction_result
1000+
1001+ @property
1002+ def num_categories (self ):
1003+ """
1004+ Returns number of categories
1005+ """
1006+ return len (self .category_mapping )
1007+
1008+ @property
1009+ def has_mask (self ):
1010+ """
1011+ Returns if model output contains segmentation mask
1012+ """
1013+ return self .model .with_mask
1014+
1015+ @property
1016+ def category_names (self ):
1017+ return list (self .category_mapping .values ())
1018+
1019+ def _create_object_prediction_list_from_original_predictions (
1020+ self ,
1021+ shift_amount_list : Optional [List [List [int ]]] = [[0 , 0 ]],
1022+ full_shape_list : Optional [List [List [int ]]] = None ,
1023+ ):
1024+ """
1025+ self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
1026+ self._object_prediction_list_per_image.
1027+ Args:
1028+ shift_amount_list: list of list
1029+ To shift the box and mask predictions from sliced image to full sized image, should
1030+ be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
1031+ full_shape_list: list of list
1032+ Size of the full image after shifting, should be in the form of
1033+ List[[height, width],[height, width],...]
1034+ """
1035+ original_predictions = self ._original_predictions
1036+
1037+ # compatilibty for sahi v0.8.20
1038+ if isinstance (shift_amount_list [0 ], int ):
1039+ shift_amount_list = [shift_amount_list ]
1040+ if full_shape_list is not None and isinstance (full_shape_list [0 ], int ):
1041+ full_shape_list = [full_shape_list ]
1042+
1043+ for image_predictions in original_predictions :
1044+ object_prediction_list_per_image = []
1045+
1046+ # get indices of boxes with score > confidence_threshold
1047+ scores = image_predictions ["scores" ].cpu ().detach ().numpy ()
1048+ selected_indices = np .where (scores > self .confidence_threshold )[0 ]
1049+
1050+ # parse boxes, masks, scores, category_ids from predictions
1051+ category_ids = list (image_predictions ["labels" ][selected_indices ].cpu ().detach ().numpy ())
1052+ boxes = list (image_predictions ["boxes" ][selected_indices ].cpu ().detach ().numpy ())
1053+ scores = scores [selected_indices ]
1054+
1055+ # check if predictions contain mask
1056+ masks = image_predictions .get ("masks" , None )
1057+ if masks is not None :
1058+ masks = list (image_predictions ["masks" ][selected_indices ].cpu ().detach ().numpy ())
1059+ else :
1060+ masks = None
1061+
1062+ # create object_prediction_list
1063+ object_prediction_list = []
1064+
1065+ shift_amount = shift_amount_list [0 ]
1066+ full_shape = None if full_shape_list is None else full_shape_list [0 ]
1067+
1068+ for ind in range (len (boxes )):
1069+
1070+ if masks is not None :
1071+ mask = np .array (masks [ind ])
1072+ else :
1073+ mask = None
1074+
1075+ object_prediction = ObjectPrediction (
1076+ bbox = boxes [ind ],
1077+ bool_mask = mask ,
1078+ category_id = int (category_ids [ind ]),
1079+ category_name = self .category_mapping [str (int (category_ids [ind ]))],
1080+ shift_amount = shift_amount ,
1081+ score = scores [ind ],
1082+ full_shape = full_shape ,
1083+ )
1084+ object_prediction_list .append (object_prediction )
1085+ object_prediction_list_per_image .append (object_prediction_list )
1086+
1087+ self ._object_prediction_list_per_image = object_prediction_list_per_image
0 commit comments