Skip to content

Commit 7a39803

Browse files
imyhxyglenn-jocher
andauthored
Export, detect and validation with TensorRT engine file (#5699)
* Export and detect with TensorRT engine file * Resolve `isort` * Make validation works with TensorRT engine * feat: update export docstring * feat: change suffix from *.trt to *.engine * feat: get rid of pycuda * feat: make compatiable with val.py * feat: support detect with fp16 engine * Add Lite to Edge TPU string * Remove *.trt comment * Revert to standard success logger.info string * Fix Deprecation Warning ``` export.py:310: DeprecationWarning: Use build_serialized_network instead. with builder.build_engine(network, config) as engine, open(f, 'wb') as t: ``` * Revert deprecation warning fix @imyhxy it seems we can't apply the deprecation warning fix because then export fails, so I'm reverting my previous change here. * Update export.py * Update export.py * Update common.py * export onnx to file before building TensorRT engine file * feat: triger ONNX export failed early * feat: load ONNX model from file Co-authored-by: Glenn Jocher <[email protected]>
1 parent f17c86b commit 7a39803

File tree

4 files changed

+90
-11
lines changed

4 files changed

+90
-11
lines changed

detect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,11 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
7777
# Load model
7878
device = select_device(device)
7979
model = DetectMultiBackend(weights, device=device, dnn=dnn)
80-
stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx
80+
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
8181
imgsz = check_img_size(imgsz, s=stride) # check image size
8282

8383
# Half
84-
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
84+
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
8585
if pt:
8686
model.model.half() if half else model.model.float()
8787

export.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
TensorFlow GraphDef | yolov5s.pb | 'pb'
1313
TensorFlow Lite | yolov5s.tflite | 'tflite'
1414
TensorFlow.js | yolov5s_web_model/ | 'tfjs'
15+
TensorRT | yolov5s.engine | 'engine'
1516
1617
Usage:
1718
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs
@@ -24,6 +25,7 @@
2425
yolov5s_saved_model
2526
yolov5s.pb
2627
yolov5s.tflite
28+
yolov5s.engine
2729
2830
TensorFlow.js:
2931
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
@@ -263,6 +265,51 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
263265
LOGGER.info(f'\n{prefix} export failure: {e}')
264266

265267

268+
def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
269+
try:
270+
check_requirements(('tensorrt',))
271+
import tensorrt as trt
272+
273+
opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
274+
export_onnx(model, im, file, opset, train, False, simplify)
275+
onnx = file.with_suffix('.onnx')
276+
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
277+
278+
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
279+
f = str(file).replace('.pt', '.engine') # TensorRT engine file
280+
logger = trt.Logger(trt.Logger.INFO)
281+
if verbose:
282+
logger.min_severity = trt.Logger.Severity.VERBOSE
283+
284+
builder = trt.Builder(logger)
285+
config = builder.create_builder_config()
286+
config.max_workspace_size = workspace * 1 << 30
287+
288+
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
289+
network = builder.create_network(flag)
290+
parser = trt.OnnxParser(network, logger)
291+
if not parser.parse_from_file(str(onnx)):
292+
raise RuntimeError(f'failed to load ONNX file: {onnx}')
293+
294+
inputs = [network.get_input(i) for i in range(network.num_inputs)]
295+
outputs = [network.get_output(i) for i in range(network.num_outputs)]
296+
LOGGER.info(f'{prefix} Network Description:')
297+
for inp in inputs:
298+
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
299+
for out in outputs:
300+
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
301+
302+
half &= builder.platform_has_fast_fp16
303+
LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}')
304+
if half:
305+
config.set_flag(trt.BuilderFlag.FP16)
306+
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
307+
t.write(engine.serialize())
308+
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
309+
310+
except Exception as e:
311+
LOGGER.info(f'\n{prefix} export failure: {e}')
312+
266313
@torch.no_grad()
267314
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
268315
weights=ROOT / 'yolov5s.pt', # weights path
@@ -278,6 +325,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
278325
dynamic=False, # ONNX/TF: dynamic axes
279326
simplify=False, # ONNX: simplify model
280327
opset=12, # ONNX: opset version
328+
verbose=False, # TensorRT: verbose log
329+
workspace=4, # TensorRT: workspace size (GB)
281330
topk_per_class=100, # TF.js NMS: topk per class to keep
282331
topk_all=100, # TF.js NMS: topk for all classes to keep
283332
iou_thres=0.45, # TF.js NMS: IoU threshold
@@ -322,6 +371,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
322371
export_torchscript(model, im, file, optimize)
323372
if 'onnx' in include:
324373
export_onnx(model, im, file, opset, train, dynamic, simplify)
374+
if 'engine' in include:
375+
export_engine(model, im, file, train, half, simplify, workspace, verbose)
325376
if 'coreml' in include:
326377
export_coreml(model, im, file)
327378

