-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Closed
Labels
Description
System Info
transformersversion: 4.21.0.dev0- Platform: Linux-5.4.0-121-generic-x86_64-with-glibc2.29
- Python version: 3.8.10
- Huggingface_hub version: 0.8.1
- PyTorch version (GPU?): 1.11.0+cu113 (True)
Who can help?
Reproduction
Reproduction
from pathlib import Path
from transformers.onnx import export
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers.models.deberta_v2 import DebertaV2OnnxConfig
# load model and tokenizer
onnx_path = Path("results/deberta-v2-model.onnx")
model_ckpt = "microsoft/deberta-v2-xxlarge"
base_model = AutoModel.from_pretrained(model_ckpt)
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
onnx_config = DebertaV2OnnxConfig(base_model.config)
# export to onnx
onnx_inputs, onnx_outputs = export(tokenizer, base_model, onnx_config, onnx_config.default_onnx_opset, onnx_path)Trace Warnings
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:564: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
q_ids = np.arange(0, query_size)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:564: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
q_ids = np.arange(0, query_size)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:565: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
k_ids = np.arange(0, key_size)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:565: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
k_ids = np.arange(0, key_size)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:569: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:698: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
scale = math.sqrt(query_layer.size(-1) * scale_factor)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:752: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:754: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
query_layer.size(0) // self.num_attention_heads, 1, 1
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:773: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:785: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:786: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if key_layer.size(-2) != query_layer.size(-2):
/usr/local/lib/python3.8/dist-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:113: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
There are some operations which are not torch native(numpy, math) lead to the failure of tracing.
e.g. the following graph corresponds to
| query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) |
transformers/src/transformers/models/deberta/modeling_deberta.py
Lines 633 to 634 in d0acc95
| scale = math.sqrt(query_layer.size(-1) * scale_factor) | |
| query_layer = query_layer / scale |
As shown in the graph, the sqrt node has been ignored and the value of scale is treated as a constant.
Expected behavior
Correctly export ONNX model without triggering TraceWarning
For that need to replace numpy and math ops into natively supported torch ops. I can open a PR for the replacement.