Skip to content

Commit 06ae39b

Browse files
JingyaHuanglewtun
authored andcommitted
Add onnx export cuda support (huggingface#17183)
Co-authored-by: Lysandre Debut <[email protected]> Co-authored-by: lewtun <[email protected]>
1 parent b109bc6 commit 06ae39b

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

src/transformers/models/big_bird/modeling_big_bird.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3099,7 +3099,7 @@ def forward(
30993099
# setting lengths logits to `-inf`
31003100
logits_mask = self.prepare_question_mask(question_lengths, seqlen)
31013101
if token_type_ids is None:
3102-
token_type_ids = torch.ones(logits_mask.size(), dtype=int) - logits_mask
3102+
token_type_ids = torch.ones(logits_mask.size(), dtype=int, device=logits_mask.device) - logits_mask
31033103
logits_mask = logits_mask
31043104
logits_mask[:, 0] = False
31053105
logits_mask.unsqueeze_(2)

src/transformers/onnx/convert.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def export_pytorch(
8686
opset: int,
8787
output: Path,
8888
tokenizer: "PreTrainedTokenizer" = None,
89+
device: str = "cpu",
8990
) -> Tuple[List[str], List[str]]:
9091
"""
9192
Export a PyTorch model to an ONNX Intermediate Representation (IR)
@@ -101,6 +102,8 @@ def export_pytorch(
101102
The version of the ONNX operator set to use.
102103
output (`Path`):
103104
Directory to store the exported ONNX model.
105+
device (`str`, *optional*, defaults to `cpu`):
106+
The device on which the ONNX model will be exported. Either `cpu` or `cuda`.
104107
105108
Returns:
106109
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
@@ -137,6 +140,10 @@ def export_pytorch(
137140
# Ensure inputs match
138141
# TODO: Check when exporting QA we provide "is_pair=True"
139142
model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH)
143+
device = torch.device(device)
144+
if device.type == "cuda" and torch.cuda.is_available():
145+
model.to(device)
146+
model_inputs = dict((k, v.to(device)) for k, v in model_inputs.items())
140147
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
141148
onnx_outputs = list(config.outputs.keys())
142149

@@ -268,6 +275,7 @@ def export(
268275
opset: int,
269276
output: Path,
270277
tokenizer: "PreTrainedTokenizer" = None,
278+
device: str = "cpu",
271279
) -> Tuple[List[str], List[str]]:
272280
"""
273281
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
@@ -283,6 +291,9 @@ def export(
283291
The version of the ONNX operator set to use.
284292
output (`Path`):
285293
Directory to store the exported ONNX model.
294+
device (`str`, *optional*, defaults to `cpu`):
295+
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
296+
export on CUDA devices.
286297
287298
Returns:
288299
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
@@ -294,6 +305,9 @@ def export(
294305
"Please install torch or tensorflow first."
295306
)
296307

308+
if is_tf_available() and isinstance(model, TFPreTrainedModel) and device == "cuda":
309+
raise RuntimeError("`tf2onnx` does not support export on CUDA device.")
310+
297311
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
298312
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
299313
if tokenizer is not None:
@@ -318,7 +332,7 @@ def export(
318332
)
319333

320334
if is_torch_available() and issubclass(type(model), PreTrainedModel):
321-
return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer)
335+
return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer, device=device)
322336
elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
323337
return export_tensorflow(preprocessor, model, config, opset, output, tokenizer=tokenizer)
324338

@@ -359,6 +373,8 @@ def validate_model_outputs(
359373
session = InferenceSession(onnx_model.as_posix(), options, providers=["CPUExecutionProvider"])
360374

361375
# Compute outputs from the reference model
376+
if is_torch_available() and issubclass(type(reference_model), PreTrainedModel):
377+
reference_model.to("cpu")
362378
ref_outputs = reference_model(**reference_model_inputs)
363379
ref_outputs_dict = {}
364380

tests/onnx/test_onnx_v2.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ class OnnxExportTestCaseV2(TestCase):
242242
Integration tests ensuring supported models are correctly exported
243243
"""
244244

245-
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
245+
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"):
246246
from transformers.onnx import export
247247

248248
model_class = FeaturesManager.get_model_class_for_feature(feature)
@@ -273,7 +273,7 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c
273273
with NamedTemporaryFile("w") as output:
274274
try:
275275
onnx_inputs, onnx_outputs = export(
276-
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
276+
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name), device=device
277277
)
278278
validate_model_outputs(
279279
onnx_config,
@@ -294,6 +294,14 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c
294294
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
295295
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
296296

297+
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
298+
@slow
299+
@require_torch
300+
@require_vision
301+
@require_rjieba
302+
def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
303+
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")
304+
297305
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
298306
@slow
299307
@require_torch

0 commit comments

Comments
 (0)