@@ -204,36 +204,6 @@ def _init_kv_cache_quant(
204204 layer .quant_method .create_weights (layer )
205205
206206
207- def get_attention_context (
208- layer_name : str ,
209- ) -> tuple [dict | object | None , "Attention | MLAAttention" , torch .Tensor ]:
210- """Extract attention context for a given layer.
211-
212- This helper function extracts the attention metadata, attention layer
213- instance, and KV cache tensor for a specific layer.
214-
215- Args:
216- layer_name: The name/identifier of the attention layer.
217-
218- Returns:
219- A tuple containing:
220- - attn_metadata: Attention metadata for this specific layer, or None if
221- no metadata available
222- - attn_layer: The attention layer instance (Attention or MLAAttention)
223- - kv_cache: The KV cache tensor for current virtual engine
224-
225- Note: attn_metadata may be None, but attn_layer and kv_cache are always
226- extracted from the forward context.
227- """
228- forward_context : ForwardContext = get_forward_context ()
229- attn_metadata = forward_context .attn_metadata
230- if isinstance (attn_metadata , dict ):
231- attn_metadata = attn_metadata [layer_name ]
232- attn_layer : Attention | MLAAttention = forward_context .no_compile_layers [layer_name ]
233- kv_cache = attn_layer .kv_cache [forward_context .virtual_engine ]
234- return attn_metadata , attn_layer , kv_cache
235-
236-
237207class Attention (nn .Module , AttentionLayerBase ):
238208 """Attention layer.
239209
@@ -907,6 +877,36 @@ def maybe_calc_kv_scales_fake(
907877)
908878
909879
880+ def get_attention_context (
881+ layer_name : str ,
882+ ) -> tuple [dict | object | None , Attention | MLAAttention , torch .Tensor ]:
883+ """Extract attention context for a given layer.
884+
885+ This helper function extracts the attention metadata, attention layer
886+ instance, and KV cache tensor for a specific layer.
887+
888+ Args:
889+ layer_name: The name/identifier of the attention layer.
890+
891+ Returns:
892+ A tuple containing:
893+ - attn_metadata: Attention metadata for this specific layer, or None if
894+ no metadata available
895+ - attn_layer: The attention layer instance (Attention or MLAAttention)
896+ - kv_cache: The KV cache tensor for current virtual engine
897+
898+ Note: attn_metadata may be None, but attn_layer and kv_cache are always
899+ extracted from the forward context.
900+ """
901+ forward_context : ForwardContext = get_forward_context ()
902+ attn_metadata = forward_context .attn_metadata
903+ if isinstance (attn_metadata , dict ):
904+ attn_metadata = attn_metadata [layer_name ]
905+ attn_layer : Attention | MLAAttention = forward_context .no_compile_layers [layer_name ]
906+ kv_cache = attn_layer .kv_cache [forward_context .virtual_engine ]
907+ return attn_metadata , attn_layer , kv_cache
908+
909+
910910@maybe_transfer_kv_layer
911911def unified_attention (
912912 query : torch .Tensor ,
0 commit comments