-
Notifications
You must be signed in to change notification settings - Fork 31.7k
[EncoderDecoder] Add Cross Attention for GPT2 #6415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[EncoderDecoder] Add Cross Attention for GPT2 #6415
Conversation
7e7c8ad to
ad5af2c
Compare
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks great to me!
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, looks good to me!
Co-authored-by: Sylvain Gugger <[email protected]>
Co-authored-by: Sylvain Gugger <[email protected]>
Codecov Report
@@ Coverage Diff @@
## master #6415 +/- ##
==========================================
- Coverage 79.98% 79.98% -0.01%
==========================================
Files 153 153
Lines 28005 28039 +34
==========================================
+ Hits 22401 22427 +26
- Misses 5604 5612 +8
Continue to review full report at Codecov.
|
* Generation doc * MBartForConditionalGeneration (#6441) * add MBartForConditionalGeneration * style * rebase and fixes * add mbart test in TEST_FILES_WITH_NO_COMMON_TESTS * fix docs * don't ignore mbart * doc * fix mbart fairseq link * put mbart before bart * apply doc suggestions * Use hash to clean the test dirs (#6475) * Use hash to clean the test dirs * Use hash to clean the test dirs * Use hash to clean the test dirs * fix * [EncoderDecoder] Add Cross Attention for GPT2 (#6415) * add cross attention layers for gpt2 * make gpt2 cross attention work * finish bert2gpt2 * add explicit comments * remove attention mask since not yet supported * revert attn mask in pipeline * Update src/transformers/modeling_gpt2.py Co-authored-by: Sylvain Gugger <[email protected]> * Update src/transformers/modeling_encoder_decoder.py Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> * Sort unique_no_split_tokens to make it deterministic (#6461) * change unique_no_split_tokens's type to set * use sorted list instead of set * style * Import accuracy_score (#6480) * Apply suggestions from code review Co-authored-by: Lysandre Debut <[email protected]> * Address comments * Styling * Generation doc * Apply suggestions from code review Co-authored-by: Lysandre Debut <[email protected]> * Address comments * Styling Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Kevin Canwen Xu <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]> Co-authored-by: gijswijnholds <[email protected]> Co-authored-by: Lysandre Debut <[email protected]>
This PR implements Bert2GPT2 by adding cross-attention layers to GPT2.
Note that currently it is not possible to speed up decoder generation with the encoder-decoder framework (by using GPT2's past tensors) since it has to be implemented for all models that are compatible with the encoder/decoder framework (Bert, Roberta) before it can be used within the framework.
All GPT2
RUN_SLOWtests are verified to pass.Future PRs TODO: