Skip to content

Commit f2254dd

Browse files
committed
1 parent 711a664 commit f2254dd

2 files changed

Lines changed: 16 additions & 18 deletions

File tree

yolort/v5/helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from .models import AutoShape
78
from .models.yolo import Model
89
from .utils import attempt_download, intersect_dicts, set_logging
910

@@ -68,6 +69,6 @@ def load_yolov5_model(checkpoint_path: str, autoshape: bool = False, verbose: bo
6869
model.load_state_dict(ckpt_state_dict, strict=False)
6970

7071
if autoshape:
71-
model = model.autoshape()
72+
model = AutoShape(model)
7273

7374
return model

yolort/v5/models/common.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -434,17 +434,17 @@ def autoshape(self):
434434

435435
@torch.no_grad()
436436
def forward(self, imgs, size=640, augment=False, profile=False):
437+
"""
438+
Inference from various sources. For height=640, width=1280, RGB images example inputs are:
439+
- file: imgs = 'data/images/zidane.jpg' # str or PosixPath
440+
- URI: = 'https://ultralytics.com/images/zidane.jpg'
441+
- OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
442+
- PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
443+
- numpy: = np.zeros((640,1280,3)) # HWC
444+
- torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
445+
- multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
446+
"""
437447
from yolort.v5.utils.augmentations import letterbox
438-
439-
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
440-
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath
441-
# URI: = 'https://ultralytics.com/images/zidane.jpg'
442-
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
443-
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
444-
# numpy: = np.zeros((640,1280,3)) # HWC
445-
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
446-
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
447-
448448
from yolort.v5.utils.datasets import exif_transpose
449449

450450
t = [time_sync()]
@@ -454,17 +454,14 @@ def forward(self, imgs, size=640, augment=False, profile=False):
454454
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
455455

456456
# Pre-process
457-
n, imgs = (
458-
(len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs])
459-
) # number of images, list of images
457+
# number of images, list of images
458+
n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs])
460459
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
461460
for i, im in enumerate(imgs):
462461
f = f"image{i}" # filename
463462
if isinstance(im, (str, Path)): # filename or uri
464-
im, f = (
465-
Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im),
466-
im,
467-
)
463+
f = im
464+
im = Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im)
468465
im = np.asarray(exif_transpose(im))
469466
elif isinstance(im, Image.Image): # PIL Image
470467
im, f = np.asarray(exif_transpose(im)), getattr(im, "filename", f) or f

0 commit comments

Comments
 (0)