1- # YOLOv5 by Ultralytics, GPL-3.0 license
1+ # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
22"""
33Common modules
44"""
1616from PIL import Image
1717from torch import nn , Tensor
1818from torch .cuda import amp
19- from yolort .v5 .utils .datasets import exif_transpose , letterbox
2019from yolort .v5 .utils .general import (
2120 colorstr ,
2221 increment_path ,
2322 is_ascii ,
2423 make_divisible ,
2524 non_max_suppression ,
26- save_one_box ,
2725 scale_coords ,
2826 xyxy2xywh ,
2927)
30- from yolort .v5 .utils .plots import Annotator , colors
31- from yolort .v5 .utils .torch_utils import time_sync
28+ from yolort .v5 .utils .plots import Annotator , colors , save_one_box
29+ from yolort .v5 .utils .torch_utils import copy_attr , time_sync
3230
3331LOGGER = logging .getLogger (__name__ )
3432
@@ -414,32 +412,52 @@ def forward(self, x):
414412
415413
416414class AutoShape (nn .Module ):
417- # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs.
418- # Includes preprocessing, inference and NMS
415+ """
416+ YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs.
417+ Includes preprocessing, inference and NMS
418+ """
419+
419420 conf = 0.25 # NMS confidence threshold
420421 iou = 0.45 # NMS IoU threshold
421- classes = None # (optional list) filter by class
422+ # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
423+ classes = None
422424 multi_label = False # NMS multiple labels per box
423425 max_det = 1000 # maximum number of detections per image
424426
425427 def __init__ (self , model ):
426428 super ().__init__ ()
429+ LOGGER .info ("Adding AutoShape... " )
430+ # copy attributes
431+ copy_attr (self , model , include = ("yaml" , "nc" , "hyp" , "names" , "stride" , "abc" ), exclude = ())
427432 self .model = model .eval ()
428433
429- def autoshape (self ):
430- LOGGER .info ("AutoShape already enabled, skipping... " ) # model already converted to model.autoshape()
434+ def _apply (self , fn ):
435+ """
436+ Apply to(), cpu(), cuda(), half() to model tensors that
437+ are not parameters or registered buffers
438+ """
439+ self = super ()._apply (fn )
440+ m = self .model .model [- 1 ] # Detect()
441+ m .stride = fn (m .stride )
442+ m .grid = list (map (fn , m .grid ))
443+ if isinstance (m .anchor_grid , list ):
444+ m .anchor_grid = list (map (fn , m .anchor_grid ))
431445 return self
432446
433447 @torch .no_grad ()
434448 def forward (self , imgs , size = 640 , augment = False , profile = False ):
435- # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
436- # file: imgs = 'data/images/zidane.jpg' # str or PosixPath
437- # URI: = 'https://ultralytics.com/images/zidane.jpg'
438- # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
439- # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
440- # numpy: = np.zeros((640,1280,3)) # HWC
441- # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
442- # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
449+ """
450+ Inference from various sources. For height=640, width=1280, RGB images example inputs are:
451+ - file: imgs = 'data/images/zidane.jpg' # str or PosixPath
452+ - URI: = 'https://ultralytics.com/images/zidane.jpg'
453+ - OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
454+ - PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
455+ - numpy: = np.zeros((640,1280,3)) # HWC
456+ - torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
457+ - multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
458+ """
459+ from yolort .v5 .utils .augmentations import letterbox
460+ from yolort .v5 .utils .datasets import exif_transpose
443461
444462 t = [time_sync ()]
445463 p = next (self .model .parameters ()) # for device and type
@@ -448,10 +466,10 @@ def forward(self, imgs, size=640, augment=False, profile=False):
448466 return self .model (imgs .to (p .device ).type_as (p ), augment , profile ) # inference
449467
450468 # Pre-process
451- n , imgs = (
452- (len (imgs ), imgs ) if isinstance (imgs , list ) else (1 , [imgs ])
453- ) # number of images, list of images
454- shape0 , shape1 , files = [], [], [] # image and inference shapes, filenames
469+ # number of images, list of images
470+ n , imgs = (len (imgs ), imgs ) if isinstance (imgs , list ) else (1 , [imgs ])
471+ # image and inference shapes, filenames
472+ shape0 , shape1 , files = [], [], []
455473 for i , im in enumerate (imgs ):
456474 f = f"image{ i } " # filename
457475 if isinstance (im , (str , Path )): # filename or uri
@@ -476,7 +494,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
476494 x = [letterbox (im , new_shape = shape1 , auto = False )[0 ] for im in imgs ] # pad
477495 x = np .stack (x , 0 ) if n > 1 else x [0 ][None ] # stack
478496 x = np .ascontiguousarray (x .transpose ((0 , 3 , 1 , 2 ))) # BHWC to BCHW
479- x = torch .from_numpy (x ).to (p .device ).type_as (p ) / 255.0 # uint8 to fp16/32
497+ x = torch .from_numpy (x ).to (p .device ).type_as (p ) / 255 # uint8 to fp16/32
480498 t .append (time_sync ())
481499
482500 with amp .autocast (enabled = p .device .type != "cpu" ):
@@ -492,7 +510,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
492510 classes = self .classes ,
493511 multi_label = self .multi_label ,
494512 max_det = self .max_det ,
495- ) # NMS
513+ )
496514 for i in range (n ):
497515 scale_coords (shape1 , y [i ][:, :4 ], shape0 [i ])
498516
0 commit comments