Skip to content

Commit ab2006e

Browse files
BART - Fix attention mask device issue on copied models (#18540)
* attempt to fix attn mask device * fix bart `_prepare_decoder_attention_mask` - add correct device - run `make fix-copies` to propagate the fix
1 parent 6bea7b8 commit ab2006e

File tree

9 files changed

+27
-9
lines changed

9 files changed

+27
-9
lines changed

src/transformers/models/bart/modeling_bart.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,9 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
915915

916916
if attention_mask is not None:
917917
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
918-
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
918+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
919+
inputs_embeds.device
920+
)
919921
combined_attention_mask = (
920922
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
921923
)

src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2116,7 +2116,9 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
21162116

21172117
if attention_mask is not None:
21182118
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
2119-
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
2119+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
2120+
inputs_embeds.device
2121+
)
21202122
combined_attention_mask = (
21212123
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
21222124
)

src/transformers/models/blenderbot/modeling_blenderbot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,9 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
854854

855855
if attention_mask is not None:
856856
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
857-
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
857+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
858+
inputs_embeds.device
859+
)
858860
combined_attention_mask = (
859861
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
860862
)

src/transformers/models/blenderbot_small/modeling_blenderbot_small.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,9 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
850850

851851
if attention_mask is not None:
852852
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
853-
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
853+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
854+
inputs_embeds.device
855+
)
854856
combined_attention_mask = (
855857
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
856858
)

src/transformers/models/marian/modeling_marian.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,9 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
860860

861861
if attention_mask is not None:
862862
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
863-
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
863+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
864+
inputs_embeds.device
865+
)
864866
combined_attention_mask = (
865867
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
866868
)

src/transformers/models/mbart/modeling_mbart.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,9 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
913913

914914
if attention_mask is not None:
915915
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
916-
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
916+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
917+
inputs_embeds.device
918+
)
917919
combined_attention_mask = (
918920
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
919921
)

src/transformers/models/opt/modeling_opt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,9 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
534534

535535
if attention_mask is not None:
536536
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
537-
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
537+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
538+
inputs_embeds.device
539+
)
538540
combined_attention_mask = (
539541
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
540542
)

src/transformers/models/pegasus/modeling_pegasus.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,9 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
880880

881881
if attention_mask is not None:
882882
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
883-
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
883+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
884+
inputs_embeds.device
885+
)
884886
combined_attention_mask = (
885887
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
886888
)

src/transformers/models/plbart/modeling_plbart.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,9 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
887887

888888
if attention_mask is not None:
889889
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
890-
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
890+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
891+
inputs_embeds.device
892+
)
891893
combined_attention_mask = (
892894
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
893895
)

0 commit comments

Comments
 (0)