-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[GPUModelRunner] initialize_kv_cache cleanup (1/N): move initialization that doesn't depend on kv cache config to load_model #28258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d6d230f
cd404c3
857510d
28df894
c8a128d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -138,14 +138,19 @@ | |
| UBatchSlices, | ||
| check_ubatch_thresholds, | ||
| ) | ||
| from vllm.v1.worker.utils import is_residual_scattered_for_sp | ||
| from vllm.v1.worker.utils import ( | ||
| get_runner_only_attn_layers, | ||
| init_kv_sharing, | ||
| is_residual_scattered_for_sp, | ||
| ) | ||
|
|
||
| from .utils import ( | ||
| AttentionGroup, | ||
| MultiModalBudget, | ||
| add_kv_sharing_layers_to_kv_cache_groups, | ||
| bind_kv_cache, | ||
| gather_mm_placeholders, | ||
| get_attn_backend_classes, | ||
| sanity_check_mm_encoder_outputs, | ||
| scatter_mm_placeholders, | ||
| ) | ||
|
|
@@ -3070,6 +3075,34 @@ def load_model(self, eep_scale_up: bool = False) -> None: | |
| self.model, self.vllm_config, CUDAGraphMode.NONE, self.device | ||
| ) | ||
|
|
||
| # Attention layers that are only in the KVCacheConfig of the runner | ||
| # (e.g., KV sharing, encoder-only attention), but not in the | ||
| # KVCacheConfig of the scheduler. | ||
| assert len(self.runner_only_attn_layers) == 0, ( | ||
| "runner_only_attn_layers is not empty" | ||
| ) | ||
| self.runner_only_attn_layers.update( | ||
| get_runner_only_attn_layers(self.vllm_config) | ||
| ) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we get rid of the initialization in the constructor and the assertion here, and just do |
||
| ( | ||
| self.shared_kv_cache_layers, | ||
| self.kv_sharing_fast_prefill_eligible_layers, | ||
| self.kv_sharing_fast_prefill_logits_indices, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is already initialized in the constructor, you forgot to remove it from there? But leaving it in the constructor seems fine? That would also mean the |
||
| ) = init_kv_sharing(self.vllm_config, self.device) | ||
|
|
||
| # Resolve cudagraph_mode | ||
| self._check_and_update_cudagraph_mode() | ||
|
|
||
| if self.dcp_world_size > 1: | ||
| layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) | ||
| for layer in layers.values(): | ||
| assert layer.impl.need_to_return_lse_for_decode, ( | ||
| "DCP requires attention impls to return" | ||
| " the softmax lse for decode, but the impl " | ||
| f"{layer.impl.__class__.__name__} " | ||
| "does not return the softmax lse for decode." | ||
| ) | ||
heheda12345 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: | ||
| """Extract Eagle3 auxiliary layer indices from speculative config. | ||
|
|
||
|
|
@@ -4088,9 +4121,6 @@ def create_attn_groups( | |
| attention_backend_maps.append(attn_backends[0]) | ||
| attention_backend_set.update(attn_backends[1]) | ||
|
|
||
| # Resolve cudagraph_mode before actually initialize metadata_builders | ||
| self._check_and_update_cudagraph_mode(attention_backend_set) | ||
|
|
||
| for i, attn_backend_map in enumerate(attention_backend_maps): | ||
| self.attn_groups.append(create_attn_groups(attn_backend_map, i)) | ||
|
|
||
|
|
@@ -4117,15 +4147,18 @@ def initialize_metadata_builders( | |
| # because some of them change the threshold at init time. | ||
| self.calculate_reorder_batch_threshold() | ||
|
|
||
| def _check_and_update_cudagraph_mode( | ||
| self, attention_backends: set[type[AttentionBackend]] | ||
| ) -> None: | ||
| def _check_and_update_cudagraph_mode(self) -> None: | ||
| """ | ||
| Resolve the cudagraph_mode when there are multiple attention | ||
| backends with potential conflicting CUDA graph support. | ||
| Then initialize the cudagraph_dispatcher based on the resolved | ||
| cudagraph_mode. | ||
| """ | ||
| attention_backends = set( | ||
| get_attn_backend_classes( | ||
| self.vllm_config, self.kv_sharing_fast_prefill_eligible_layers | ||
| ).values() | ||
| ) | ||
| min_cg_support = AttentionCGSupport.ALWAYS | ||
| min_cg_backend_name = None | ||
|
|
||
|
|
@@ -4648,20 +4681,8 @@ def maybe_add_kv_sharing_layers_to_kv_cache_groups( | |
| add_kv_sharing_layers_to_kv_cache_groups( | ||
| self.shared_kv_cache_layers, | ||
| kv_cache_config.kv_cache_groups, | ||
| self.runner_only_attn_layers, | ||
| ) | ||
|
|
||
| if self.cache_config.kv_sharing_fast_prefill: | ||
| # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other | ||
| # similar KV sharing setups, only the layers that generate KV caches | ||
| # are involved in the prefill phase, enabling prefill to early exit. | ||
| attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) | ||
| for layer_name in reversed(attn_layers): | ||
| if layer_name in self.shared_kv_cache_layers: | ||
| self.kv_sharing_fast_prefill_eligible_layers.add(layer_name) | ||
| else: | ||
| break | ||
|
|
||
| def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | ||
| """ | ||
| Initialize KV cache based on `kv_cache_config`. | ||
|
|
@@ -4701,19 +4722,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | |
| kv_transfer_group.register_kv_caches(kv_caches) | ||
| kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) | ||
|
|
||
| if self.dcp_world_size > 1: | ||
| layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) | ||
| for layer in layers.values(): | ||
| assert layer.impl.need_to_return_lse_for_decode, ( | ||
| "DCP requires attention impls to return" | ||
| " the softmax lse for decode, but the impl " | ||
| f"{layer.impl.__class__.__name__} " | ||
| "does not return the softmax lse for decode." | ||
| ) | ||
|
|
||
| def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: | ||
| """ | ||
| Add encoder-only layers to the KV cache config. | ||
| Add encoder-only attention layers to the KV cache config. | ||
| """ | ||
| block_size = self.vllm_config.cache_config.block_size | ||
| encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) | ||
|
|
@@ -4727,7 +4738,6 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: | |
| dtype=self.kv_cache_dtype, | ||
| ) | ||
| encoder_only_attn_specs[attn_spec].append(layer_name) | ||
| self.runner_only_attn_layers.add(layer_name) | ||
| if len(encoder_only_attn_specs) > 0: | ||
| assert len(encoder_only_attn_specs) == 1, ( | ||
| "Only support one encoder-only attention spec now" | ||
|
|
@@ -4750,7 +4760,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: | |
| attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) | ||
| for layer_name, attn_module in attn_layers.items(): | ||
| if isinstance(attn_module, Attention) and ( | ||
| kv_tgt_layer := attn_module.kv_sharing_target_layer_name | ||
| attn_module.kv_sharing_target_layer_name | ||
| ): | ||
| # The layer doesn't need its own KV cache and will use that of | ||
| # the target layer. We skip creating a KVCacheSpec for it, so | ||
|
|
@@ -4759,7 +4769,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: | |
| # enables the memory saving of cross-layer kv sharing, allowing | ||
| # a given amount of memory to accommodate longer context lengths | ||
| # or enable more requests to be processed simultaneously. | ||
| self.shared_kv_cache_layers[layer_name] = kv_tgt_layer | ||
| continue | ||
| # Skip modules that don't need KV cache (eg encoder-only attention) | ||
heheda12345 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if spec := attn_module.get_kv_cache_spec(self.vllm_config): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,13 +7,23 @@ | |
| import torch | ||
|
|
||
| from vllm.attention.backends.abstract import AttentionBackend | ||
| from vllm.config import ModelConfig, SchedulerConfig, VllmConfig | ||
| from vllm.attention.layer import Attention, AttentionType | ||
| from vllm.config import ( | ||
| ModelConfig, | ||
| SchedulerConfig, | ||
| VllmConfig, | ||
| get_layers_from_vllm_config, | ||
| ) | ||
| from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase | ||
| from vllm.model_executor.models.interfaces import MultiModalEmbeddings | ||
| from vllm.model_executor.models.utils import extract_layer_index | ||
| from vllm.multimodal.cache import processor_only_cache_from_config | ||
| from vllm.multimodal.registry import MultiModalRegistry | ||
| from vllm.platforms import current_platform | ||
| from vllm.v1.attention.backends.utils import AttentionMetadataBuilder | ||
| from vllm.v1.attention.backends.utils import ( | ||
| AttentionMetadataBuilder, | ||
| create_fast_prefill_custom_backend, | ||
| ) | ||
| from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget | ||
| from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec | ||
|
|
||
|
|
@@ -248,7 +258,6 @@ def gather_mm_placeholders( | |
| def add_kv_sharing_layers_to_kv_cache_groups( | ||
| shared_kv_cache_layers: dict[str, str], | ||
| kv_cache_groups: list[KVCacheGroupSpec], | ||
| runner_only_attn_layers: set[str] | None = None, | ||
| ) -> None: | ||
| """ | ||
| Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` | ||
|
|
@@ -272,9 +281,6 @@ def add_kv_sharing_layers_to_kv_cache_groups( | |
| tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name] | ||
| tgt_kv_cache_group.layer_names.append(layer_name) | ||
|
|
||
| if runner_only_attn_layers is not None: | ||
| runner_only_attn_layers.add(layer_name) | ||
|
|
||
|
|
||
| def bind_kv_cache( | ||
| kv_caches: dict[str, torch.Tensor], | ||
|
|
@@ -364,3 +370,101 @@ def is_residual_scattered_for_sp( | |
| return True | ||
|
|
||
| return num_input_tokens in vllm_config.compilation_config.compile_sizes | ||
|
|
||
|
|
||
| def get_runner_only_attn_layers( | ||
| vllm_config: VllmConfig, | ||
| ) -> set[str]: | ||
| """ | ||
| Get the runner-only attention layers that the scheduler doesn't need to allocate | ||
| KV cache for. | ||
| """ | ||
| attn_layers = get_layers_from_vllm_config(vllm_config, Attention) | ||
| runner_only_attn_layers: set[str] = set() | ||
| for layer_name, attn_module in attn_layers.items(): | ||
| if attn_module.attn_type == AttentionType.ENCODER_ONLY: | ||
| # Encoder-only attention layers don't need KV cache. | ||
| runner_only_attn_layers.add(layer_name) | ||
| elif attn_module.kv_sharing_target_layer_name: | ||
| # KV sharing layers use the KV cache of other layers. | ||
| runner_only_attn_layers.add(layer_name) | ||
| return runner_only_attn_layers | ||
|
|
||
|
|
||
| def init_kv_sharing( | ||
| vllm_config: VllmConfig, device: torch.device | ||
| ) -> tuple[dict[str, str], set[str], torch.Tensor | None]: | ||
| """ | ||
| Initialize the KV sharing layers. | ||
|
|
||
| Args: | ||
| vllm_config: The VllmConfig. | ||
| device: The device to initialize the KV sharing layers on. | ||
|
|
||
| Returns: | ||
| A tuple of: | ||
| shared_kv_cache_layers: A dictionary of layer names to their target layer | ||
| names. | ||
| kv_sharing_fast_prefill_eligible_layers: A set of layer names that are | ||
| eligible for fast prefill optimization mentioned in You Only Cache Once | ||
| (https://arxiv.org/abs/2405.05254). | ||
| kv_sharing_fast_prefill_logits_indices: A buffer for logits indices used in | ||
| fast prefill optimization. None if fast prefill optimization is not enabled. | ||
| """ | ||
|
|
||
| # If an Attention layer `layer_name` is in the keys of this dict, it | ||
| # means this layer will perform attention using the keys and values | ||
| # from the KV cache of `shared_kv_cache_layers[layer_name]`. | ||
| shared_kv_cache_layers: dict[str, str] = {} | ||
| attn_layers = get_layers_from_vllm_config(vllm_config, Attention) | ||
| for layer_name, attn_module in attn_layers.items(): | ||
| if kv_tgt_layer := attn_module.kv_sharing_target_layer_name: | ||
| shared_kv_cache_layers[layer_name] = kv_tgt_layer | ||
|
|
||
| # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other | ||
| # similar KV sharing setups, only the layers that generate KV caches | ||
| # are involved in the prefill phase, enabling prefill to early exit. | ||
| # Layers in this set will be skipped during prefill. | ||
| kv_sharing_fast_prefill_eligible_layers: set[str] = set() | ||
| kv_sharing_fast_prefill_logits_indices = None | ||
|
|
||
| if vllm_config.cache_config.kv_sharing_fast_prefill: | ||
| kv_sharing_fast_prefill_logits_indices = torch.zeros( | ||
| vllm_config.scheduler_config.max_num_batched_tokens, | ||
| dtype=torch.int32, | ||
| device=device, | ||
| ) | ||
| for layer_name in reversed(attn_layers): | ||
| if layer_name in shared_kv_cache_layers: | ||
| kv_sharing_fast_prefill_eligible_layers.add(layer_name) | ||
| else: | ||
| break | ||
|
|
||
| return ( | ||
| shared_kv_cache_layers, | ||
| kv_sharing_fast_prefill_eligible_layers, | ||
| kv_sharing_fast_prefill_logits_indices, | ||
| ) | ||
|
|
||
|
|
||
| def get_attn_backend_classes( | ||
| vllm_config: VllmConfig, | ||
| kv_sharing_fast_prefill_eligible_layers: set[str], | ||
| layer_names: list[str] | None = None, | ||
| ) -> dict[str, type[AttentionBackend]]: | ||
| """ | ||
| Get the attention backend classes for the given VllmConfig. | ||
| """ | ||
| attn_layers = get_layers_from_vllm_config( | ||
| vllm_config, AttentionLayerBase, layer_names | ||
| ) | ||
| attn_backend_cls_dict = {} | ||
| for layer_name, attn_module in attn_layers.items(): | ||
| attn_backend = attn_module.get_attn_backend() | ||
| if layer_name in kv_sharing_fast_prefill_eligible_layers: | ||
| attn_backend = create_fast_prefill_custom_backend( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, this gets created again later in Can this be avoided? e.g. can it be created elsewhere earlier so
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, this is probably a duplicate of @LucasWilkinson comment |
||
| "FastPrefill", | ||
| attn_backend, | ||
| ) | ||
| attn_backend_cls_dict[layer_name] = attn_backend | ||
| return attn_backend_cls_dict | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a thought ... what belongs in utils, and what belongs in the model runner? Are you putting these in utils so they can be re-used by other model runners? Is that a good refactoring goal in general - move as much code as possible into utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see
I guess my preference would be to keep the purpose of the PR clean - move some steps into
load_model()in this PR and do some more complete "prepare for model runner v2" refactoring in a separate PR. It's hard to judge whether these functions are a positive refactoring move in the context of this PRNot a strong objection though 🤷