Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 39 additions & 76 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.multimodal import MultiModalConfig
from vllm.config.vllm import VllmConfig
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
Expand Down Expand Up @@ -842,41 +838,6 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
)


def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return

connector = get_kv_transfer_group()
if not connector.has_connector_metadata():
return

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: list[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return

connector = get_kv_transfer_group()
if not connector.has_connector_metadata():
return

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])


def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -911,23 +872,46 @@ def maybe_calc_kv_scales_fake(
)


def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
def get_attention_context(
layer_name: str,
) -> torch.Tensor:
wait_for_kv_layer_from_connector(layer_name)
) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]:
"""Extract attention context for a given layer.

This helper function extracts the attention metadata, attention layer
instance, and KV cache tensor for a specific layer.

Args:
layer_name: The name/identifier of the attention layer.

Returns:
A tuple containing:
- attn_metadata: Attention metadata for this specific layer, or None if
no metadata available
- attn_layer: The attention layer instance (Attention or MLAAttention)
- kv_cache: The KV cache tensor for current virtual engine

Note: attn_metadata may be None, but attn_layer and kv_cache are always
extracted from the forward context.
"""
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
return attn_metadata, attn_layer, kv_cache


@maybe_transfer_kv_layer
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)

maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output


Expand All @@ -947,6 +931,7 @@ def unified_attention_fake(
)


@maybe_transfer_kv_layer
def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
Expand All @@ -956,13 +941,7 @@ def unified_attention_with_output(
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
attn_metadata, self, kv_cache = get_attention_context(layer_name)
self.impl.forward(
self,
query,
Expand All @@ -975,8 +954,6 @@ def unified_attention_with_output(
output_block_scale=output_block_scale,
)

maybe_save_kv_layer_to_connector(layer_name, kv_cache)


def unified_attention_with_output_fake(
query: torch.Tensor,
Expand All @@ -998,23 +975,16 @@ def unified_attention_with_output_fake(
)


@maybe_transfer_kv_layer
def unified_mla_attention(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
wait_for_kv_layer_from_connector(layer_name)

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
attn_metadata, self, kv_cache = get_attention_context(layer_name)
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)

maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output


Expand All @@ -1036,6 +1006,7 @@ def unified_mla_attention_fake(
)


@maybe_transfer_kv_layer
def unified_mla_attention_with_output(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
Expand All @@ -1045,13 +1016,7 @@ def unified_mla_attention_with_output(
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self: MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
attn_metadata, self, kv_cache = get_attention_context(layer_name)
self.impl.forward(
self,
q,
Expand All @@ -1064,8 +1029,6 @@ def unified_mla_attention_with_output(
output_block_scale=output_block_scale,
)

maybe_save_kv_layer_to_connector(layer_name, kv_cache)


def unified_mla_attention_with_output_fake(
q: torch.Tensor,
Expand Down
60 changes: 60 additions & 0 deletions vllm/attention/utils/kv_transfer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Callable
from functools import wraps

from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)


def maybe_transfer_kv_layer(func: Callable) -> Callable:
"""Decorator that handles KV layer transfer prior and after execution of
an attention layer, if enabled. Otherwise, the wrapper is a no-op.

On entry: waits for the KV layer from the connector.
On exit: saves the KV layer to the connector.
"""
# Import at runtime to avoid circular dependency
from vllm.attention.layer import get_attention_context

# Inspect the signature ONCE when the decorator is applied.
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())

# Find the index of 'layer_name' parameter.
try:
layer_name_index = param_names.index("layer_name")
except ValueError as e:
raise TypeError(
f"Function {func.__name__} must have a 'layer_name' parameter"
) from e

@wraps(func)
def wrapper(*args, **kwargs):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return func(*args, **kwargs)

layer_name: str = args[layer_name_index]

# Extract attention context (layer-specific metadata, layer, and kv_cache)
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
connector = get_kv_transfer_group()
if attn_metadata is None or not connector.has_connector_metadata():
return func(*args, **kwargs)

# Wait for KV layer on entry
connector.wait_for_layer_load(layer_name)

# Execute the function
result = func(*args, **kwargs)

# Save KV cache layer on exit
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)

return result

return wrapper