77import torch .nn .functional as F
88
99import vllm .envs as envs
10- from vllm .attention import AttentionType
10+ from vllm .attention import AttentionMetadata , AttentionType
1111from vllm .attention .selector import backend_name_to_enum , get_attn_backend
1212from vllm .config import CacheConfig , get_current_vllm_config
1313from vllm .forward_context import ForwardContext , get_forward_context
1414from vllm .model_executor .layers .quantization .base_config import (
1515 QuantizationConfig )
1616from vllm .model_executor .layers .quantization .kv_cache import BaseKVCacheMethod
1717from vllm .platforms import _Backend , current_platform
18- from vllm .utils import direct_register_custom_op
18+ from vllm .utils import deprecate_args , direct_register_custom_op
1919
2020
2121class Attention (nn .Module ):
@@ -148,15 +148,25 @@ def __init__(
148148 self .k_range = torch .tensor (envs .K_SCALE_CONSTANT , dtype = torch .float32 )
149149 self .v_range = torch .tensor (envs .V_SCALE_CONSTANT , dtype = torch .float32 )
150150
151+ @deprecate_args (
152+ 4 ,
153+ additional_message =
154+ "In Attention, kv_cache is accessed via self.kv_cache and "
155+ "attn_metadata is accessed via forward context." )
151156 def forward (
152157 self ,
153158 query : torch .Tensor ,
154159 key : torch .Tensor ,
155160 value : torch .Tensor ,
161+ kv_cache : Optional [torch .Tensor ] = None ,
162+ attn_metadata : Optional [AttentionMetadata ] = None ,
156163 ) -> torch .Tensor :
164+ # NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
165+ # directly, use `self.kv_cache` and
166+ # `get_forward_context().attn_metadata` instead.
157167 if self .calculate_kv_scales :
158- attn_metadata = get_forward_context ().attn_metadata
159- if attn_metadata .enable_kv_scales_calculation :
168+ ctx_attn_metadata = get_forward_context ().attn_metadata
169+ if ctx_attn_metadata .enable_kv_scales_calculation :
160170 self .calc_kv_scales (key , value )
161171 if self .use_output :
162172 output = torch .empty_like (query )
@@ -172,14 +182,14 @@ def forward(
172182 value = value .view (- 1 , self .num_kv_heads , self .head_size )
173183 if self .use_direct_call :
174184 forward_context : ForwardContext = get_forward_context ()
175- attn_metadata = forward_context .attn_metadata
185+ ctx_attn_metadata = forward_context .attn_metadata
176186 self_kv_cache = self .kv_cache [forward_context .virtual_engine ]
177187 self .impl .forward (self ,
178188 query ,
179189 key ,
180190 value ,
181191 self_kv_cache ,
182- attn_metadata ,
192+ ctx_attn_metadata ,
183193 output = output )
184194 else :
185195 torch .ops .vllm .unified_attention_with_output (
@@ -188,10 +198,10 @@ def forward(
188198 else :
189199 if self .use_direct_call :
190200 forward_context = get_forward_context ()
191- attn_metadata = forward_context .attn_metadata
201+ ctx_attn_metadata = forward_context .attn_metadata
192202 self_kv_cache = self .kv_cache [forward_context .virtual_engine ]
193203 return self .impl .forward (self , query , key , value ,
194- self_kv_cache , attn_metadata )
204+ self_kv_cache , ctx_attn_metadata )
195205 else :
196206 return torch .ops .vllm .unified_attention (
197207 query , key , value , self .layer_name )
0 commit comments