-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Core] Support full cuda graph in v1 #16072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
a7a16df
6aa65db
c338c99
15edc9b
f9d41b4
bdb1747
4700afd
e55d023
353ea66
1dcdb37
08b8d6a
aa9e4b6
7e821e0
0d1a796
166e6a6
5deacad
6b523ac
3cfd971
59e52e6
22fa9df
659d9b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| from vllm import LLM, SamplingParams | ||
| from vllm.config import CompilationConfig, CompilationLevel | ||
|
|
||
|
|
||
| def run_model(compilation_config: CompilationConfig): | ||
| prompts = ["Hello, my name is"] | ||
| sampling_params = SamplingParams(temperature=0.0, max_tokens=20) | ||
|
|
||
| llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", | ||
| compilation_config=compilation_config) | ||
|
|
||
| return llm.generate(prompts, sampling_params) | ||
|
|
||
|
|
||
| def test_full_cudagraph(monkeypatch): | ||
| with monkeypatch.context() as m: | ||
| m.setenv("VLLM_USE_V1", "1") | ||
| m.setenv("VLLM_FLASH_ATTN_VERSION", "3") | ||
|
|
||
| full_cudagraph_responses = run_model( | ||
| compilation_config=CompilationConfig( | ||
| level=CompilationLevel.FULL_GRAPH, | ||
| use_cudagraph=True, | ||
| )) | ||
|
|
||
| piecewise_responses = run_model(compilation_config=CompilationConfig( | ||
| level=CompilationLevel.PIECEWISE, | ||
| use_cudagraph=True, | ||
| )) | ||
|
|
||
| assert full_cudagraph_responses[0].outputs[ | ||
| 0].text == piecewise_responses[0].outputs[0].text | ||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3065,6 +3065,7 @@ class CompilationLevel: | |||||||||
| DYNAMO_AS_IS = 1 | ||||||||||
| DYNAMO_ONCE = 2 | ||||||||||
| PIECEWISE = 3 | ||||||||||
| FULL_GRAPH = 4 | ||||||||||
|
||||||||||
|
|
||||||||||
|
|
||||||||||
| class CompilationConfig(BaseModel): | ||||||||||
|
|
@@ -3077,6 +3078,7 @@ class CompilationConfig(BaseModel): | |||||||||
| - 1: dynamo as is. | ||||||||||
| - 2: dynamo once. | ||||||||||
| - 3: piecewise compilation. | ||||||||||
| - 4: full compilation. | ||||||||||
| - debug_dump_path: the path to dump the debug information. | ||||||||||
| - cache_dir: the directory to store the compiled graph, to | ||||||||||
| accelerate Inductor compilation. By default, it will use | ||||||||||
|
|
@@ -3088,6 +3090,7 @@ class CompilationConfig(BaseModel): | |||||||||
| We use string to avoid serialization issues when using compilation in a distributed setting. | ||||||||||
| When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). | ||||||||||
| When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph). | ||||||||||
| When the compilation level is 4, the backend is used for the full graph. | ||||||||||
chanh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| - custom_ops: fine-grained control over which custom ops to enable/disable. | ||||||||||
| Use 'all' to enable all, 'none' to disable all. | ||||||||||
| Also specify a list of custom op names to enable (prefixed with a '+'), | ||||||||||
|
|
@@ -3260,7 +3263,7 @@ def __repr__(self) -> str: | |||||||||
| @classmethod | ||||||||||
| def from_cli(cls, cli_value: str) -> "CompilationConfig": | ||||||||||
| """Parse the CLI value for the compilation config.""" | ||||||||||
| if cli_value in ["0", "1", "2", "3"]: | ||||||||||
| if cli_value in ["0", "1", "2", "3", "4"]: | ||||||||||
| return cls(level=int(cli_value)) | ||||||||||
| # do not use `eval`, it is dangerous and can execute arbitrary code | ||||||||||
| dict_value = ast.literal_eval(cli_value) | ||||||||||
|
|
@@ -3327,7 +3330,7 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: | |||||||||
|
|
||||||||||
| # TODO: pass user-specified backend to piecewise compilation | ||||||||||
| # merge with the config use_inductor | ||||||||||
| assert self.level == CompilationLevel.PIECEWISE | ||||||||||
| assert self.level >= CompilationLevel.PIECEWISE | ||||||||||
|
|
||||||||||
| from vllm.compilation.backends import VllmBackend | ||||||||||
| return VllmBackend(vllm_config) | ||||||||||
|
|
@@ -3382,13 +3385,15 @@ def init_with_cudagraph_sizes(self, | |||||||||
| self.max_capture_size] = self.max_capture_size | ||||||||||
|
|
||||||||||
| def set_splitting_ops_for_v1(self): | ||||||||||
| # If default, override splitting ops for piecewise cudagraph on V1. | ||||||||||
| # NOTE: this function needs to be called | ||||||||||
| if not self.splitting_ops: | ||||||||||
| self.splitting_ops = [ | ||||||||||
| "vllm.unified_attention", | ||||||||||
| "vllm.unified_attention_with_output", | ||||||||||
| ] | ||||||||||
| if self.level == CompilationLevel.PIECEWISE: | ||||||||||
| self.splitting_ops = [ | ||||||||||
| "vllm.unified_attention", | ||||||||||
| "vllm.unified_attention_with_output", | ||||||||||
| ] | ||||||||||
| elif self.level == CompilationLevel.FULL_GRAPH: | ||||||||||
| self.splitting_ops = [] | ||||||||||
|
||||||||||
| elif self.level == CompilationLevel.FULL_GRAPH: | |
| self.splitting_ops = [] | |
| else: | |
| assert not self.splitting_ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should remove the if not self.splitting_ops: too right? otherwise assert would be redundant.
if not self.splitting_ops:
if self.level == CompilationLevel.PIECEWISE:
self.splitting_ops = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
else:
assert not self.splitting_ops
chanh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -286,15 +286,12 @@ def reorder_batch(self, input_batch: "InputBatch", | |
| def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, | ||
| common_prefix_len: int): | ||
| max_seq_len = self.runner.seq_lens_np[:num_reqs].max() | ||
| query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] | ||
| query_start_loc = query_start_loc_cpu.to(self.runner.device, | ||
| non_blocking=True) | ||
| seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] | ||
| seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True) | ||
| query_start_loc = self.runner.query_start_loc[:num_reqs + 1] | ||
| seq_lens = self.runner.seq_lens[:num_reqs] | ||
|
|
||
| block_table = ( | ||
| self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) | ||
| slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( | ||
| self.runner.device, non_blocking=True).long() | ||
|
Comment on lines
-320
to
-321
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@WoosukKwon it's because we call
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tlrmchlsmth what do you think about just making the CPU tensor int64 too? (that's the route that i went with in latest update on this PR)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Had to check - that takes the |
||
| slot_mapping = self.runner.slot_mapping[:num_actual_tokens] | ||
|
|
||
| # for local attention | ||
| local_attn_metadata = None | ||
|
|
||
chanh marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -185,7 +185,7 @@ def __init__( | |
| ) | ||
|
|
||
| self.use_cuda_graph = (self.vllm_config.compilation_config.level | ||
| == CompilationLevel.PIECEWISE | ||
| >= CompilationLevel.PIECEWISE | ||
| and not self.model_config.enforce_eager) | ||
| # TODO(woosuk): Provide an option to tune the max cudagraph batch size. | ||
| # The convention is different. | ||
|
|
@@ -206,6 +206,19 @@ def __init__( | |
| self.positions = torch.zeros(self.max_num_tokens, | ||
| dtype=torch.int64, | ||
| device=self.device) | ||
| self.query_start_loc = torch.zeros(self.max_num_reqs + 1, | ||
| dtype=torch.int32, | ||
| device=self.device) | ||
| self.seq_lens = torch.zeros(self.max_num_reqs, | ||
| dtype=torch.int32, | ||
| device=self.device) | ||
| self.slot_mapping = torch.zeros( | ||
| self.max_num_tokens, | ||
| # CPU slot_mapping is int32, but | ||
| # this one must be int64 | ||
chanh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| dtype=torch.int64, | ||
| device=self.device) | ||
|
|
||
| # None in the first PP rank. The rest are set after load_model. | ||
| self.intermediate_tensors: Optional[IntermediateTensors] = None | ||
|
|
||
|
|
@@ -584,6 +597,19 @@ def _prepare_inputs( | |
| scheduler_output.num_common_prefix_blocks, | ||
| ) | ||
|
|
||
| self.query_start_loc[:num_reqs + 1].copy_( | ||
| self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) | ||
| self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], | ||
| non_blocking=True) | ||
| self.slot_mapping[:total_num_scheduled_tokens].copy_( | ||
| self.slot_mapping_cpu[:total_num_scheduled_tokens], | ||
| non_blocking=True) | ||
|
|
||
| # Fill unused with -1. Needed for reshape_and_cache | ||
| self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) | ||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.seq_lens[num_reqs:].fill_(0) | ||
| self.query_start_loc[num_reqs + 1:].fill_(-1) | ||
chanh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| attn_metadata = self.attn_metadata_builder.build( | ||
| num_reqs=num_reqs, | ||
| num_actual_tokens=total_num_scheduled_tokens, | ||
|
|
@@ -1392,6 +1418,7 @@ def _get_prompt_logprobs_dict( | |
| def _dummy_run( | ||
| self, | ||
| num_tokens: int, | ||
| initialize_attention_metadata: bool = False, | ||
chanh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) -> torch.Tensor: | ||
|
|
||
| # Set num_scheduled_tokens based on num_tokens and max_num_seqs | ||
|
|
@@ -1408,6 +1435,16 @@ def _dummy_run( | |
| num_scheduled_tokens = np.array(num_scheduled_tokens_list, | ||
| dtype=np.int32) | ||
|
|
||
| if initialize_attention_metadata: | ||
| attn_metadata = self.attn_metadata_builder.build( | ||
| num_reqs=num_tokens, | ||
| num_actual_tokens=num_tokens, | ||
| max_query_len=num_tokens, | ||
| common_prefix_len=0, | ||
| ) | ||
| else: | ||
| attn_metadata = None | ||
|
|
||
| with self.maybe_dummy_run_with_lora(self.lora_config, | ||
| num_scheduled_tokens): | ||
| model = self.model | ||
|
|
@@ -1436,7 +1473,7 @@ def _dummy_run( | |
| for k, v in self.intermediate_tensors.items() | ||
| }) | ||
|
|
||
| with set_forward_context(None, | ||
| with set_forward_context(attn_metadata, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Considering that Therefore, Full Cuda Graph should be incompatible with kvconnector? |
||
| self.vllm_config, | ||
| num_tokens=num_tokens): | ||
| hidden_states = model( | ||
|
|
@@ -1603,7 +1640,8 @@ def capture_model(self) -> None: | |
| if not self.use_cuda_graph: | ||
| logger.warning( | ||
| "Skipping CUDA graph capture. Please add " | ||
| "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE) | ||
| "-O %s or -O %s to use CUDA graphs.", | ||
chanh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| CompilationLevel.PIECEWISE, CompilationLevel.FULL_GRAPH) | ||
| return | ||
|
|
||
| start_time = time.perf_counter() | ||
|
|
@@ -1616,8 +1654,9 @@ def capture_model(self) -> None: | |
| for num_tokens in reversed(self.cudagraph_batch_sizes): | ||
| for _ in range(self.vllm_config.compilation_config. | ||
| cudagraph_num_of_warmups): | ||
| self._dummy_run(num_tokens) | ||
| self._dummy_run(num_tokens) | ||
| self._dummy_run(num_tokens, | ||
| initialize_attention_metadata=True) | ||
| self._dummy_run(num_tokens, initialize_attention_metadata=True) | ||
|
||
|
|
||
| end_time = time.perf_counter() | ||
| end_free_gpu_memory = torch.cuda.mem_get_info()[0] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.