Skip to content
30 changes: 17 additions & 13 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
# limitations under the License.
""" PyTorch DeBERTa-v2 model."""

import math
from collections.abc import Sequence
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
Expand Down Expand Up @@ -535,11 +533,17 @@ def custom_forward(*inputs):


def make_log_bucket_position(relative_pos, bucket_size, max_position):
sign = np.sign(relative_pos)
sign = torch.sign(relative_pos)
mid = bucket_size // 2
abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos))
log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid
bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int)
abs_pos = torch.where(
(relative_pos < mid) & (relative_pos > -mid),
torch.tensor(mid - 1).type_as(relative_pos),
torch.abs(relative_pos),
)
log_pos = (
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
)
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
return bucket_pos


Expand All @@ -561,12 +565,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
`torch.LongTensor`: A tensor with shape [1, query_size, key_size]

"""
q_ids = np.arange(0, query_size)
k_ids = np.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1))
q_ids = torch.arange(0, query_size)
k_ids = torch.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
rel_pos_ids = rel_pos_ids.to(torch.long)
rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = rel_pos_ids.unsqueeze(0)
return rel_pos_ids
Expand Down Expand Up @@ -695,7 +699,7 @@ def forward(
scale_factor += 1
if "p2c" in self.pos_att_type:
scale_factor += 1
scale = math.sqrt(query_layer.size(-1) * scale_factor)
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
Expand Down Expand Up @@ -770,7 +774,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_
score = 0
# content->position
if "c2p" in self.pos_att_type:
scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch.gather(
Expand All @@ -782,7 +786,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_

# position->content
if "p2c" in self.pos_att_type:
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if key_layer.size(-2) != query_layer.size(-2):
r_pos = build_relative_position(
key_layer.size(-2),
Expand Down
28 changes: 17 additions & 11 deletions src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,17 @@ def compute_num_masked_span(input_length):

# Copied from transformers.models.deberta_v2.modeling_deberta_v2.make_log_bucket_position
def make_log_bucket_position(relative_pos, bucket_size, max_position):
sign = np.sign(relative_pos)
sign = torch.sign(relative_pos)
mid = bucket_size // 2
abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos))
log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid
bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int)
abs_pos = torch.where(
(relative_pos < mid) & (relative_pos > -mid),
torch.tensor(mid - 1).type_as(relative_pos),
torch.abs(relative_pos),
)
log_pos = (
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
)
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
return bucket_pos


Expand All @@ -221,12 +227,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
`torch.LongTensor`: A tensor with shape [1, query_size, key_size]

"""
q_ids = np.arange(0, query_size)
k_ids = np.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1))
q_ids = torch.arange(0, query_size)
k_ids = torch.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
rel_pos_ids = rel_pos_ids.to(torch.long)
rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = rel_pos_ids.unsqueeze(0)
return rel_pos_ids
Expand Down Expand Up @@ -767,7 +773,7 @@ def forward(
scale_factor += 1
if "p2c" in self.pos_att_type:
scale_factor += 1
scale = math.sqrt(query_layer.size(-1) * scale_factor)
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
Expand Down Expand Up @@ -842,7 +848,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_
score = 0
# content->position
if "c2p" in self.pos_att_type:
scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch.gather(
Expand All @@ -854,7 +860,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_

# position->content
if "p2c" in self.pos_att_type:
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if key_layer.size(-2) != query_layer.size(-2):
r_pos = build_relative_position(
key_layer.size(-2),
Expand Down