@@ -360,13 +411,15 @@ def parse_opt():
360411
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
361412
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
362413
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
414+
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
415+
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
363416
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
364417
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
365418
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
366419
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
367420
parser.add_argument('--include', nargs='+',
368421
default=['torchscript', 'onnx'],
369-
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
422+
help='available formats are (torchscript, onnx, engine, coreml, saved_model, pb, tflite, tfjs)')
370423
opt = parser.parse_args()
371424
print_args(FILE.stem, opt)
372425
return opt

models/common.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import math
88
import platform
99
import warnings
10+
from collections import namedtuple
1011
from copy import copy
1112
from pathlib import Path
1213

@@ -285,11 +286,12 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
285286
# TensorFlow Lite: *.tflite
286287
# ONNX Runtime: *.onnx
287288
# OpenCV DNN: *.onnx with dnn=True
289+
# TensorRT: *.engine
288290
super().__init__()
289291
w = str(weights[0] if isinstance(weights, list) else weights)
290-
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel']
292+
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel']
291293
check_suffix(w, suffixes) # check weights have acceptable suffix
292-
pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
294+
pt, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
293295
jit = pt and 'torchscript' in w.lower()
294296
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
295297

@@ -317,6 +319,23 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
317319
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
318320
import onnxruntime
319321
session = onnxruntime.InferenceSession(w, None)
322+
elif engine: # TensorRT
323+
LOGGER.info(f'Loading {w} for TensorRT inference...')
324+
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
325+
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
326+
logger = trt.Logger(trt.Logger.INFO)
327+
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
328+
model = runtime.deserialize_cuda_engine(f.read())
329+
bindings = dict()
330+
for index in range(model.num_bindings):
331+
name = model.get_binding_name(index)
332+
dtype = trt.nptype(model.get_binding_dtype(index))
333+
shape = tuple(model.get_binding_shape(index))
334+
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
335+
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
336+
binding_addrs = {n: d.ptr for n, d in bindings.items()}
337+
context = model.create_execution_context()
338+
batch_size = bindings['images'].shape[0]
320339
else: # TensorFlow model (TFLite, pb, saved_model)
321340
import tensorflow as tf
322341
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
@@ -334,7 +353,7 @@ def wrap_frozen_graph(gd, inputs, outputs):
334353
model = tf.keras.models.load_model(w)
335354
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
336355
if 'edgetpu' in w.lower():
337-
LOGGER.info(f'Loading {w} for TensorFlow Edge TPU inference...')
356+
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
338357
import tflite_runtime.interpreter as tfli
339358
delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime
340359
'Darwin': 'libedgetpu.1.dylib',
@@ -369,6 +388,11 @@ def forward(self, im, augment=False, visualize=False, val=False):
369388
y = self.net.forward()
370389
else: # ONNX Runtime
371390
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
391+
elif self.engine: # TensorRT
392+
assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
393+
self.binding_addrs['images'] = int(im.data_ptr())
394+
self.context.execute_v2(list(self.binding_addrs.values()))
395+
y = self.bindings['output'].data
372396
else: # TensorFlow model (TFLite, pb, saved_model)
373397
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
374398
if self.pb:
@@ -391,7 +415,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
391415
y[..., 1] *= h # y
392416
y[..., 2] *= w # w
393417
y[..., 3] *= h # h
394-
y = torch.tensor(y)
418+
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
395419
return (y, []) if val else y
396420

397421

val.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def run(data,
111111
# Initialize/load model and set device
112112
training = model is not None
113113
if training: # called by train.py
114-
device, pt = next(model.parameters()).device, True # get model device, PyTorch model
114+
device, pt, engine = next(model.parameters()).device, True, False # get model device, PyTorch model
115115

116116
half &= device.type != 'cpu' # half precision only supported on CUDA
117117
model.half() if half else model.float()
@@ -124,11 +124,13 @@ def run(data,
124124

125125
# Load model
126126
model = DetectMultiBackend(weights, device=device, dnn=dnn)
127-
stride, pt = model.stride, model.pt
127+
stride, pt, engine = model.stride, model.pt, model.engine
128128
imgsz = check_img_size(imgsz, s=stride) # check image size
129-
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
129+
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
130130
if pt:
131131
model.model.half() if half else model.model.float()
132+
elif engine:
133+
batch_size = model.batch_size
132134
else:
133135
half = False
134136
batch_size = 1 # export.py models default to batch-size 1
@@ -165,7 +167,7 @@ def run(data,
165167
pbar = tqdm(dataloader, desc=s, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
166168
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
167169
t1 = time_sync()
168-
if pt:
170+
if pt or engine:
169171
im = im.to(device, non_blocking=True)
170172
targets = targets.to(device)
171173
im = im.half() if half else im.float() # uint8 to fp16/32

0 commit comments

Comments
 (0)