diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 226a5705cdd..4a8b5902d07 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -376,7 +376,10 @@ def build( seq_lens = seq_lens[:self._num_decode_tokens] input_positions = input_positions[:self._num_decode_tokens] block_table = block_table[:self._num_decode_tokens, ...] - if use_torchair_graph and self.runner.attn_state == AscendAttentionState.DecodeOnly: + if use_torchair_graph and self.runner.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: num_seqs = len(seq_lens) if graph_pad_size != 0: pad_value = 1 diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 2b343d7f7d4..51ffe1ec2d4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -943,11 +943,6 @@ def _process_reqs( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) input_ids = self.input_ids[:num_input_tokens] - if (envs_ascend.VLLM_ENABLE_MC2 - or self.torchair_graph_enabled) and not with_prefill: - input_ids = self.input_ids[:padded_batch_size] - positions = self.positions[:padded_batch_size] - # prepare the MRoPE for mllm if using multimodal num_input_tokens = total_num_scheduled_tokens # _prepare_inputs may reorder the batch, so we must gather multi @@ -985,6 +980,11 @@ def _process_reqs( else: positions = self.positions[:num_input_tokens] + if (envs_ascend.VLLM_ENABLE_MC2 + or self.torchair_graph_enabled) and not with_prefill: + input_ids = self.input_ids[:padded_batch_size] + positions = self.positions[:padded_batch_size] + # Run forward pass with set_forward_context(attn_metadata, self.vllm_config,