-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Core] Support full cuda graph in v1 #16072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
a7a16df
6aa65db
c338c99
15edc9b
f9d41b4
bdb1747
4700afd
e55d023
353ea66
1dcdb37
08b8d6a
aa9e4b6
7e821e0
0d1a796
166e6a6
5deacad
6b523ac
3cfd971
59e52e6
22fa9df
659d9b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import contextlib | ||
| import os | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.config import CompilationConfig | ||
|
|
||
| MODEL = "Qwen/Qwen2-1.5B-Instruct" | ||
|
|
||
|
|
||
| @contextlib.contextmanager | ||
| def temporary_environ(env_vars): | ||
| """ | ||
| Temporarily set environment variables and restore them afterward. | ||
| We have to do this vs monkeypatch because monkeypatch doesn't work | ||
| with "module" scoped fixtures. | ||
| """ | ||
| original_env = {k: os.environ.get(k) for k in env_vars} | ||
| try: | ||
| os.environ.update(env_vars) | ||
| yield | ||
| finally: | ||
| for k, v in original_env.items(): | ||
| if v is None: | ||
| os.environ.pop(k, None) | ||
| else: | ||
| os.environ[k] = v | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def full_cudagraph_llm(): | ||
| with temporary_environ({ | ||
| "VLLM_USE_V1": "1", | ||
| "VLLM_FLASH_ATTN_VERSION": "3" | ||
| }): | ||
| return LLM(model=MODEL, | ||
| gpu_memory_utilization=0.2, | ||
| compilation_config=CompilationConfig(full_cuda_graph=True)) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def piecewise_llm(): | ||
| with temporary_environ({ | ||
| "VLLM_USE_V1": "1", | ||
| "VLLM_FLASH_ATTN_VERSION": "3" | ||
| }): | ||
| return LLM(model=MODEL, | ||
| gpu_memory_utilization=0.5, | ||
| compilation_config=CompilationConfig()) | ||
|
|
||
|
|
||
| def generate_text(llm: LLM, batch_size: int, max_tokens: int): | ||
| prompts = ["Hi my name is"] * batch_size | ||
| sampling_params = SamplingParams(temperature=0.0, | ||
| max_tokens=max_tokens, | ||
| top_p=0.95) | ||
|
|
||
| return llm.generate(prompts, sampling_params) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), | ||
| (16, 10), (25, 10), | ||
| (32, 10), (45, 10), | ||
| (64, 10), (8, 5), | ||
| (8, 20), (8, 200)]) | ||
| def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, | ||
| piecewise_llm): | ||
| """ | ||
| Load full cudagraph model and piecewise model once, and at the same time to | ||
| reuse them across various test cases. | ||
|
|
||
| Test various batch sizes and max_tokens to ensure that the full cudagraph | ||
| compilation works for padded cases too. | ||
| """ | ||
| piecewise_responses = generate_text(piecewise_llm, | ||
| batch_size=batch_size, | ||
| max_tokens=max_tokens) | ||
| full_cudagraph_responses = generate_text(full_cudagraph_llm, | ||
| batch_size=batch_size, | ||
| max_tokens=max_tokens) | ||
|
|
||
| # Check that all responses are the same | ||
| for i in range(len(piecewise_responses)): | ||
| assert piecewise_responses[i].outputs[ | ||
| 0].text == full_cudagraph_responses[i].outputs[0].text | ||
|
|
||
|
|
||
| def test_full_cudagraph_with_invalid_backend(): | ||
| with temporary_environ({ | ||
| "VLLM_USE_V1": "1", | ||
| "VLLM_FLASH_ATTN_VERSION": | ||
| "2" #FA2 not supported with full_cuda_graph | ||
| }), pytest.raises(RuntimeError): | ||
| LLM(model=MODEL, | ||
| compilation_config=CompilationConfig(full_cuda_graph=True)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -294,15 +294,12 @@ def reorder_batch(self, input_batch: "InputBatch", | |
| def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, | ||
| common_prefix_len: int): | ||
| max_seq_len = self.runner.seq_lens_np[:num_reqs].max() | ||
| query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] | ||
| query_start_loc = query_start_loc_cpu.to(self.runner.device, | ||
| non_blocking=True) | ||
| seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] | ||
| seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True) | ||
| query_start_loc = self.runner.query_start_loc[:num_reqs + 1] | ||
| seq_lens = self.runner.seq_lens[:num_reqs] | ||
|
|
||
| block_table = ( | ||
| self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) | ||
| slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( | ||
| self.runner.device, non_blocking=True).long() | ||
|
Comment on lines
-320
to
-321
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@WoosukKwon it's because we call
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tlrmchlsmth what do you think about just making the CPU tensor int64 too? (that's the route that i went with in latest update on this PR)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Had to check - that takes the |
||
| slot_mapping = self.runner.slot_mapping[:num_actual_tokens] | ||
|
|
||
| def schedule(batch_size, cu_query_lens, max_query_len, seqlens, | ||
| max_seq_len, causal): | ||
|
|
||
chanh marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
|
|
||
| from vllm.attention import AttentionType, get_attn_backend | ||
| from vllm.attention.layer import Attention | ||
| from vllm.attention.utils.fa_utils import get_flash_attn_version | ||
| from vllm.config import (CompilationLevel, VllmConfig, | ||
| get_layers_from_vllm_config) | ||
| from vllm.distributed.kv_transfer import (get_kv_transfer_group, | ||
|
|
@@ -138,6 +139,16 @@ def __init__( | |
| raise NotImplementedError( | ||
| "Non-Attention backend is not supported by V1 GPUModelRunner.") | ||
|
|
||
| if self.vllm_config.compilation_config.full_cuda_graph: | ||
| attn_backend_name = self.attn_backend.__name__ | ||
| flash_attn_version = get_flash_attn_version() | ||
| if attn_backend_name != "FlashAttentionBackend" or \ | ||
| flash_attn_version != 3: | ||
| raise ValueError( | ||
| f"full_cuda_graph is only supported with " | ||
| f"FA3. Current attention backend is {attn_backend_name}, " | ||
| f"FlashAttention version is {flash_attn_version}.") | ||
|
|
||
| self.attn_metadata_builder = self.attn_backend.get_builder_cls()( | ||
| weakref.proxy(self)) | ||
| self.cascade_attn_enabled = not self.model_config.disable_cascade_attn | ||
|
|
@@ -215,6 +226,19 @@ def __init__( | |
| self.positions = torch.zeros(self.max_num_tokens, | ||
| dtype=torch.int64, | ||
| device=self.device) | ||
| self.query_start_loc = torch.zeros(self.max_num_reqs + 1, | ||
| dtype=torch.int32, | ||
| device=self.device) | ||
| self.seq_lens = torch.zeros(self.max_num_reqs, | ||
| dtype=torch.int32, | ||
| device=self.device) | ||
| self.slot_mapping = torch.zeros( | ||
| self.max_num_tokens, | ||
| # CPU slot_mapping is int32, but | ||
| # this one must be int64 | ||
chanh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| dtype=torch.int64, | ||
| device=self.device) | ||
|
|
||
| # None in the first PP rank. The rest are set after load_model. | ||
| self.intermediate_tensors: Optional[IntermediateTensors] = None | ||
|
|
||
|
|
@@ -593,6 +617,19 @@ def _prepare_inputs( | |
| scheduler_output.num_common_prefix_blocks, | ||
| ) | ||
|
|
||
| self.query_start_loc[:num_reqs + 1].copy_( | ||
| self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) | ||
| self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], | ||
| non_blocking=True) | ||
| self.slot_mapping[:total_num_scheduled_tokens].copy_( | ||
| self.slot_mapping_cpu[:total_num_scheduled_tokens], | ||
| non_blocking=True) | ||
|
|
||
| # Fill unused with -1. Needed for reshape_and_cache | ||
| self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) | ||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.seq_lens[num_reqs:].fill_(0) | ||
| self.query_start_loc[num_reqs + 1:].fill_(-1) | ||
chanh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| attn_metadata = self.attn_metadata_builder.build( | ||
| num_reqs=num_reqs, | ||
| num_actual_tokens=total_num_scheduled_tokens, | ||
|
|
@@ -1448,6 +1485,7 @@ def _get_prompt_logprobs_dict( | |
| def _dummy_run( | ||
| self, | ||
| num_tokens: int, | ||
| skip_attn: bool = False, | ||
chanh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) -> torch.Tensor: | ||
|
|
||
| # Set num_scheduled_tokens based on num_tokens and max_num_seqs | ||
|
|
@@ -1464,6 +1502,16 @@ def _dummy_run( | |
| num_scheduled_tokens = np.array(num_scheduled_tokens_list, | ||
| dtype=np.int32) | ||
|
|
||
| if skip_attn: | ||
| attn_metadata = None | ||
| else: | ||
| attn_metadata = self.attn_metadata_builder.build( | ||
| num_reqs=num_tokens, | ||
| num_actual_tokens=num_tokens, | ||
| max_query_len=num_tokens, | ||
| common_prefix_len=0, | ||
| ) | ||
|
|
||
| with self.maybe_dummy_run_with_lora(self.lora_config, | ||
| num_scheduled_tokens): | ||
| model = self.model | ||
|
|
@@ -1492,7 +1540,7 @@ def _dummy_run( | |
| for k, v in self.intermediate_tensors.items() | ||
| }) | ||
|
|
||
| with set_forward_context(None, | ||
| with set_forward_context(attn_metadata, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Considering that Therefore, Full Cuda Graph should be incompatible with kvconnector? |
||
| self.vllm_config, | ||
| num_tokens=num_tokens): | ||
| outputs = model( | ||
|
|
@@ -1649,7 +1697,7 @@ def profile_run(self) -> None: | |
| # Cache the dummy encoder outputs. | ||
| self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) | ||
|
|
||
| hidden_states = self._dummy_run(self.max_num_tokens) | ||
| hidden_states = self._dummy_run(self.max_num_tokens, skip_attn=True) | ||
chanh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if get_pp_group().is_last_rank: | ||
| sampler_output = self._dummy_sampler_run(hidden_states) | ||
| else: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be
--compilation-config "{'full_cuda_graph': True}". we use double quote in the outside so that it will not be treated as escape chars by some shells.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, thanks