Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 42 additions & 33 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,19 @@
UBatchSlices,
check_ubatch_thresholds,
)
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.utils import (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought ... what belongs in utils, and what belongs in the model runner? Are you putting these in utils so they can be re-used by other model runners? Is that a good refactoring goal in general - move as much code as possible into utils?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see

This PR also starts to change some functions in gpu model runner into pure functions so that they can be reused by model runner v2 in the future.

I guess my preference would be to keep the purpose of the PR clean - move some steps into load_model() in this PR and do some more complete "prepare for model runner v2" refactoring in a separate PR. It's hard to judge whether these functions are a positive refactoring move in the context of this PR

Not a strong objection though 🤷

get_runner_only_attn_layers,
init_kv_sharing,
is_residual_scattered_for_sp,
)

from .utils import (
AttentionGroup,
MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache,
gather_mm_placeholders,
get_attn_backend_classes,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders,
)
Expand Down Expand Up @@ -3070,6 +3075,34 @@ def load_model(self, eep_scale_up: bool = False) -> None:
self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
)

# Attention layers that are only in the KVCacheConfig of the runner
# (e.g., KV sharing, encoder-only attention), but not in the
# KVCacheConfig of the scheduler.
assert len(self.runner_only_attn_layers) == 0, (
"runner_only_attn_layers is not empty"
)
self.runner_only_attn_layers.update(
get_runner_only_attn_layers(self.vllm_config)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get rid of the initialization in the constructor and the assertion here, and just do

self.runner_only_attn_layers = get_runner_only_attn_layers(self.vllm_config)

(
self.shared_kv_cache_layers,
self.kv_sharing_fast_prefill_eligible_layers,
self.kv_sharing_fast_prefill_logits_indices,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already initialized in the constructor, you forgot to remove it from there?

But leaving it in the constructor seems fine? That would also mean the device arg can be removed from utils.kv_sharing() and make it only return layers which is a nice simplification

) = init_kv_sharing(self.vllm_config, self.device)

# Resolve cudagraph_mode
self._check_and_update_cudagraph_mode()

if self.dcp_world_size > 1:
layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
for layer in layers.values():
assert layer.impl.need_to_return_lse_for_decode, (
"DCP requires attention impls to return"
" the softmax lse for decode, but the impl "
f"{layer.impl.__class__.__name__} "
"does not return the softmax lse for decode."
)

def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None:
"""Extract Eagle3 auxiliary layer indices from speculative config.

Expand Down Expand Up @@ -4088,9 +4121,6 @@ def create_attn_groups(
attention_backend_maps.append(attn_backends[0])
attention_backend_set.update(attn_backends[1])

# Resolve cudagraph_mode before actually initialize metadata_builders
self._check_and_update_cudagraph_mode(attention_backend_set)

for i, attn_backend_map in enumerate(attention_backend_maps):
self.attn_groups.append(create_attn_groups(attn_backend_map, i))

Expand All @@ -4117,15 +4147,18 @@ def initialize_metadata_builders(
# because some of them change the threshold at init time.
self.calculate_reorder_batch_threshold()

def _check_and_update_cudagraph_mode(
self, attention_backends: set[type[AttentionBackend]]
) -> None:
def _check_and_update_cudagraph_mode(self) -> None:
"""
Resolve the cudagraph_mode when there are multiple attention
backends with potential conflicting CUDA graph support.
Then initialize the cudagraph_dispatcher based on the resolved
cudagraph_mode.
"""
attention_backends = set(
get_attn_backend_classes(
self.vllm_config, self.kv_sharing_fast_prefill_eligible_layers
).values()
)
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_backend_name = None

Expand Down Expand Up @@ -4648,20 +4681,8 @@ def maybe_add_kv_sharing_layers_to_kv_cache_groups(
add_kv_sharing_layers_to_kv_cache_groups(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
self.runner_only_attn_layers,
)

if self.cache_config.kv_sharing_fast_prefill:
# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
# similar KV sharing setups, only the layers that generate KV caches
# are involved in the prefill phase, enabling prefill to early exit.
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name in reversed(attn_layers):
if layer_name in self.shared_kv_cache_layers:
self.kv_sharing_fast_prefill_eligible_layers.add(layer_name)
else:
break

def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Expand Down Expand Up @@ -4701,19 +4722,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
kv_transfer_group.register_kv_caches(kv_caches)
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)

if self.dcp_world_size > 1:
layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
for layer in layers.values():
assert layer.impl.need_to_return_lse_for_decode, (
"DCP requires attention impls to return"
" the softmax lse for decode, but the impl "
f"{layer.impl.__class__.__name__} "
"does not return the softmax lse for decode."
)

def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""
Add encoder-only layers to the KV cache config.
Add encoder-only attention layers to the KV cache config.
"""
block_size = self.vllm_config.cache_config.block_size
encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
Expand All @@ -4727,7 +4738,6 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
dtype=self.kv_cache_dtype,
)
encoder_only_attn_specs[attn_spec].append(layer_name)
self.runner_only_attn_layers.add(layer_name)
if len(encoder_only_attn_specs) > 0:
assert len(encoder_only_attn_specs) == 1, (
"Only support one encoder-only attention spec now"
Expand All @@ -4750,7 +4760,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention) and (
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
attn_module.kv_sharing_target_layer_name
):
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
Expand All @@ -4759,7 +4769,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
# Skip modules that don't need KV cache (eg encoder-only attention)
if spec := attn_module.get_kv_cache_spec(self.vllm_config):
Expand Down
116 changes: 110 additions & 6 deletions vllm/v1/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,23 @@
import torch

