@@ -135,7 +135,7 @@ class AscendMetadata:
135135 # tokens + new tokens (is None if it is a decoding).
136136 # (batch_size,)
137137 seq_lens : torch .Tensor = None
138- seq_lens_list : list
138+ seq_lens_list : Optional [ list [ int ]]
139139 query_start_loc : torch .Tensor = None
140140 query_lens : torch .Tensor = None
141141 # Maximum query length in the batch (None for decoding).
@@ -183,8 +183,9 @@ def build(self,
183183 seq_lens = common_attn_metadata .seq_lens
184184 # TODO: Refactor these two param to common metadata in runners,
185185 # preparing for the hybrid KV groups feature
186- query_lens = common_attn_metadata .query_lens if common_attn_metadata .query_lens is not None else self .runner .query_lens
187- seq_lens_list = common_attn_metadata .seq_lens_list if common_attn_metadata .seq_lens_list is not None else self .runner .seq_lens_list
186+ query_lens = common_attn_metadata .query_lens or self .runner .query_lens
187+ # Since FIA for GQA is not active now, we temporarily silence it
188+ seq_lens_list = common_attn_metadata .seq_lens_list
188189
189190 slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
190191 attn_mask = self .runner .attn_mask
@@ -219,8 +220,8 @@ def build_dummy_metadata(self, num_actual_tokens, num_reqs,
219220 num_scheduled_tokens , attn_state ):
220221 if attn_state == AscendAttentionState .DecodeOnly :
221222 # NOTE: We only need to pay attention to seq_lens_list and block_table here
222- common_attn_metadata = CommonAttentionMetadata (seq_lens_list = [ 2 ] *
223- num_reqs )
223+ common_attn_metadata = CommonAttentionMetadata (
224+ seq_lens = torch . empty_like ( self . runner . seq_lens_cpu ). fill_ ( 2 ) )
224225
225226 block_table = self .runner .input_batch .block_table [0 ].block_table
226227 block_table [:num_reqs , 0 ] = torch .arange (1 ,
@@ -407,96 +408,58 @@ def forward(
407408 scale_value = self .scale ,
408409 out = output )
409410 elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
410- if self .full_graph :
411- graph_params = get_graph_params ()
412- q = query .view (num_tokens , - 1 , self .hidden_size )
413- k = self .key_cache .view ( # type: ignore
414- - 1 , self .block_size ,
415- self .num_kv_heads * self .head_size )
416- v = self .value_cache .view ( # type: ignore
417- - 1 , self .block_size ,
418- self .num_kv_heads * self .head_size )
419- actual_seq_lens = attn_metadata .seq_lens_list
420- attn_args = {
421- "query" : q ,
422- "key" : k ,
423- "value" : v ,
424- "actual_seq_lengths_kv" : actual_seq_lens ,
425- "block_table" : attn_metadata .block_tables ,
426- "num_heads" : self .num_heads ,
427- "scale" : self .scale ,
428- "input_layout" : "BSH" ,
429- "num_key_value_heads" : self .num_kv_heads ,
430- "block_size" : self .block_size ,
431- }
432-
433- # Prepare tensors for attention output
434- # TODO: Refactor this to step-level instead of layer-level
435- attn_output = torch .empty (num_tokens ,
436- 1 ,
437- self .hidden_size ,
438- dtype = output .dtype ,
439- device = output .device )
440- softmax_lse = torch .empty (num_tokens ,
441- dtype = output .dtype ,
442- device = output .device )
443-
444- # Get workspace from cache or calculate it if not present.
445- workspace = graph_params .workspaces .get (num_tokens )
446- if workspace is None :
447- workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
448- ** attn_args )
449- graph_params .workspaces [num_tokens ] = workspace
450-
451- forward_context = get_forward_context ()
452- if not forward_context .capturing :
453- # Execute attention kernel directly in non-capturing mode
454- torch .ops .npu .npu_fused_infer_attention_score .out (
455- workspace = workspace ,
456- out = [attn_output , softmax_lse ],
457- ** attn_args )
458- else :
459- # Handle graph capturing mode
460- stream = torch_npu .npu .current_stream ()
461-
462- event = torch .npu .ExternalEvent ()
463- event .wait (stream )
464- event .reset (stream )
465- graph_params .events [num_tokens ].append (event )
466-
467- graph_params .attn_params [num_tokens ].append (
468- (q , k , v , actual_seq_lens ,
469- attn_metadata .block_tables , self .num_heads ,
470- self .scale , self .num_kv_heads , attn_output ,
471- softmax_lse ))
472-
473- torch .npu .graph_task_group_begin (stream )
474- torch .ops .npu .npu_fused_infer_attention_score .out (
475- workspace = workspace ,
476- out = [attn_output , softmax_lse ],
477- ** attn_args )
478- handle = torch .npu .graph_task_group_end (stream )
479- graph_params .handles [num_tokens ].append (handle )
480-
481- # Reshape output to match the expected format
482- output .copy_ (
483- attn_output .view (num_tokens , self .num_heads ,
484- self .head_size ))
485- else :
411+ graph_params = get_graph_params ()
412+
413+ forward_context = get_forward_context ()
414+ if not forward_context .capturing :
486415 if is_310p ():
487416 # seq_lens_tensor needs to be transferred to the device for 310P
488417 attn_metadata .seq_lens = \
489418 attn_metadata .seq_lens .to (device = query .device )
490419 torch_npu ._npu_paged_attention (
491- query = query ,
492- key_cache = self .key_cache ,
493- value_cache = self .value_cache ,
494- num_kv_heads = self .num_kv_heads ,
495- num_heads = self .num_heads ,
496- scale_value = self .scale ,
497- block_table = attn_metadata .block_tables ,
498- context_lens = attn_metadata .seq_lens ,
499- out = output )
420+ query = query ,
421+ key_cache = self .key_cache ,
422+ value_cache = self .value_cache ,
423+ num_kv_heads = self .num_kv_heads ,
424+ num_heads = self .num_heads ,
425+ scale_value = self .scale ,
426+ block_table = attn_metadata .block_tables ,
427+ context_lens = attn_metadata .seq_lens ,
428+ out = output )
429+ else :
430+ # Handle graph capturing mode
431+ stream = torch_npu .npu .current_stream ()
432+
433+ event = torch .npu .ExternalEvent ()
434+ event .wait (stream )
435+ event .reset (stream )
436+ graph_params .events [num_tokens ].append (event )
437+
438+ graph_params .attn_params [num_tokens ].append ((
439+ query ,
440+ self .key_cache ,
441+ self .value_cache ,
442+ self .num_kv_heads ,
443+ self .num_heads ,
444+ self .scale ,
445+ attn_metadata .block_tables ,
446+ attn_metadata .seq_lens ,
447+ output ,
448+ ))
449+
450+ torch .npu .graph_task_group_begin (stream )
451+ torch_npu ._npu_paged_attention (
452+ query = query ,
453+ key_cache = self .key_cache ,
454+ value_cache = self .value_cache ,
455+ num_kv_heads = self .num_kv_heads ,
456+ num_heads = self .num_heads ,
457+ scale_value = self .scale ,
458+ block_table = attn_metadata .block_tables ,
459+ context_lens = attn_metadata .seq_lens ,
460+ out = output )
461+ handle = torch .npu .graph_task_group_end (stream )
462+ graph_params .handles [num_tokens ].append (handle )
500463 # Normal V1 situation.
501464 else :
502465 # use chunked prefill for head size 192 scenario, like deepseek
0 commit comments