Skip to content

Commit 86d0b26

Browse files
authored
Fix matmul inputs dtype (#18585)
1 parent c99e984 commit 86d0b26

File tree

3 files changed

+16
-13
lines changed

3 files changed

+16
-13
lines changed

src/transformers/models/deberta/modeling_deberta.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
""" PyTorch DeBERTa model."""
1616

17-
import math
1817
from collections.abc import Sequence
1918
from typing import Optional, Tuple, Union
2019

@@ -640,8 +639,8 @@ def linear(w, b, x):
640639
qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
641640
qkvb = [None] * 3
642641

643-
q = linear(qkvw[0], qkvb[0], query_states)
644-
k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)]
642+
q = linear(qkvw[0], qkvb[0], torch.tensor(query_states, dtype=qkvw[0].dtype))
643+
k, v = [linear(qkvw[i], qkvb[i], torch.tensor(hidden_states, dtype=qkvw[i].dtype)) for i in range(1, 3)]
645644
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
646645

647646
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
@@ -650,8 +649,8 @@ def linear(w, b, x):
650649
rel_att = None
651650
# Take the dot product between "query" and "key" to get the raw attention scores.
652651
scale_factor = 1 + len(self.pos_att_type)
653-
scale = math.sqrt(query_layer.size(-1) * scale_factor)
654-
query_layer = query_layer / scale
652+
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
653+
query_layer = query_layer / torch.tensor(scale, dtype=query_layer.dtype)
655654
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
656655
if self.relative_attention:
657656
rel_embeddings = self.pos_dropout(rel_embeddings)
@@ -711,13 +710,13 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd
711710
if "p2c" in self.pos_att_type:
712711
pos_query_layer = self.pos_q_proj(rel_embeddings)
713712
pos_query_layer = self.transpose_for_scores(pos_query_layer)
714-
pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor)
713+
pos_query_layer /= torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
715714
if query_layer.size(-2) != key_layer.size(-2):
716715
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
717716
else:
718717
r_pos = relative_pos
719718
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
720-
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
719+
p2c_att = torch.matmul(key_layer, torch.tensor(pos_query_layer.transpose(-1, -2), dtype=key_layer.dtype))
721720
p2c_att = torch.gather(
722721
p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
723722
).transpose(-1, -2)

src/transformers/models/deberta_v2/modeling_deberta_v2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,9 @@ def forward(
717717
if "p2c" in self.pos_att_type:
718718
scale_factor += 1
719719
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
720-
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
720+
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
721+
scale, dtype=query_layer.dtype
722+
)
721723
if self.relative_attention:
722724
rel_embeddings = self.pos_dropout(rel_embeddings)
723725
rel_att = self.disentangled_attention_bias(
@@ -799,7 +801,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_
799801
dim=-1,
800802
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
801803
)
802-
score += c2p_att / scale
804+
score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
803805

804806
# position->content
805807
if "p2c" in self.pos_att_type:
@@ -822,7 +824,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_
822824
dim=-1,
823825
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
824826
).transpose(-1, -2)
825-
score += p2c_att / scale
827+
score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)
826828

827829
return score
828830

src/transformers/models/sew_d/modeling_sew_d.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,9 @@ def forward(
791791
if "p2c" in self.pos_att_type:
792792
scale_factor += 1
793793
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
794-
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
794+
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
795+
scale, dtype=query_layer.dtype
796+
)
795797
if self.relative_attention:
796798
rel_embeddings = self.pos_dropout(rel_embeddings)
797799
rel_att = self.disentangled_attention_bias(
@@ -873,7 +875,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_
873875
dim=-1,
874876
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
875877
)
876-
score += c2p_att / scale
878+
score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
877879

878880
# position->content
879881
if "p2c" in self.pos_att_type:
@@ -896,7 +898,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_
896898
dim=-1,
897899
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
898900
).transpose(-1, -2)
899-
score += p2c_att / scale
901+
score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)
900902

901903
return score
902904

0 commit comments

Comments
 (0)