from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.attention.layer import Attention, AttentionType
from vllm.config import (
ModelConfig,
SchedulerConfig,
VllmConfig,
get_layers_from_vllm_config,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
create_fast_prefill_custom_backend,
)
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec

Expand Down Expand Up @@ -248,7 +258,6 @@ def gather_mm_placeholders(
def add_kv_sharing_layers_to_kv_cache_groups(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec],
runner_only_attn_layers: set[str] | None = None,
) -> None:
"""
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
Expand All @@ -272,9 +281,6 @@ def add_kv_sharing_layers_to_kv_cache_groups(
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
tgt_kv_cache_group.layer_names.append(layer_name)

if runner_only_attn_layers is not None:
runner_only_attn_layers.add(layer_name)


def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
Expand Down Expand Up @@ -364,3 +370,101 @@ def is_residual_scattered_for_sp(
return True

return num_input_tokens in vllm_config.compilation_config.compile_sizes


def get_runner_only_attn_layers(
vllm_config: VllmConfig,
) -> set[str]:
"""
Get the runner-only attention layers that the scheduler doesn't need to allocate
KV cache for.
"""
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
runner_only_attn_layers: set[str] = set()
for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
# Encoder-only attention layers don't need KV cache.
runner_only_attn_layers.add(layer_name)
elif attn_module.kv_sharing_target_layer_name:
# KV sharing layers use the KV cache of other layers.
runner_only_attn_layers.add(layer_name)
return runner_only_attn_layers


def init_kv_sharing(
vllm_config: VllmConfig, device: torch.device
) -> tuple[dict[str, str], set[str], torch.Tensor | None]:
"""
Initialize the KV sharing layers.

Args:
vllm_config: The VllmConfig.
device: The device to initialize the KV sharing layers on.

Returns:
A tuple of:
shared_kv_cache_layers: A dictionary of layer names to their target layer
names.
kv_sharing_fast_prefill_eligible_layers: A set of layer names that are
eligible for fast prefill optimization mentioned in You Only Cache Once
(https://arxiv.org/abs/2405.05254).
kv_sharing_fast_prefill_logits_indices: A buffer for logits indices used in
fast prefill optimization. None if fast prefill optimization is not enabled.
"""

# If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
shared_kv_cache_layers: dict[str, str] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if kv_tgt_layer := attn_module.kv_sharing_target_layer_name:
shared_kv_cache_layers[layer_name] = kv_tgt_layer

# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
# similar KV sharing setups, only the layers that generate KV caches
# are involved in the prefill phase, enabling prefill to early exit.
# Layers in this set will be skipped during prefill.
kv_sharing_fast_prefill_eligible_layers: set[str] = set()
kv_sharing_fast_prefill_logits_indices = None

if vllm_config.cache_config.kv_sharing_fast_prefill:
kv_sharing_fast_prefill_logits_indices = torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
dtype=torch.int32,
device=device,
)
for layer_name in reversed(attn_layers):
if layer_name in shared_kv_cache_layers:
kv_sharing_fast_prefill_eligible_layers.add(layer_name)
else:
break

return (
shared_kv_cache_layers,
kv_sharing_fast_prefill_eligible_layers,
kv_sharing_fast_prefill_logits_indices,
)


def get_attn_backend_classes(
vllm_config: VllmConfig,
kv_sharing_fast_prefill_eligible_layers: set[str],
layer_names: list[str] | None = None,
) -> dict[str, type[AttentionBackend]]:
"""
Get the attention backend classes for the given VllmConfig.
"""
attn_layers = get_layers_from_vllm_config(
vllm_config, AttentionLayerBase, layer_names
)
attn_backend_cls_dict = {}
for layer_name, attn_module in attn_layers.items():
attn_backend = attn_module.get_attn_backend()
if layer_name in kv_sharing_fast_prefill_eligible_layers:
attn_backend = create_fast_prefill_custom_backend(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this gets created again later in initialize_attn_backend() ?

Can this be avoided? e.g. can it be created elsewhere earlier so get_layers_from_vllm_config() returns it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this is probably a duplicate of @LucasWilkinson comment

"FastPrefill",
attn_backend,
)
attn_backend_cls_dict[layer_name] = attn_backend
return attn_backend_cls_dict