-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Encoder decoder] Add cuda graph support during decoding for encoder-decoder models #7631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 88 commits
Commits
Show all changes
91 commits
Select commit
Hold shift + click to select a range
5650b95
Merge pull request #1 from vllm-project/main
sroy745 8f36146
Merge branch 'vllm-project:main' into main
sroy745 9e75057
Merge branch 'vllm-project:main' into main
sroy745 db2c679
Merge branch 'vllm-project:main' into main
sroy745 8d7512c
Merge branch 'vllm-project:main' into main
sroy745 1473f74
Merge branch 'vllm-project:main' into main
sroy745 4013e1a
Merge branch 'vllm-project:main' into main
sroy745 2dbdd78
Merge branch 'vllm-project:main' into main
sroy745 b3575e9
Merge branch 'vllm-project:main' into main
sroy745 94b0d43
Merge branch 'vllm-project:main' into main
sroy745 fa8fedf
Merge branch 'vllm-project:main' into main
sroy745 6ed96b4
Merge branch 'vllm-project:main' into main
sroy745 b71c533
Merge branch 'vllm-project:main' into main
sroy745 57babef
Merge branch 'vllm-project:main' into main
sroy745 4b19bac
Merge branch 'vllm-project:main' into main
sroy745 eb7a1c4
Merge branch 'vllm-project:main' into main
sroy745 7e2c87e
Merge branch 'vllm-project:main' into main
sroy745 6212d5f
Merge branch 'vllm-project:main' into main
sroy745 5491438
Merge branch 'vllm-project:main' into main
sroy745 68e080a
Merge branch 'vllm-project:main' into main
sroy745 55e4332
Merge branch 'vllm-project:main' into main
sroy745 c08dc0a
Add cuda graph support during decoding for encoder-decoder models
sroy745 fbc8837
Add logic for CUDA Graph capture
sroy745 b3b4e4a
Remove extra line
sroy745 bc9c32e
Remove debugs
sroy745 532eb48
Merge branch 'vllm-project:main' into main
sroy745 33342fb
Add comments
sroy745 b3425d6
Merge branch 'main' into enc-dec-cuda-graph
sroy745 7cea056
Merge branch 'vllm-project:main' into main
sroy745 47ba15b
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 bf70ceb
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 cef7273
Merge
sroy745 599cd6b
Move logic to backend/utils.py
sroy745 a9ca02f
Fix import
sroy745 6b09ee8
Fix formatting
sroy745 c199b50
Fix test documentation
sroy745 539b10e
Remove debug stmt
sroy745 727b3f2
Add a new test
sroy745 e2e16cf
Fix Batch Size
sroy745 1fb7cc6
Fix comments
sroy745 79e3928
Fix formatting
sroy745 09f9741
Fix test to run with CUDA Graph
sroy745 eb52df8
fix format
sroy745 f27e84a
Dummy commit
sroy745 ada2d05
Dummy commit
sroy745 185e056
Merge branch 'vllm-project:main' into main
sroy745 3c75775
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 7b94a35
Format
sroy745 e2be95f
Merge branch 'vllm-project:main' into main
sroy745 783b37b
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 4b72b1f
Addressing comments
sroy745 2ed5473
Merge branch 'vllm-project:main' into main
sroy745 efa4714
Merge branch 'vllm-project:main' into main
sroy745 105f6c3
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 525c541
Addressing comments
sroy745 9f3ad3f
Merge branch 'main' into enc-dec-cuda-graph
sroy745 785dfc5
Fix format
sroy745 758c8d2
Add tests
sroy745 6f61f97
fix format
sroy745 f93ac1c
Add comments
sroy745 e408a00
Dummy fix
sroy745 fb87d34
Merge branch 'vllm-project:main' into main
sroy745 61af6ed
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 bf7b4fc
Dummy commit
sroy745 a814a0b
Fix format
sroy745 5419e49
Merge branch 'vllm-project:main' into main
sroy745 2ca1c2f
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 4617e39
Adding back the assertion
sroy745 9ba12f8
Merge branch 'vllm-project:main' into main
sroy745 25cef3d
Merge branch 'vllm-project:main' into main
sroy745 0d8b6bc
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 2ab85b8
Dummy commit
sroy745 0592dc0
Dummy commit
sroy745 9d4cd09
Merge branch 'vllm-project:main' into main
sroy745 0b3c9e2
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 c8eb247
Merge branch 'main' into enc-dec-cuda-graph
sroy745 12f8312
Dummy commit
sroy745 0ea4479
Dummy commit
sroy745 c48cacb
Merge branch 'vllm-project:main' into main
sroy745 c42c399
Merge branch 'vllm-project:main' into main
sroy745 3d13e43
Merge branch 'vllm-project:main' into main
sroy745 ed68558
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 7417751
Dummy
sroy745 12c4af4
Dummy
sroy745 7cd2b43
Dummy Commit to rerun tests
sroy745 e4338ce
Dummy Commit to rerun tests
sroy745 7479775
Merge branch 'vllm-project:main' into main
sroy745 9c54a60
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 df9b966
Merge branch 'vllm-project:main' into main
sroy745 7cc1a49
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 3fb360c
Address comments
sroy745 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| """E2E tests to verify the correctness of the encoder-decoder framework | ||
|
|
||
| Run `pytest tests/encoder_decoder/test_e2e_correctness.py`. | ||
| """ | ||
| from typing import List, Optional, Tuple | ||
|
|
||
| from vllm.utils import is_cpu | ||
|
|
||
| if not is_cpu(): | ||
| # CPU backend is not currently supported with encoder/decoder models | ||
| # skip test definitions entirely to avoid importing GPU kernel libs | ||
| # (xFormers, etc.) | ||
|
|
||
| import pytest | ||
| from transformers import AutoModelForSeq2SeqLM | ||
|
|
||
| from vllm.sequence import SampleLogprobs | ||
|
|
||
| from ..conftest import DecoderPromptType | ||
| from ..models.utils import check_logprobs_close | ||
|
|
||
| def vllm_to_hf_output( | ||
| vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], | ||
| decoder_prompt_type: DecoderPromptType, | ||
| ): | ||
| """Sanitize vllm output to be comparable with hf output.""" | ||
| output_ids, output_str, out_logprobs = vllm_output | ||
|
|
||
| hf_output_str = output_str + "</s>" | ||
| if decoder_prompt_type == DecoderPromptType.NONE: | ||
| hf_output_str = "<s>" + hf_output_str | ||
|
|
||
| return output_ids, hf_output_str, out_logprobs | ||
|
|
||
| @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) | ||
| @pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
| @pytest.mark.parametrize("max_tokens", [128]) | ||
| @pytest.mark.parametrize("num_logprobs", [5]) | ||
| @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) | ||
| @pytest.mark.parametrize("enforce_eager", [True, False]) | ||
| def test_encoder_decoder_e2e( | ||
| hf_runner, | ||
| vllm_runner, | ||
| example_encoder_decoder_prompts, | ||
| model: str, | ||
| dtype: str, | ||
| max_tokens: int, | ||
| num_logprobs: int, | ||
| decoder_prompt_type: DecoderPromptType, | ||
| enforce_eager: bool, | ||
| ) -> None: | ||
| ''' | ||
| End-to-End (E2E) test for the encoder-decoder framework. | ||
| This test evaluates the encoder-decoder functionality using the BART | ||
| model. We compare the outputs of the Hugging Face and vLLM | ||
| implementations to ensure that both implementations produce consistent | ||
| and correct results. | ||
| ''' | ||
| test_case_prompts = example_encoder_decoder_prompts[ | ||
| decoder_prompt_type] | ||
|
|
||
| # Configuration settings for HF baseline | ||
| hf_kwargs = { | ||
| "top_k": None, | ||
| "num_beams": 1, | ||
| "repetition_penalty": 1.0, | ||
| "top_p": 1.0, | ||
| "length_penalty": 1.0, | ||
| "early_stopping": False, | ||
| "no_repeat_ngram_size": None, | ||
| "min_length": 0 | ||
| } | ||
|
|
||
| with hf_runner(model, dtype=dtype, | ||
| auto_cls=AutoModelForSeq2SeqLM) as hf_model: | ||
| hf_outputs = ( | ||
| hf_model.generate_encoder_decoder_greedy_logprobs_limit( | ||
| test_case_prompts, | ||
| max_tokens, | ||
| num_logprobs, | ||
| **hf_kwargs, | ||
| )) | ||
| with vllm_runner(model, dtype=dtype, | ||
| enforce_eager=enforce_eager) as vllm_model: | ||
| vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( | ||
| test_case_prompts, max_tokens, num_logprobs) | ||
|
|
||
| hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE | ||
| else 0) | ||
|
|
||
| check_logprobs_close( | ||
| outputs_0_lst=hf_outputs, | ||
| outputs_1_lst=[ | ||
| vllm_to_hf_output(vllm_output, decoder_prompt_type) | ||
| for vllm_output in vllm_outputs | ||
| ], | ||
| name_0="hf", | ||
| name_1="vllm", | ||
| num_outputs_0_skip_tokens=hf_skip_tokens, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.