-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
GPU Model Runner V2 #25266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPU Model Runner V2 #25266
Changes from 90 commits
33a3a26
699bd79
c472982
79e5eb3
64c8cce
48bca9a
a1e3745
da9cd26
7b4b72e
65f9369
b1d5273
a851aaa
e570b0a
d6d719f
b21393c
efba25e
e451045
19c0dfc
4055781
9ee9d0e
efcb786
e696f78
c11d1e6
22771e5
ba1a58f
62d23b3
af7b6c5
01bf16e
cc340e2
4c2a337
b16e2d9
23eae07
ead95fe
8e6cb9a
0c56069
6283995
286eeb9
5f95309
787e596
7a50a54
9314a83
caf963f
5c133fc
e47bb99
eb3742c
633f9f0
9a6fcca
8b3c13c
67852c1
69b1789
f1981db
e107680
9f2becd
dfc84b1
83d1137
c320a33
9151026
c1d83f2
9050087
92f337f
cbdb47d
3f50030
a496283
bc6463a
aabfaa0
330058f
82e591f
8407fa0
e171e5b
2bb2cb1
67d8c0c
a98eff0
323a05b
82da219
efda084
86dade7
d2be623
31619ff
b9c7448
8deedfa
52ca2f5
af65838
8af8798
b405d78
0d3de9e
3367277
37478c1
9c75d89
d30c0d5
4be2c66
a8e7071
c7f3e84
396bbe6
010e39e
6f038fc
a66aa37
98ef239
158a468
913b8e9
8aee6e9
42ffdd9
631b5b4
bc73f67
fe5472d
72f0a71
17c2c10
42f9915
704def2
ad2cf80
866eef5
1107701
09e4b2f
5666a25
5c8049d
1c5c866
8f8aaa8
e40e85b
013daed
608fec3
bf3992c
a1249af
3ce8a08
b9ebedb
8d82fac
af23897
83943cd
cbd90df
5b5fd19
484135c
8912870
312affc
523f27a
de64ce7
8b44f99
d8a8279
fe97bf9
8240f3a
ebdee19
e75ded3
493b4d6
75ef5f4
724593b
6dc3d83
2b51ecb
ecb2932
63e4387
dd254ce
f510b9e
a505e71
fb0782c
645650c
2326a8c
e284750
4085ce8
1d8a671
31580e9
a0c396b
6da659f
ff9a1aa
197ed08
a72b07e
a9b4fa3
3da2e77
ee2c3b0
995f1aa
ed84190
5ea5e7e
784371c
1402b93
4ee6bc4
e9152dd
104b2fa
327c0e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import torch | ||
|
|
||
| from vllm.v1.outputs import (AsyncModelRunnerOutput, LogprobsTensors, | ||
| ModelRunnerOutput, SamplerOutput) | ||
|
|
||
|
|
||
| class AsyncOutput(AsyncModelRunnerOutput): | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_runner_output: ModelRunnerOutput, | ||
| sampler_output: SamplerOutput, | ||
| copy_stream: torch.cuda.Stream, | ||
| ): | ||
| self.model_runner_output = model_runner_output | ||
| self.sampler_output = sampler_output | ||
| self.copy_stream = copy_stream | ||
| self.copy_event = torch.cuda.Event() | ||
WoosukKwon marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| default_stream = torch.cuda.current_stream() | ||
| with torch.cuda.stream(self.copy_stream): | ||
| self.copy_stream.wait_stream(default_stream) | ||
|
|
||
| self.sampled_token_ids = sampler_output.sampled_token_ids.to( | ||
| "cpu", non_blocking=True) | ||
| x = sampler_output.logprobs_tensors | ||
| if x is not None: | ||
| self.logprobs_tensors = LogprobsTensors( | ||
| logprob_token_ids=x.logprob_token_ids.to( | ||
| "cpu", non_blocking=True), | ||
| logprobs=x.logprobs.to("cpu", non_blocking=True), | ||
| selected_token_ranks=x.selected_token_ranks.to( | ||
| "cpu", non_blocking=True), | ||
| ) | ||
| else: | ||
| self.logprobs_tensors = None | ||
| self.copy_event.record() | ||
|
|
||
| def get_output(self) -> ModelRunnerOutput: | ||
| self.copy_event.synchronize() | ||
| self.model_runner_output.sampled_token_ids = ( | ||
| self.sampled_token_ids.numpy()) | ||
| if self.logprobs_tensors is not None: | ||
| self.model_runner_output.logprobs = ( | ||
| self.logprobs_tensors.tolists()) | ||
| return self.model_runner_output | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from typing import Any | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.attention.backends.abstract import AttentionBackend, AttentionType | ||
| from vllm.attention.layer import Attention | ||
| from vllm.config import VllmConfig, get_layers_from_vllm_config | ||
| from vllm.v1.attention.backends.utils import AttentionMetadataBuilder | ||
| from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, | ||
| KVCacheSpec, SlidingWindowSpec) | ||
| from vllm.v1.worker.utils import bind_kv_cache | ||
|
|
||
|
|
||
| def get_kv_cache_spec( | ||
| vllm_config: VllmConfig, | ||
| kv_cache_dtype: torch.dtype, | ||
| ) -> dict[str, KVCacheSpec]: | ||
| block_size = vllm_config.cache_config.block_size | ||
| use_mla = vllm_config.model_config.use_mla | ||
|
|
||
| kv_cache_spec: dict[str, KVCacheSpec] = {} | ||
| attn_layers = get_layers_from_vllm_config(vllm_config, Attention) | ||
| for layer_name, attn_module in attn_layers.items(): | ||
| assert attn_module.attn_type == AttentionType.DECODER | ||
WoosukKwon marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if attn_module.sliding_window is not None: | ||
| kv_cache_spec[layer_name] = SlidingWindowSpec( | ||
| block_size=block_size, | ||
| num_kv_heads=attn_module.num_kv_heads, | ||
| head_size=attn_module.head_size, | ||
| dtype=kv_cache_dtype, | ||
| sliding_window=attn_module.sliding_window, | ||
| use_mla=use_mla, | ||
| ) | ||
| else: | ||
| kv_cache_spec[layer_name] = FullAttentionSpec( | ||
| block_size=block_size, | ||
| num_kv_heads=attn_module.num_kv_heads, | ||
| head_size=attn_module.head_size, | ||
| dtype=kv_cache_dtype, | ||
| use_mla=use_mla, | ||
| ) | ||
| return kv_cache_spec | ||
|
|
||
|
|
||
| def init_attn_backend( | ||
| kv_cache_config: KVCacheConfig, | ||
| vllm_config: VllmConfig, | ||
| device: torch.device, | ||
| ): | ||
| attn_backends: dict[str, AttentionBackend] = {} | ||
| attn_metadata_builders: list[AttentionMetadataBuilder] = [] | ||
|
|
||
| attn_layers = get_layers_from_vllm_config(vllm_config, Attention) | ||
| for kv_cache_group_spec in kv_cache_config.kv_cache_groups: | ||
| layer_names = kv_cache_group_spec.layer_names | ||
| any_layer_name = next(iter(layer_names)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This appears to assume an always on hybrid-kv-cache manager; I am 100% supportive of this but the iirc the reason we still supported disabling the hybrid-kv-cache manager was because P/D did not support the hybrid kv-cache manager yet; cc @NickLucche @heheda12345 do you know the state of P/D + hybrid-kv-cache? is there any other reason we would want to disable the hybrid-kv-cache?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still not supported on PD, @KuntaiDu has a series of PRs to enable hybrid allocator + kv connectors first
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about adding some assertion to make sure each kv cache group is using only one attention backend? like: |
||
|
|
||
| attn_backend = attn_layers[any_layer_name].get_attn_backend() | ||
| for layer_name in layer_names: | ||
| attn_backends[layer_name] = attn_backend | ||
|
|
||
| attn_metadata_builder = attn_backend.get_builder_cls()( | ||
| kv_cache_group_spec.kv_cache_spec, | ||
| layer_names, | ||
| vllm_config, | ||
| device, | ||
| ) | ||
| attn_metadata_builders.append(attn_metadata_builder) | ||
| return attn_backends, attn_metadata_builders | ||
|
|
||
|
|
||
| def _allocate_kv_cache( | ||
| kv_cache_config: KVCacheConfig, | ||
| device: torch.device, | ||
| ): | ||
| kv_cache_raw_tensors: dict[str, torch.Tensor] = {} | ||
| for kv_cache_tensor in kv_cache_config.kv_cache_tensors: | ||
| tensor = torch.zeros(kv_cache_tensor.size, | ||
| dtype=torch.int8, | ||
| device=device) | ||
| for layer_name in kv_cache_tensor.shared_by: | ||
| kv_cache_raw_tensors[layer_name] = tensor | ||
|
|
||
| layer_names = set() | ||
| for group in kv_cache_config.kv_cache_groups: | ||
| for layer_name in group.layer_names: | ||
| layer_names.add(layer_name) | ||
| assert layer_names == set(kv_cache_raw_tensors.keys() | ||
|
||
| ), "Some layers are not correctly initialized" | ||
| return kv_cache_raw_tensors | ||
|
|
||
|
|
||
| def _reshape_kv_cache( | ||
| kv_cache_config: KVCacheConfig, | ||
| kv_cache_raw_tensors: dict[str, torch.Tensor], | ||
| attn_backends: dict[str, AttentionBackend], | ||
| ) -> dict[str, torch.Tensor]: | ||
| kv_caches: dict[str, torch.Tensor] = {} | ||
| for kv_cache_group_spec in kv_cache_config.kv_cache_groups: | ||
| kv_cache_spec = kv_cache_group_spec.kv_cache_spec | ||
| for layer_name in kv_cache_group_spec.layer_names: | ||
| raw_tensor = kv_cache_raw_tensors[layer_name] | ||
| assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 | ||
| num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes) | ||
|
|
||
| attn_backend = attn_backends[layer_name] | ||
| kv_cache_shape = attn_backend.get_kv_cache_shape( | ||
| num_blocks, kv_cache_spec.block_size, | ||
| kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) | ||
|
|
||
| dtype = kv_cache_spec.dtype | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() | ||
| kv_cache_shape = tuple(kv_cache_shape[i] | ||
| for i in kv_cache_stride_order) | ||
|
|
||
| inv_order = [ | ||
| kv_cache_stride_order.index(i) | ||
| for i in range(len(kv_cache_stride_order)) | ||
| ] | ||
|
|
||
| raw_tensor = raw_tensor.view(dtype) | ||
| raw_tensor = raw_tensor.view(kv_cache_shape) | ||
| kv_caches[layer_name] = raw_tensor.permute(*inv_order) | ||
| return kv_caches | ||
|
|
||
|
|
||
| def init_kv_cache( | ||
| runner_kv_caches: list[torch.Tensor], | ||
| forward_context: dict[str, Any], | ||
| kv_cache_config: KVCacheConfig, | ||
| attn_backends: dict[str, AttentionBackend], | ||
| device: torch.device, | ||
| ): | ||
| kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device) | ||
| kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, | ||
| attn_backends) | ||
| bind_kv_cache(kv_caches, forward_context, runner_kv_caches) | ||
Uh oh!
There was an error while loading. Please reload this page.