Skip to content

Exported DeBERTa ONNX model is incorrect #18199

@JingyaHuang

Description

@JingyaHuang

System Info

  • transformers version: 4.21.0.dev0
  • Platform: Linux-5.4.0-121-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.8.1
  • PyTorch version (GPU?): 1.11.0+cu113 (True)

Who can help?

@LysandreJik

Reproduction

Reproduction

from pathlib import Path
from transformers.onnx import export
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers.models.deberta_v2 import DebertaV2OnnxConfig

# load model and tokenizer
onnx_path = Path("results/deberta-v2-model.onnx")
model_ckpt = "microsoft/deberta-v2-xxlarge"
base_model = AutoModel.from_pretrained(model_ckpt)
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
onnx_config = DebertaV2OnnxConfig(base_model.config)

# export to onnx
onnx_inputs, onnx_outputs = export(tokenizer, base_model, onnx_config, onnx_config.default_onnx_opset, onnx_path)

Trace Warnings

/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:564: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  q_ids = np.arange(0, query_size)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:564: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  q_ids = np.arange(0, query_size)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:565: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  k_ids = np.arange(0, key_size)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:565: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  k_ids = np.arange(0, key_size)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:569: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:698: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  scale = math.sqrt(query_layer.size(-1) * scale_factor)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:752: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:754: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  query_layer.size(0) // self.num_attention_heads, 1, 1
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:773: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:785: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:786: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if key_layer.size(-2) != query_layer.size(-2):
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:113: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))

There are some operations which are not torch native(numpy, math) lead to the failure of tracing.
e.g. the following graph corresponds to

query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])

scale = math.sqrt(query_layer.size(-1) * scale_factor)
query_layer = query_layer / scale

image

As shown in the graph, the sqrt node has been ignored and the value of scale is treated as a constant.

Expected behavior

Correctly export ONNX model without triggering TraceWarning

For that need to replace numpy and math ops into natively supported torch ops. I can open a PR for the replacement.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions