From fa29764fe8a98551eddab5deecbfd2813fda0454 Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Sat, 23 Jul 2022 22:02:19 +0200 Subject: [PATCH 01/13] Fix critical trace warnings to allow ONNX export --- .../models/deberta_v2/modeling_deberta_v2.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index dd820590b66c..3d5739cbf0af 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -14,7 +14,6 @@ # limitations under the License. """ PyTorch DeBERTa-v2 model.""" -import math from collections.abc import Sequence from typing import Optional, Tuple, Union @@ -535,11 +534,11 @@ def custom_forward(*inputs): def make_log_bucket_position(relative_pos, bucket_size, max_position): - sign = np.sign(relative_pos) + sign = torch.sign(relative_pos) mid = bucket_size // 2 - abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) - log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid - bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) + abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), torch.tensor(mid - 1).type_as(relative_pos), torch.abs(relative_pos)) + log_pos = torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid + bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign) return bucket_pos @@ -561,12 +560,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- `torch.LongTensor`: A tensor with shape [1, query_size, key_size] """ - q_ids = np.arange(0, query_size) - k_ids = np.arange(0, key_size) - rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) + q_ids = torch.arange(0, query_size) + k_ids = torch.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - k_ids.repeat(q_ids.shape[0], 1) if bucket_size > 0 and max_position > 0: rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) - rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) + rel_pos_ids = rel_pos_ids.type(torch.long) rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids.unsqueeze(0) return rel_pos_ids @@ -695,7 +694,7 @@ def forward( scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 - scale = math.sqrt(query_layer.size(-1) * scale_factor) + scale = torch.sqrt(query_layer.size(-1).type(torch.float) * scale_factor) attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) @@ -770,7 +769,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ score = 0 # content->position if "c2p" in self.pos_att_type: - scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) + scale = torch.sqrt(pos_key_layer.size(-1).type(torch.float) * scale_factor) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_att = torch.gather( @@ -782,7 +781,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ # position->content if "p2c" in self.pos_att_type: - scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) + scale = torch.sqrt(pos_query_layer.size(-1).type(torch.float) * scale_factor) if key_layer.size(-2) != query_layer.size(-2): r_pos = build_relative_position( key_layer.size(-2), From c13ee0d5e9688387813d4e74f745fca3f0f1c32c Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Sat, 23 Jul 2022 22:31:38 +0200 Subject: [PATCH 02/13] Force input to `sqrt` to be float type --- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 3d5739cbf0af..98ea21c18e6d 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -694,7 +694,7 @@ def forward( scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 - scale = torch.sqrt(query_layer.size(-1).type(torch.float) * scale_factor) + scale = torch.sqrt(torch.tensor(query_layer.size(-1) * scale_factor, dtype=torch.float)) attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) @@ -769,7 +769,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ score = 0 # content->position if "c2p" in self.pos_att_type: - scale = torch.sqrt(pos_key_layer.size(-1).type(torch.float) * scale_factor) + scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1) * scale_factor, dtype=torch.float)) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_att = torch.gather( @@ -781,7 +781,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ # position->content if "p2c" in self.pos_att_type: - scale = torch.sqrt(pos_query_layer.size(-1).type(torch.float) * scale_factor) + scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1) * scale_factor, dtype=torch.float)) if key_layer.size(-2) != query_layer.size(-2): r_pos = build_relative_position( key_layer.size(-2), From e09ee154aea36f17def4553bff771df789cb4a73 Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Sat, 23 Jul 2022 23:07:50 +0200 Subject: [PATCH 03/13] Cleanup code --- .../models/deberta_v2/modeling_deberta_v2.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 98ea21c18e6d..aa01767163cf 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -536,8 +536,14 @@ def custom_forward(*inputs): def make_log_bucket_position(relative_pos, bucket_size, max_position): sign = torch.sign(relative_pos) mid = bucket_size // 2 - abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), torch.tensor(mid - 1).type_as(relative_pos), torch.abs(relative_pos)) - log_pos = torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid + abs_pos = torch.where( + (relative_pos < mid) & (relative_pos > -mid), + torch.tensor(mid - 1).type_as(relative_pos), + torch.abs(relative_pos), + ) + log_pos = ( + torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid + ) bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign) return bucket_pos From 4800723107052860c3430a8b9cf9506403af4d39 Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Sat, 23 Jul 2022 23:13:04 +0200 Subject: [PATCH 04/13] Remove unused import statement --- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index aa01767163cf..0048a8355dcf 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -17,7 +17,6 @@ from collections.abc import Sequence from typing import Optional, Tuple, Union -import numpy as np import torch import torch.utils.checkpoint from torch import nn From 77987c190c7572138b998f4b8cea89fcf868b712 Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Sat, 23 Jul 2022 23:23:05 +0200 Subject: [PATCH 05/13] Update model sew --- .../models/sew_d/modeling_sew_d.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 8dc210d06ca6..38afeff89e08 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -194,11 +194,17 @@ def compute_num_masked_span(input_length): # Copied from transformers.models.deberta_v2.modeling_deberta_v2.make_log_bucket_position def make_log_bucket_position(relative_pos, bucket_size, max_position): - sign = np.sign(relative_pos) + sign = torch.sign(relative_pos) mid = bucket_size // 2 - abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) - log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid - bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) + abs_pos = torch.where( + (relative_pos < mid) & (relative_pos > -mid), + torch.tensor(mid - 1).type_as(relative_pos), + torch.abs(relative_pos), + ) + log_pos = ( + torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid + ) + bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign) return bucket_pos @@ -221,12 +227,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- `torch.LongTensor`: A tensor with shape [1, query_size, key_size] """ - q_ids = np.arange(0, query_size) - k_ids = np.arange(0, key_size) - rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) + q_ids = torch.arange(0, query_size) + k_ids = torch.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - k_ids.repeat(q_ids.shape[0], 1) if bucket_size > 0 and max_position > 0: rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) - rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) + rel_pos_ids = rel_pos_ids.type(torch.long) rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids.unsqueeze(0) return rel_pos_ids @@ -767,7 +773,7 @@ def forward( scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 - scale = math.sqrt(query_layer.size(-1) * scale_factor) + scale = torch.sqrt(torch.tensor(query_layer.size(-1) * scale_factor, dtype=torch.float)) attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) @@ -842,7 +848,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ score = 0 # content->position if "c2p" in self.pos_att_type: - scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) + scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1) * scale_factor, dtype=torch.float)) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_att = torch.gather( @@ -854,7 +860,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ # position->content if "p2c" in self.pos_att_type: - scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) + scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1) * scale_factor, dtype=torch.float)) if key_layer.size(-2) != query_layer.size(-2): r_pos = build_relative_position( key_layer.size(-2), From d34125245d77580618a2dc9d702298b8cfe65c62 Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Thu, 11 Aug 2022 13:17:09 +0200 Subject: [PATCH 06/13] Small refactor Co-authored-by: Michael Benayoun --- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 0048a8355dcf..8be46852a7c7 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -786,7 +786,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ # position->content if "p2c" in self.pos_att_type: - scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1) * scale_factor, dtype=torch.float)) + scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor) if key_layer.size(-2) != query_layer.size(-2): r_pos = build_relative_position( key_layer.size(-2), From 6267493b7e4f3e2e640b79a67bf95d40a945e2c8 Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Thu, 11 Aug 2022 13:28:51 +0200 Subject: [PATCH 07/13] Use broadcasting instead of repeat --- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 8be46852a7c7..5909191570f6 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -567,7 +567,7 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- """ q_ids = torch.arange(0, query_size) k_ids = torch.arange(0, key_size) - rel_pos_ids = q_ids[:, None] - k_ids.repeat(q_ids.shape[0], 1) + rel_pos_ids = q_ids[:, None] - k_ids[None,:] if bucket_size > 0 and max_position > 0: rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) rel_pos_ids = rel_pos_ids.type(torch.long) From 7c0d6bf6d62a99273b3c326e49c7c0b45ab4d0c8 Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Thu, 11 Aug 2022 13:30:06 +0200 Subject: [PATCH 08/13] Implement suggestion Co-authored-by: Michael Benayoun --- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 5909191570f6..d1ae895b6ab1 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -570,7 +570,7 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- rel_pos_ids = q_ids[:, None] - k_ids[None,:] if bucket_size > 0 and max_position > 0: rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) - rel_pos_ids = rel_pos_ids.type(torch.long) + rel_pos_ids = rel_pos_ids.to(torch.long) rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids.unsqueeze(0) return rel_pos_ids From 2823063b42daab260db06033e54593855e4bce2d Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Thu, 11 Aug 2022 13:34:42 +0200 Subject: [PATCH 09/13] Match deberta v2 changes in sew_d --- src/transformers/models/sew_d/modeling_sew_d.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 38afeff89e08..adffc6056eb2 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -229,10 +229,10 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- """ q_ids = torch.arange(0, query_size) k_ids = torch.arange(0, key_size) - rel_pos_ids = q_ids[:, None] - k_ids.repeat(q_ids.shape[0], 1) + rel_pos_ids = q_ids[:, None] - k_ids[None,:] if bucket_size > 0 and max_position > 0: rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) - rel_pos_ids = rel_pos_ids.type(torch.long) + rel_pos_ids = rel_pos_ids.to(torch.long) rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids.unsqueeze(0) return rel_pos_ids @@ -773,7 +773,7 @@ def forward( scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 - scale = torch.sqrt(torch.tensor(query_layer.size(-1) * scale_factor, dtype=torch.float)) + scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor) attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) From 78f1d92d5dabd3e75d3cda05dfc9edefc9e92b73 Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Thu, 11 Aug 2022 14:03:13 +0200 Subject: [PATCH 10/13] Improve code quality --- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index d1ae895b6ab1..5eea0700a5cc 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -567,7 +567,7 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- """ q_ids = torch.arange(0, query_size) k_ids = torch.arange(0, key_size) - rel_pos_ids = q_ids[:, None] - k_ids[None,:] + rel_pos_ids = q_ids[:, None] - k_ids[None, :] if bucket_size > 0 and max_position > 0: rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) rel_pos_ids = rel_pos_ids.to(torch.long) From e5f8309e5a236543cabbc4efb941cdffb815627d Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Thu, 11 Aug 2022 14:03:48 +0200 Subject: [PATCH 11/13] Update code quality --- src/transformers/models/sew_d/modeling_sew_d.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index adffc6056eb2..2c39499a70af 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -229,7 +229,7 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- """ q_ids = torch.arange(0, query_size) k_ids = torch.arange(0, key_size) - rel_pos_ids = q_ids[:, None] - k_ids[None,:] + rel_pos_ids = q_ids[:, None] - k_ids[None, :] if bucket_size > 0 and max_position > 0: rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) rel_pos_ids = rel_pos_ids.to(torch.long) @@ -773,7 +773,7 @@ def forward( scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 - scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor) + scale = torch.sqrt(torch.tensor(query_layer.size(-1) * scale_factor, dtype=torch.float)) attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) @@ -860,7 +860,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ # position->content if "p2c" in self.pos_att_type: - scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1) * scale_factor, dtype=torch.float)) + scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor) if key_layer.size(-2) != query_layer.size(-2): r_pos = build_relative_position( key_layer.size(-2), @@ -1107,7 +1107,6 @@ def forward( rel_embeddings = self.get_rel_embedding() output_states = next_kv for i, layer_module in enumerate(self.layer): - if output_hidden_states: all_hidden_states = all_hidden_states + (output_states,) @@ -1571,7 +1570,6 @@ def forward( loss = None if labels is not None: - if labels.max() >= self.config.vocab_size: raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") From d07b4196ca362d1f063e65bf91be518fed201fc9 Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Thu, 11 Aug 2022 14:22:35 +0200 Subject: [PATCH 12/13] Consistency of small refactor --- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 5eea0700a5cc..f814473f64d9 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -699,7 +699,7 @@ def forward( scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 - scale = torch.sqrt(torch.tensor(query_layer.size(-1) * scale_factor, dtype=torch.float)) + scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) @@ -774,7 +774,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ score = 0 # content->position if "c2p" in self.pos_att_type: - scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1) * scale_factor, dtype=torch.float)) + scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_att = torch.gather( From 08d305e7db46a74e438609bdd96d0ce846e83e3b Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Thu, 11 Aug 2022 14:23:04 +0200 Subject: [PATCH 13/13] Match changes in sew_d --- src/transformers/models/sew_d/modeling_sew_d.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 2c39499a70af..271fd3a9bccf 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -773,7 +773,7 @@ def forward( scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 - scale = torch.sqrt(torch.tensor(query_layer.size(-1) * scale_factor, dtype=torch.float)) + scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) @@ -848,7 +848,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ score = 0 # content->position if "c2p" in self.pos_att_type: - scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1) * scale_factor, dtype=torch.float)) + scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_att = torch.gather( @@ -1107,6 +1107,7 @@ def forward( rel_embeddings = self.get_rel_embedding() output_states = next_kv for i, layer_module in enumerate(self.layer): + if output_hidden_states: all_hidden_states = all_hidden_states + (output_states,) @@ -1570,6 +1571,7 @@ def forward( loss = None if labels is not None: + if labels.max() >= self.config.vocab_size: raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")