1414# limitations under the License.
1515""" PyTorch DeBERTa model."""
1616
17- import math
1817from collections .abc import Sequence
1918from typing import Optional , Tuple , Union
2019
@@ -640,8 +639,8 @@ def linear(w, b, x):
640639 qkvw = [torch .cat ([ws [i * 3 + k ] for i in range (self .num_attention_heads )], dim = 0 ) for k in range (3 )]
641640 qkvb = [None ] * 3
642641
643- q = linear (qkvw [0 ], qkvb [0 ], query_states )
644- k , v = [linear (qkvw [i ], qkvb [i ], hidden_states ) for i in range (1 , 3 )]
642+ q = linear (qkvw [0 ], qkvb [0 ], torch . tensor ( query_states , dtype = qkvw [ 0 ]. dtype ) )
643+ k , v = [linear (qkvw [i ], qkvb [i ], torch . tensor ( hidden_states , dtype = qkvw [ i ]. dtype ) ) for i in range (1 , 3 )]
645644 query_layer , key_layer , value_layer = [self .transpose_for_scores (x ) for x in [q , k , v ]]
646645
647646 query_layer = query_layer + self .transpose_for_scores (self .q_bias [None , None , :])
@@ -650,8 +649,8 @@ def linear(w, b, x):
650649 rel_att = None
651650 # Take the dot product between "query" and "key" to get the raw attention scores.
652651 scale_factor = 1 + len (self .pos_att_type )
653- scale = math .sqrt (query_layer .size (- 1 ) * scale_factor )
654- query_layer = query_layer / scale
652+ scale = torch .sqrt (torch . tensor ( query_layer .size (- 1 ), dtype = torch . float ) * scale_factor )
653+ query_layer = query_layer / torch . tensor ( scale , dtype = query_layer . dtype )
655654 attention_scores = torch .matmul (query_layer , key_layer .transpose (- 1 , - 2 ))
656655 if self .relative_attention :
657656 rel_embeddings = self .pos_dropout (rel_embeddings )
@@ -711,13 +710,13 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd
711710 if "p2c" in self .pos_att_type :
712711 pos_query_layer = self .pos_q_proj (rel_embeddings )
713712 pos_query_layer = self .transpose_for_scores (pos_query_layer )
714- pos_query_layer /= math .sqrt (pos_query_layer .size (- 1 ) * scale_factor )
713+ pos_query_layer /= torch .sqrt (torch . tensor ( pos_query_layer .size (- 1 ), dtype = torch . float ) * scale_factor )
715714 if query_layer .size (- 2 ) != key_layer .size (- 2 ):
716715 r_pos = build_relative_position (key_layer .size (- 2 ), key_layer .size (- 2 ), query_layer .device )
717716 else :
718717 r_pos = relative_pos
719718 p2c_pos = torch .clamp (- r_pos + att_span , 0 , att_span * 2 - 1 )
720- p2c_att = torch .matmul (key_layer , pos_query_layer .transpose (- 1 , - 2 ))
719+ p2c_att = torch .matmul (key_layer , torch . tensor ( pos_query_layer .transpose (- 1 , - 2 ), dtype = key_layer . dtype ))
721720 p2c_att = torch .gather (
722721 p2c_att , dim = - 1 , index = p2c_dynamic_expand (p2c_pos , query_layer , key_layer )
723722 ).transpose (- 1 , - 2 )
0 commit comments