-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Core] Support full cuda graph in v1 #16072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a7a16df
6aa65db
c338c99
15edc9b
f9d41b4
bdb1747
4700afd
e55d023
353ea66
1dcdb37
08b8d6a
aa9e4b6
7e821e0
0d1a796
166e6a6
5deacad
6b523ac
3cfd971
59e52e6
22fa9df
659d9b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import contextlib | ||
| import os | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.config import CompilationConfig | ||
|
|
||
| MODEL = "Qwen/Qwen2-1.5B-Instruct" | ||
|
|
||
|
|
||
| @contextlib.contextmanager | ||
| def temporary_environ(env_vars): | ||
| """ | ||
| Temporarily set environment variables and restore them afterward. | ||
| We have to do this vs monkeypatch because monkeypatch doesn't work | ||
| with "module" scoped fixtures. | ||
| """ | ||
| original_env = {k: os.environ.get(k) for k in env_vars} | ||
| try: | ||
| os.environ.update(env_vars) | ||
| yield | ||
| finally: | ||
| for k, v in original_env.items(): | ||
| if v is None: | ||
| os.environ.pop(k, None) | ||
| else: | ||
| os.environ[k] = v | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def full_cudagraph_llm(): | ||
| with temporary_environ({ | ||
| "VLLM_USE_V1": "1", | ||
| "VLLM_FLASH_ATTN_VERSION": "3" | ||
| }): | ||
| return LLM(model=MODEL, | ||
| gpu_memory_utilization=0.2, | ||
| compilation_config=CompilationConfig(full_cuda_graph=True)) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def piecewise_llm(): | ||
| with temporary_environ({ | ||
| "VLLM_USE_V1": "1", | ||
| "VLLM_FLASH_ATTN_VERSION": "3" | ||
| }): | ||
| return LLM(model=MODEL, | ||
| gpu_memory_utilization=0.5, | ||
| compilation_config=CompilationConfig()) | ||
|
|
||
|
|
||
| def generate_text(llm: LLM, batch_size: int, max_tokens: int): | ||
| prompts = ["Hi my name is"] * batch_size | ||
| sampling_params = SamplingParams(temperature=0.0, | ||
| max_tokens=max_tokens, | ||
| top_p=0.95) | ||
|
|
||
| return llm.generate(prompts, sampling_params) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), | ||
| (16, 10), (25, 10), | ||
| (32, 10), (45, 10), | ||
| (64, 10), (8, 5), | ||
| (8, 20), (8, 200)]) | ||
| def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, | ||
| piecewise_llm): | ||
| """ | ||
| Load full cudagraph model and piecewise model once, and at the same time to | ||
| reuse them across various test cases. | ||
|
|
||
| Test various batch sizes and max_tokens to ensure that the full cudagraph | ||
| compilation works for padded cases too. | ||
| """ | ||
| piecewise_responses = generate_text(piecewise_llm, | ||
| batch_size=batch_size, | ||
| max_tokens=max_tokens) | ||
| full_cudagraph_responses = generate_text(full_cudagraph_llm, | ||
| batch_size=batch_size, | ||
| max_tokens=max_tokens) | ||
|
|
||
| # Check that all responses are the same | ||
| for i in range(len(piecewise_responses)): | ||
| assert piecewise_responses[i].outputs[ | ||
| 0].text == full_cudagraph_responses[i].outputs[0].text | ||
|
|
||
|
|
||
| def test_full_cudagraph_with_invalid_backend(): | ||
| with temporary_environ({ | ||
| "VLLM_USE_V1": "1", | ||
| "VLLM_FLASH_ATTN_VERSION": | ||
| "2" #FA2 not supported with full_cuda_graph | ||
| }), pytest.raises(RuntimeError): | ||
| LLM(model=MODEL, | ||
| compilation_config=CompilationConfig(full_cuda_graph=True)) |
chanh marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
|
|
||
| from vllm.attention import AttentionType, get_attn_backend | ||
| from vllm.attention.layer import Attention | ||
| from vllm.attention.utils.fa_utils import get_flash_attn_version | ||
| from vllm.config import (CompilationLevel, VllmConfig, | ||
| get_layers_from_vllm_config) | ||
| from vllm.distributed.kv_transfer import (get_kv_transfer_group, | ||
|
|
@@ -139,6 +140,16 @@ def __init__( | |
| raise NotImplementedError( | ||
| "Non-Attention backend is not supported by V1 GPUModelRunner.") | ||
|
|
||
| if self.vllm_config.compilation_config.full_cuda_graph: | ||
| attn_backend_name = self.attn_backend.__name__ | ||
| flash_attn_version = get_flash_attn_version() | ||
| if attn_backend_name != "FlashAttentionBackend" or \ | ||
| flash_attn_version != 3: | ||
| raise ValueError( | ||
| f"full_cuda_graph is only supported with " | ||
| f"FA3. Current attention backend is {attn_backend_name}, " | ||
| f"FlashAttention version is {flash_attn_version}.") | ||
|
|
||
| self.attn_metadata_builder = self.attn_backend.get_builder_cls()( | ||
| weakref.proxy(self)) | ||
| self.cascade_attn_enabled = not self.model_config.disable_cascade_attn | ||
|
|
@@ -219,6 +230,16 @@ def __init__( | |
| self.positions = torch.zeros(self.max_num_tokens, | ||
| dtype=torch.int64, | ||
| device=self.device) | ||
| self.query_start_loc = torch.zeros(self.max_num_reqs + 1, | ||
| dtype=torch.int32, | ||
| device=self.device) | ||
| self.seq_lens = torch.zeros(self.max_num_reqs, | ||
| dtype=torch.int32, | ||
| device=self.device) | ||
| self.slot_mapping = torch.zeros(self.max_num_tokens, | ||
| dtype=torch.int64, | ||
| device=self.device) | ||
|
|
||
| # None in the first PP rank. The rest are set after load_model. | ||
| self.intermediate_tensors: Optional[IntermediateTensors] = None | ||
|
|
||
|
|
@@ -271,7 +292,7 @@ def __init__( | |
| pin_memory=self.pin_memory) | ||
| self.positions_np = self.positions_cpu.numpy() | ||
| self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, | ||
| dtype=torch.int32, | ||
| dtype=torch.int64, | ||
| device="cpu", | ||
| pin_memory=self.pin_memory) | ||
| self.slot_mapping_np = self.slot_mapping_cpu.numpy() | ||
|
|
@@ -589,10 +610,22 @@ def _prepare_inputs( | |
| self.positions_cpu[:total_num_scheduled_tokens], | ||
| non_blocking=True) | ||
|
|
||
| query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( | ||
| self.device, non_blocking=True) | ||
| seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, | ||
| non_blocking=True) | ||
| self.query_start_loc[:num_reqs + 1].copy_( | ||
| self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) | ||
| self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], | ||
| non_blocking=True) | ||
| self.slot_mapping[:total_num_scheduled_tokens].copy_( | ||
| self.slot_mapping_cpu[:total_num_scheduled_tokens], | ||
| non_blocking=True) | ||
|
|
||
| # Fill unused with -1. Needed for reshape_and_cache | ||
| self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) | ||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.seq_lens[num_reqs:].fill_(0) | ||
| self.query_start_loc[num_reqs + 1:].fill_(-1) | ||
chanh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| query_start_loc = self.query_start_loc[:num_reqs + 1] | ||
| seq_lens = self.seq_lens[:num_reqs] | ||
|
|
||
| common_attn_metadata = CommonAttentionMetadata( | ||
| query_start_loc=query_start_loc, seq_lens=seq_lens) | ||
|
|
||
|
|
@@ -1478,6 +1511,7 @@ def _get_prompt_logprobs_dict( | |
| def _dummy_run( | ||
| self, | ||
| num_tokens: int, | ||
| skip_attn: bool = True, | ||
| ) -> torch.Tensor: | ||
|
|
||
| # Set num_scheduled_tokens based on num_tokens and max_num_seqs | ||
|
|
@@ -1494,6 +1528,23 @@ def _dummy_run( | |
| num_scheduled_tokens = np.array(num_scheduled_tokens_list, | ||
| dtype=np.int32) | ||
|
|
||
| if skip_attn: | ||
| attn_metadata = None | ||
| else: | ||
| query_start_loc = self.query_start_loc[:num_reqs + 1] | ||
| seq_lens = self.seq_lens[:num_reqs] | ||
|
|
||
| common_attn_metadata = CommonAttentionMetadata( | ||
| query_start_loc=query_start_loc, seq_lens=seq_lens) | ||
|
|
||
| attn_metadata = self.attn_metadata_builder.build( | ||
| num_reqs=num_tokens, | ||
| num_actual_tokens=num_tokens, | ||
| max_query_len=num_tokens, | ||
| common_prefix_len=0, | ||
| common_attn_metadata=common_attn_metadata, | ||
| ) | ||
|
|
||
| with self.maybe_dummy_run_with_lora(self.lora_config, | ||
| num_scheduled_tokens): | ||
| model = self.model | ||
|
|
@@ -1522,7 +1573,7 @@ def _dummy_run( | |
| for k, v in self.intermediate_tensors.items() | ||
| }) | ||
|
|
||
| with set_forward_context(None, | ||
| with set_forward_context(attn_metadata, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Considering that Therefore, Full Cuda Graph should be incompatible with kvconnector? |
||
| self.vllm_config, | ||
| num_tokens=num_tokens): | ||
| outputs = model( | ||
|
|
@@ -1708,11 +1759,12 @@ def capture_model(self) -> None: | |
| # Capture the large shapes first so that the smaller shapes | ||
| # can reuse the memory pool allocated for the large shapes. | ||
| with graph_capture(device=self.device): | ||
| skip_attn = not self.vllm_config.compilation_config.full_cuda_graph | ||
| for num_tokens in reversed(self.cudagraph_batch_sizes): | ||
| for _ in range(self.vllm_config.compilation_config. | ||
| cudagraph_num_of_warmups): | ||
| self._dummy_run(num_tokens) | ||
| self._dummy_run(num_tokens) | ||
| self._dummy_run(num_tokens, skip_attn=skip_attn) | ||
| self._dummy_run(num_tokens, skip_attn=skip_attn) | ||
|
|
||
| end_time = time.perf_counter() | ||
| end_free_gpu_memory = torch.cuda.mem_get_info()[0] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WoosukKwon it's because we call
.long()here. We might want to still call it here, to keep the dtypes consistent in the model runner.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tlrmchlsmth what do you think about just making the CPU tensor int64 too? (that's the route that i went with in latest update on this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had to check - that takes the
slot_mappingCPU-> GPU transfer from 32KB to 64KB (by default serving on an H100). That seems fine to me since now we don't do that copy in every layer