@@ -433,9 +433,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
433433 conf_thres = 0.25 # TF.js NMS: confidence threshold
434434 ):
435435 t = time .time ()
436- include = [x .lower () for x in include ]
437- tf_exports = list (x in include for x in ('saved_model' , 'pb' , 'tflite' , 'edgetpu' , 'tfjs' )) # TensorFlow exports
438- file = Path (url2file (weights ) if str (weights ).startswith (('http:/' , 'https:/' )) else weights )
436+ include = [x .lower () for x in include ] # to lowercase
437+ formats = tuple (export_formats ()['Argument' ][1 :]) # --include arguments
438+ flags = [x in include for x in formats ]
439+ assert sum (flags ) == len (include ), f'ERROR: Invalid --include { include } , valid --include arguments are { formats } '
440+ jit , onnx , xml , engine , coreml , saved_model , pb , tflite , edgetpu , tfjs = flags # export booleans
441+ file = Path (url2file (weights ) if str (weights ).startswith (('http:/' , 'https:/' )) else weights ) # PyTorch weights
439442
440443 # Load PyTorch model
441444 device = select_device (device )
@@ -475,20 +478,19 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
475478 # Exports
476479 f = ['' ] * 10 # exported filenames
477480 warnings .filterwarnings (action = 'ignore' , category = torch .jit .TracerWarning ) # suppress TracerWarning
478- if 'torchscript' in include :
481+ if jit :
479482 f [0 ] = export_torchscript (model , im , file , optimize )
480- if ' engine' in include : # TensorRT required before ONNX
483+ if engine : # TensorRT required before ONNX
481484 f [1 ] = export_engine (model , im , file , train , half , simplify , workspace , verbose )
482- if ( ' onnx' in include ) or ( 'openvino' in include ) : # OpenVINO requires ONNX
485+ if onnx or xml : # OpenVINO requires ONNX
483486 f [2 ] = export_onnx (model , im , file , opset , train , dynamic , simplify )
484- if 'openvino' in include :
487+ if xml : # OpenVINO
485488 f [3 ] = export_openvino (model , im , file )
486- if ' coreml' in include :
489+ if coreml :
487490 _ , f [4 ] = export_coreml (model , im , file )
488491
489492 # TensorFlow Exports
490- if any (tf_exports ):
491- pb , tflite , edgetpu , tfjs = tf_exports [1 :]
493+ if any ((saved_model , pb , tflite , edgetpu , tfjs )):
492494 if int8 or edgetpu : # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
493495 check_requirements (('flatbuffers==1.12' ,)) # required before `import tensorflow`
494496 assert not (tflite and tfjs ), 'TFLite and TF.js models must be exported separately, please pass only one type.'
0 commit comments