Skip to content

Commit f3b6244

Browse files
committed
review
Signed-off-by: NickLucche <[email protected]>
1 parent e1f8ab2 commit f3b6244

File tree

2 files changed

+30
-34
lines changed

2 files changed

+30
-34
lines changed

vllm/attention/layer.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -204,36 +204,6 @@ def _init_kv_cache_quant(
204204
layer.quant_method.create_weights(layer)
205205

206206

207-
def get_attention_context(
208-
layer_name: str,
209-
) -> tuple[dict | object | None, "Attention | MLAAttention", torch.Tensor]:
210-
"""Extract attention context for a given layer.
211-
212-
This helper function extracts the attention metadata, attention layer
213-
instance, and KV cache tensor for a specific layer.
214-
215-
Args:
216-
layer_name: The name/identifier of the attention layer.
217-
218-
Returns:
219-
A tuple containing:
220-
- attn_metadata: Attention metadata for this specific layer, or None if
221-
no metadata available
222-
- attn_layer: The attention layer instance (Attention or MLAAttention)
223-
- kv_cache: The KV cache tensor for current virtual engine
224-
225-
Note: attn_metadata may be None, but attn_layer and kv_cache are always
226-
extracted from the forward context.
227-
"""
228-
forward_context: ForwardContext = get_forward_context()
229-
attn_metadata = forward_context.attn_metadata
230-
if isinstance(attn_metadata, dict):
231-
attn_metadata = attn_metadata[layer_name]
232-
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
233-
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
234-
return attn_metadata, attn_layer, kv_cache
235-
236-
237207
class Attention(nn.Module, AttentionLayerBase):
238208
"""Attention layer.
239209
@@ -907,6 +877,36 @@ def maybe_calc_kv_scales_fake(
907877
)
908878

909879

880+
def get_attention_context(
881+
layer_name: str,
882+
) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]:
883+
"""Extract attention context for a given layer.
884+
885+
This helper function extracts the attention metadata, attention layer
886+
instance, and KV cache tensor for a specific layer.
887+
888+
Args:
889+
layer_name: The name/identifier of the attention layer.
890+
891+
Returns:
892+
A tuple containing:
893+
- attn_metadata: Attention metadata for this specific layer, or None if
894+
no metadata available
895+
- attn_layer: The attention layer instance (Attention or MLAAttention)
896+
- kv_cache: The KV cache tensor for current virtual engine
897+
898+
Note: attn_metadata may be None, but attn_layer and kv_cache are always
899+
extracted from the forward context.
900+
"""
901+
forward_context: ForwardContext = get_forward_context()
902+
attn_metadata = forward_context.attn_metadata
903+
if isinstance(attn_metadata, dict):
904+
attn_metadata = attn_metadata[layer_name]
905+
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
906+
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
907+
return attn_metadata, attn_layer, kv_cache
908+
909+
910910
@maybe_transfer_kv_layer
911911
def unified_attention(
912912
query: torch.Tensor,

vllm/attention/utils/kv_transfer_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,13 @@
33
import inspect
44
from collections.abc import Callable
55
from functools import wraps
6-
from typing import TYPE_CHECKING
76

87
from vllm.distributed.kv_transfer import (
98
get_kv_transfer_group,
109
has_kv_transfer_group,
1110
is_v1_kv_transfer_group,
1211
)
1312

14-
if TYPE_CHECKING:
15-
pass
16-
1713

1814
def maybe_transfer_kv_layer(func: Callable) -> Callable:
1915
"""Decorator that handles KV layer transfer prior and after execution of

0 commit comments

Comments
 (0)