Skip to content

Commit 566b083

Browse files
TFMarian, TFMbart, TFPegasus, TFBlenderbot (#7987)
* Start plumbing * Marian close * Small stubs for all children * Fixed bart * marian working * pegasus test is good, but failing * Checkin tests * More model files * Subtle marian, pegasus integration test failures * Works well * rm print * boom boom * Still failing model2doc * merge master * Equivalence test failing, all others fixed * cleanup * Fix embed_scale * Cleanup marian pipeline test * Undo extra changes * Smaller delta * Cleanup model testers * undo delta * fix tests import structure * cross test decorator * Cleaner set_weights * Respect authorized_unexpected_keys * No warnings * No warnings * style * Nest tf import * black * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * functional dropout * fixup * Fixup * style_doc * embs * shape list * delete slow force_token_id_to_be_generated func * fixup Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
1 parent 6279072 commit 566b083

20 files changed

Lines changed: 1063 additions & 106 deletions

docs/source/model_doc/blenderbot.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,12 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward`
9595

9696
.. autoclass:: transformers.BlenderbotForConditionalGeneration
9797
:members:
98+
99+
100+
TFBlenderbotForConditionalGeneration
101+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
102+
103+
See :obj:`transformers.TFBartForConditionalGeneration` for arguments to `forward` and `generate`
104+
105+
.. autoclass:: transformers.TFBlenderbotForConditionalGeneration
106+
:members:

docs/source/model_doc/marian.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,9 @@ MarianMTModel
129129
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
130130

131131
.. autoclass:: transformers.MarianMTModel
132+
133+
134+
TFMarianMTModel
135+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
136+
137+
.. autoclass:: transformers.TFMarianMTModel

docs/source/model_doc/mbart.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,11 @@ MBartForConditionalGeneration
7979
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8080

8181
.. autoclass:: transformers.MBartForConditionalGeneration
82-
:members: forward
82+
:members:
83+
84+
85+
TFMBartForConditionalGeneration
86+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
87+
88+
.. autoclass:: transformers.TFMBartForConditionalGeneration
89+
:members:

docs/source/model_doc/pegasus.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,9 @@ PegasusForConditionalGeneration
9595
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9696

9797
.. autoclass:: transformers.PegasusForConditionalGeneration
98+
99+
100+
TFPegasusForConditionalGeneration
101+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
102+
103+
.. autoclass:: transformers.TFPegasusForConditionalGeneration

src/transformers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,7 @@
670670
TFBertModel,
671671
TFBertPreTrainedModel,
672672
)
673+
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
673674
from .modeling_tf_camembert import (
674675
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
675676
TFCamembertForMaskedLM,
@@ -750,6 +751,8 @@
750751
TFLxmertPreTrainedModel,
751752
TFLxmertVisualFeatureEncoder,
752753
)
754+
from .modeling_tf_marian import TFMarianMTModel
755+
from .modeling_tf_mbart import TFMBartForConditionalGeneration
753756
from .modeling_tf_mobilebert import (
754757
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
755758
TFMobileBertForMaskedLM,
@@ -771,6 +774,7 @@
771774
TFOpenAIGPTModel,
772775
TFOpenAIGPTPreTrainedModel,
773776
)
777+
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration
774778
from .modeling_tf_roberta import (
775779
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
776780
TFRobertaForMaskedLM,

src/transformers/modeling_bart.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,6 @@ def forward(
427427
output_attentions=False,
428428
):
429429
residual = x
430-
431430
if layer_state is None:
432431
layer_state = {}
433432
if self.normalize_before:
@@ -447,7 +446,7 @@ def forward(
447446
if not self.normalize_before:
448447
x = self.self_attn_layer_norm(x)
449448

450-
# Cross attention
449+
# Cross-Attention Block
451450
residual = x
452451
assert self.encoder_attn.cache_key != self.self_attn.cache_key
453452
if self.normalize_before:
@@ -628,7 +627,6 @@ def forward(
628627
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
629628

630629
next_cache = next_decoder_cache if use_cache else None
631-
632630
if not return_dict:
633631
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
634632
return BaseModelOutputWithPast(

src/transformers/modeling_tf_auto.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
XLNetConfig,
4242
replace_list_option_in_docstrings,
4343
)
44+
from .configuration_blenderbot import BlenderbotConfig
45+
from .configuration_marian import MarianConfig
46+
from .configuration_mbart import MBartConfig
47+
from .configuration_pegasus import PegasusConfig
4448
from .configuration_utils import PretrainedConfig
4549
from .file_utils import add_start_docstrings
4650
from .modeling_tf_albert import (
@@ -63,6 +67,7 @@
6367
TFBertLMHeadModel,
6468
TFBertModel,
6569
)
70+
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
6671
from .modeling_tf_camembert import (
6772
TFCamembertForMaskedLM,
6873
TFCamembertForMultipleChoice,
@@ -108,6 +113,8 @@
108113
)
109114
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
110115
from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel
116+
from .modeling_tf_marian import TFMarianMTModel
117+
from .modeling_tf_mbart import TFMBartForConditionalGeneration
111118
from .modeling_tf_mobilebert import (
112119
TFMobileBertForMaskedLM,
113120
TFMobileBertForMultipleChoice,
@@ -118,6 +125,7 @@
118125
TFMobileBertModel,
119126
)
120127
from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
128+
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration
121129
from .modeling_tf_roberta import (
122130
TFRobertaForMaskedLM,
123131
TFRobertaForMultipleChoice,
@@ -210,6 +218,7 @@
210218
(T5Config, TFT5ForConditionalGeneration),
211219
(DistilBertConfig, TFDistilBertForMaskedLM),
212220
(AlbertConfig, TFAlbertForMaskedLM),
221+
(MarianConfig, TFMarianMTModel),
213222
(BartConfig, TFBartForConditionalGeneration),
214223
(CamembertConfig, TFCamembertForMaskedLM),
215224
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
@@ -261,8 +270,16 @@
261270
]
262271
)
263272

273+
264274
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
265-
[(T5Config, TFT5ForConditionalGeneration), (BartConfig, TFBartForConditionalGeneration)]
275+
[
276+
(T5Config, TFT5ForConditionalGeneration),
277+
(MarianConfig, TFMarianMTModel),
278+
(MBartConfig, TFMBartForConditionalGeneration),
279+
(PegasusConfig, TFPegasusForConditionalGeneration),
280+
(BlenderbotConfig, TFBlenderbotForConditionalGeneration),
281+
(BartConfig, TFBartForConditionalGeneration),
282+
]
266283
)
267284

268285
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(

0 commit comments

Comments
 (0)