@@ -86,6 +86,7 @@ def export_pytorch(
8686 opset : int ,
8787 output : Path ,
8888 tokenizer : "PreTrainedTokenizer" = None ,
89+ device : str = "cpu" ,
8990) -> Tuple [List [str ], List [str ]]:
9091 """
9192 Export a PyTorch model to an ONNX Intermediate Representation (IR)
@@ -101,6 +102,8 @@ def export_pytorch(
101102 The version of the ONNX operator set to use.
102103 output (`Path`):
103104 Directory to store the exported ONNX model.
105+ device (`str`, *optional*, defaults to `cpu`):
106+ The device on which the ONNX model will be exported. Either `cpu` or `cuda`.
104107
105108 Returns:
106109 `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
@@ -137,6 +140,10 @@ def export_pytorch(
137140 # Ensure inputs match
138141 # TODO: Check when exporting QA we provide "is_pair=True"
139142 model_inputs = config .generate_dummy_inputs (preprocessor , framework = TensorType .PYTORCH )
143+ device = torch .device (device )
144+ if device .type == "cuda" and torch .cuda .is_available ():
145+ model .to (device )
146+ model_inputs = dict ((k , v .to (device )) for k , v in model_inputs .items ())
140147 inputs_match , matched_inputs = ensure_model_and_config_inputs_match (model , model_inputs .keys ())
141148 onnx_outputs = list (config .outputs .keys ())
142149
@@ -268,6 +275,7 @@ def export(
268275 opset : int ,
269276 output : Path ,
270277 tokenizer : "PreTrainedTokenizer" = None ,
278+ device : str = "cpu" ,
271279) -> Tuple [List [str ], List [str ]]:
272280 """
273281 Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
@@ -283,6 +291,9 @@ def export(
283291 The version of the ONNX operator set to use.
284292 output (`Path`):
285293 Directory to store the exported ONNX model.
294+ device (`str`, *optional*, defaults to `cpu`):
295+ The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
296+ export on CUDA devices.
286297
287298 Returns:
288299 `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
@@ -294,6 +305,9 @@ def export(
294305 "Please install torch or tensorflow first."
295306 )
296307
308+ if is_tf_available () and isinstance (model , TFPreTrainedModel ) and device == "cuda" :
309+ raise RuntimeError ("`tf2onnx` does not support export on CUDA device." )
310+
297311 if isinstance (preprocessor , PreTrainedTokenizerBase ) and tokenizer is not None :
298312 raise ValueError ("You cannot provide both a tokenizer and a preprocessor to export the model." )
299313 if tokenizer is not None :
@@ -318,7 +332,7 @@ def export(
318332 )
319333
320334 if is_torch_available () and issubclass (type (model ), PreTrainedModel ):
321- return export_pytorch (preprocessor , model , config , opset , output , tokenizer = tokenizer )
335+ return export_pytorch (preprocessor , model , config , opset , output , tokenizer = tokenizer , device = device )
322336 elif is_tf_available () and issubclass (type (model ), TFPreTrainedModel ):
323337 return export_tensorflow (preprocessor , model , config , opset , output , tokenizer = tokenizer )
324338
@@ -359,6 +373,8 @@ def validate_model_outputs(
359373 session = InferenceSession (onnx_model .as_posix (), options , providers = ["CPUExecutionProvider" ])
360374
361375 # Compute outputs from the reference model
376+ if is_torch_available () and issubclass (type (reference_model ), PreTrainedModel ):
377+ reference_model .to ("cpu" )
362378 ref_outputs = reference_model (** reference_model_inputs )
363379 ref_outputs_dict = {}
364380
0 commit comments