Skip to content

Commit 407a905

Browse files
Check TensorRT>=8.0.0 version (#6021)
* Check TensorRT>=8.0.0 version * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c1249a4 commit 407a905

2 files changed

Lines changed: 10 additions & 7 deletions

File tree

models/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from torch.cuda import amp
2222

2323
from utils.datasets import exif_transpose, letterbox
24-
from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path, make_divisible,
25-
non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
24+
from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
25+
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
2626
from utils.plots import Annotator, colors, save_one_box
2727
from utils.torch_utils import copy_attr, time_sync
2828

@@ -328,6 +328,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
328328
elif engine: # TensorRT
329329
LOGGER.info(f'Loading {w} for TensorRT inference...')
330330
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
331+
check_version(trt.__version__, '8.0.0', verbose=True) # version requirement
331332
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
332333
logger = trt.Logger(trt.Logger.INFO)
333334
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:

utils/general.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,14 +248,16 @@ def check_python(minimum='3.6.2'):
248248
check_version(platform.python_version(), minimum, name='Python ', hard=True)
249249

250250

251-
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False):
251+
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
252252
# Check version vs. required version
253253
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
254254
result = (current == minimum) if pinned else (current >= minimum) # bool
255-
if hard: # assert min requirements met
256-
assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
257-
else:
258-
return result
255+
s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
256+
if hard:
257+
assert result, s # assert min requirements met
258+
if verbose and not result:
259+
LOGGER.warning(s)
260+
return result
259261

260262

261263
@try_except

0 commit comments

Comments
 (0)