Skip to content

Commit 5d84b99

Browse files
committed
Deprecate args in Attention.forward instead
Signed-off-by: Harry Mellor <[email protected]>
1 parent 7e0c808 commit 5d84b99

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

vllm/attention/layer.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
import torch.nn.functional as F
88

99
import vllm.envs as envs
10-
from vllm.attention import AttentionType
10+
from vllm.attention import AttentionMetadata, AttentionType
1111
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
1212
from vllm.config import CacheConfig, get_current_vllm_config
1313
from vllm.forward_context import ForwardContext, get_forward_context
1414
from vllm.model_executor.layers.quantization.base_config import (
1515
QuantizationConfig)
1616
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
1717
from vllm.platforms import _Backend, current_platform
18-
from vllm.utils import direct_register_custom_op
18+
from vllm.utils import deprecate_args, direct_register_custom_op
1919

2020

2121
class Attention(nn.Module):
@@ -148,15 +148,25 @@ def __init__(
148148
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
149149
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
150150

151+
@deprecate_args(
152+
4,
153+
additional_message=
154+
"In Attention, kv_cache is accessed via self.kv_cache and "
155+
"attn_metadata is accessed via forward context.")
151156
def forward(
152157
self,
153158
query: torch.Tensor,
154159
key: torch.Tensor,
155160
value: torch.Tensor,
161+
kv_cache: Optional[torch.Tensor] = None,
162+
attn_metadata: Optional[AttentionMetadata] = None,
156163
) -> torch.Tensor:
164+
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
165+
# directly, use `self.kv_cache` and
166+
# `get_forward_context().attn_metadata` instead.
157167
if self.calculate_kv_scales:
158-
attn_metadata = get_forward_context().attn_metadata
159-
if attn_metadata.enable_kv_scales_calculation:
168+
ctx_attn_metadata = get_forward_context().attn_metadata
169+
if ctx_attn_metadata.enable_kv_scales_calculation:
160170
self.calc_kv_scales(key, value)
161171
if self.use_output:
162172
output = torch.empty_like(query)
@@ -172,14 +182,14 @@ def forward(
172182
value = value.view(-1, self.num_kv_heads, self.head_size)
173183
if self.use_direct_call:
174184
forward_context: ForwardContext = get_forward_context()
175-
attn_metadata = forward_context.attn_metadata
185+
ctx_attn_metadata = forward_context.attn_metadata
176186
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
177187
self.impl.forward(self,
178188
query,
179189
key,
180190
value,
181191
self_kv_cache,
182-
attn_metadata,
192+
ctx_attn_metadata,
183193
output=output)
184194
else:
185195
torch.ops.vllm.unified_attention_with_output(
@@ -188,10 +198,10 @@ def forward(
188198
else:
189199
if self.use_direct_call:
190200
forward_context = get_forward_context()
191-
attn_metadata = forward_context.attn_metadata
201+
ctx_attn_metadata = forward_context.attn_metadata
192202
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
193203
return self.impl.forward(self, query, key, value,
194-
self_kv_cache, attn_metadata)
204+
self_kv_cache, ctx_attn_metadata)
195205
else:
196206
return torch.ops.vllm.unified_attention(
197207
query, key, value, self.layer_name)

0 commit comments

Comments
 (0)