diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 91015ad4379c..4faa7e2af617 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -138,7 +138,11 @@ UBatchSlices, check_ubatch_thresholds, ) -from vllm.v1.worker.utils import is_residual_scattered_for_sp +from vllm.v1.worker.utils import ( + get_runner_only_attn_layers, + init_kv_sharing, + is_residual_scattered_for_sp, +) from .utils import ( AttentionGroup, @@ -146,6 +150,7 @@ add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, gather_mm_placeholders, + get_attn_backend_classes, sanity_check_mm_encoder_outputs, scatter_mm_placeholders, ) @@ -3070,6 +3075,34 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model, self.vllm_config, CUDAGraphMode.NONE, self.device ) + # Attention layers that are only in the KVCacheConfig of the runner + # (e.g., KV sharing, encoder-only attention), but not in the + # KVCacheConfig of the scheduler. + assert len(self.runner_only_attn_layers) == 0, ( + "runner_only_attn_layers is not empty" + ) + self.runner_only_attn_layers.update( + get_runner_only_attn_layers(self.vllm_config) + ) + ( + self.shared_kv_cache_layers, + self.kv_sharing_fast_prefill_eligible_layers, + self.kv_sharing_fast_prefill_logits_indices, + ) = init_kv_sharing(self.vllm_config, self.device) + + # Resolve cudagraph_mode + self._check_and_update_cudagraph_mode() + + if self.dcp_world_size > 1: + layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) + for layer in layers.values(): + assert layer.impl.need_to_return_lse_for_decode, ( + "DCP requires attention impls to return" + " the softmax lse for decode, but the impl " + f"{layer.impl.__class__.__name__} " + "does not return the softmax lse for decode." + ) + def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: """Extract Eagle3 auxiliary layer indices from speculative config. @@ -4088,9 +4121,6 @@ def create_attn_groups( attention_backend_maps.append(attn_backends[0]) attention_backend_set.update(attn_backends[1]) - # Resolve cudagraph_mode before actually initialize metadata_builders - self._check_and_update_cudagraph_mode(attention_backend_set) - for i, attn_backend_map in enumerate(attention_backend_maps): self.attn_groups.append(create_attn_groups(attn_backend_map, i)) @@ -4117,15 +4147,18 @@ def initialize_metadata_builders( # because some of them change the threshold at init time. self.calculate_reorder_batch_threshold() - def _check_and_update_cudagraph_mode( - self, attention_backends: set[type[AttentionBackend]] - ) -> None: + def _check_and_update_cudagraph_mode(self) -> None: """ Resolve the cudagraph_mode when there are multiple attention backends with potential conflicting CUDA graph support. Then initialize the cudagraph_dispatcher based on the resolved cudagraph_mode. """ + attention_backends = set( + get_attn_backend_classes( + self.vllm_config, self.kv_sharing_fast_prefill_eligible_layers + ).values() + ) min_cg_support = AttentionCGSupport.ALWAYS min_cg_backend_name = None @@ -4648,20 +4681,8 @@ def maybe_add_kv_sharing_layers_to_kv_cache_groups( add_kv_sharing_layers_to_kv_cache_groups( self.shared_kv_cache_layers, kv_cache_config.kv_cache_groups, - self.runner_only_attn_layers, ) - if self.cache_config.kv_sharing_fast_prefill: - # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other - # similar KV sharing setups, only the layers that generate KV caches - # are involved in the prefill phase, enabling prefill to early exit. - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - for layer_name in reversed(attn_layers): - if layer_name in self.shared_kv_cache_layers: - self.kv_sharing_fast_prefill_eligible_layers.add(layer_name) - else: - break - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -4701,19 +4722,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) - if self.dcp_world_size > 1: - layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) - for layer in layers.values(): - assert layer.impl.need_to_return_lse_for_decode, ( - "DCP requires attention impls to return" - " the softmax lse for decode, but the impl " - f"{layer.impl.__class__.__name__} " - "does not return the softmax lse for decode." - ) - def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ - Add encoder-only layers to the KV cache config. + Add encoder-only attention layers to the KV cache config. """ block_size = self.vllm_config.cache_config.block_size encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) @@ -4727,7 +4738,6 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: dtype=self.kv_cache_dtype, ) encoder_only_attn_specs[attn_spec].append(layer_name) - self.runner_only_attn_layers.add(layer_name) if len(encoder_only_attn_specs) > 0: assert len(encoder_only_attn_specs) == 1, ( "Only support one encoder-only attention spec now" @@ -4750,7 +4760,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) for layer_name, attn_module in attn_layers.items(): if isinstance(attn_module, Attention) and ( - kv_tgt_layer := attn_module.kv_sharing_target_layer_name + attn_module.kv_sharing_target_layer_name ): # The layer doesn't need its own KV cache and will use that of # the target layer. We skip creating a KVCacheSpec for it, so @@ -4759,7 +4769,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # enables the memory saving of cross-layer kv sharing, allowing # a given amount of memory to accommodate longer context lengths # or enable more requests to be processed simultaneously. - self.shared_kv_cache_layers[layer_name] = kv_tgt_layer continue # Skip modules that don't need KV cache (eg encoder-only attention) if spec := attn_module.get_kv_cache_spec(self.vllm_config): diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 0ca7e81a5c7b..0ccf143ff7c1 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -7,13 +7,23 @@ import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import ModelConfig, SchedulerConfig, VllmConfig +from vllm.attention.layer import Attention, AttentionType +from vllm.config import ( + ModelConfig, + SchedulerConfig, + VllmConfig, + get_layers_from_vllm_config, +) +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + create_fast_prefill_custom_backend, +) from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec @@ -248,7 +258,6 @@ def gather_mm_placeholders( def add_kv_sharing_layers_to_kv_cache_groups( shared_kv_cache_layers: dict[str, str], kv_cache_groups: list[KVCacheGroupSpec], - runner_only_attn_layers: set[str] | None = None, ) -> None: """ Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` @@ -272,9 +281,6 @@ def add_kv_sharing_layers_to_kv_cache_groups( tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name] tgt_kv_cache_group.layer_names.append(layer_name) - if runner_only_attn_layers is not None: - runner_only_attn_layers.add(layer_name) - def bind_kv_cache( kv_caches: dict[str, torch.Tensor], @@ -364,3 +370,101 @@ def is_residual_scattered_for_sp( return True return num_input_tokens in vllm_config.compilation_config.compile_sizes + + +def get_runner_only_attn_layers( + vllm_config: VllmConfig, +) -> set[str]: + """ + Get the runner-only attention layers that the scheduler doesn't need to allocate + KV cache for. + """ + attn_layers = get_layers_from_vllm_config(vllm_config, Attention) + runner_only_attn_layers: set[str] = set() + for layer_name, attn_module in attn_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + # Encoder-only attention layers don't need KV cache. + runner_only_attn_layers.add(layer_name) + elif attn_module.kv_sharing_target_layer_name: + # KV sharing layers use the KV cache of other layers. + runner_only_attn_layers.add(layer_name) + return runner_only_attn_layers + + +def init_kv_sharing( + vllm_config: VllmConfig, device: torch.device +) -> tuple[dict[str, str], set[str], torch.Tensor | None]: + """ + Initialize the KV sharing layers. + + Args: + vllm_config: The VllmConfig. + device: The device to initialize the KV sharing layers on. + + Returns: + A tuple of: + shared_kv_cache_layers: A dictionary of layer names to their target layer + names. + kv_sharing_fast_prefill_eligible_layers: A set of layer names that are + eligible for fast prefill optimization mentioned in You Only Cache Once + (https://arxiv.org/abs/2405.05254). + kv_sharing_fast_prefill_logits_indices: A buffer for logits indices used in + fast prefill optimization. None if fast prefill optimization is not enabled. + """ + + # If an Attention layer `layer_name` is in the keys of this dict, it + # means this layer will perform attention using the keys and values + # from the KV cache of `shared_kv_cache_layers[layer_name]`. + shared_kv_cache_layers: dict[str, str] = {} + attn_layers = get_layers_from_vllm_config(vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if kv_tgt_layer := attn_module.kv_sharing_target_layer_name: + shared_kv_cache_layers[layer_name] = kv_tgt_layer + + # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other + # similar KV sharing setups, only the layers that generate KV caches + # are involved in the prefill phase, enabling prefill to early exit. + # Layers in this set will be skipped during prefill. + kv_sharing_fast_prefill_eligible_layers: set[str] = set() + kv_sharing_fast_prefill_logits_indices = None + + if vllm_config.cache_config.kv_sharing_fast_prefill: + kv_sharing_fast_prefill_logits_indices = torch.zeros( + vllm_config.scheduler_config.max_num_batched_tokens, + dtype=torch.int32, + device=device, + ) + for layer_name in reversed(attn_layers): + if layer_name in shared_kv_cache_layers: + kv_sharing_fast_prefill_eligible_layers.add(layer_name) + else: + break + + return ( + shared_kv_cache_layers, + kv_sharing_fast_prefill_eligible_layers, + kv_sharing_fast_prefill_logits_indices, + ) + + +def get_attn_backend_classes( + vllm_config: VllmConfig, + kv_sharing_fast_prefill_eligible_layers: set[str], + layer_names: list[str] | None = None, +) -> dict[str, type[AttentionBackend]]: + """ + Get the attention backend classes for the given VllmConfig. + """ + attn_layers = get_layers_from_vllm_config( + vllm_config, AttentionLayerBase, layer_names + ) + attn_backend_cls_dict = {} + for layer_name, attn_module in attn_layers.items(): + attn_backend = attn_module.get_attn_backend() + if layer_name in kv_sharing_fast_prefill_eligible_layers: + attn_backend = create_fast_prefill_custom_backend( + "FastPrefill", + attn_backend, + ) + attn_backend_cls_dict[layer_name] = attn_backend + return attn_backend_cls_dict