1- from dataclasses import dataclass , replace
1+ from dataclasses import replace
22from typing import Any
33
44import torch
5-
65from vllm .attention .layer import Attention
76from vllm .config import VllmConfig , get_layers_from_vllm_config
87from vllm .config .speculative import SpeculativeConfig
98from vllm .logger import init_logger
109from vllm .model_executor .model_loader import get_model
1110from vllm .v1 .core .sched .output import SchedulerOutput
1211from vllm .v1 .sample .metadata import SamplingMetadata
13- from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
1412from vllm .v1 .spec_decode .eagle import PADDING_SLOT_ID
15-
16- from vllm_ascend .spec_decode .eagle_proposer import SpecDecodeBaseProposer
13+ from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
1714from vllm_ascend .attention .attention_v1 import AscendMetadata
1815from vllm_ascend .attention .utils import extend_flat_seqs
16+ from vllm_ascend .spec_decode .eagle_proposer import SpecDecodeBaseProposer
1917
2018logger = init_logger (__name__ )
2119
@@ -39,18 +37,18 @@ def __init__(
3937 self ._raise_if_vocab_size_mismatch ()
4038 self ._raise_if_draft_tp_mismatch ()
4139
42-
43- def generate_token_ids ( self ,
44- valid_sampled_token_ids : list [list [int ]],
45- sampling_metadata : SamplingMetadata = None ,
46- scheduler_output : SchedulerOutput = None ,
47- spec_decode_metadata : SpecDecodeMetadata = None ,
48- positions : torch .Tensor = None ,
49- num_scheduled_tokens : int = 0 ,
50- hidden_states : torch .Tensor = None ,
51- attn_metadata = None ,
52- aux_hidden_states : torch .Tensor = None ):
53-
40+ def generate_token_ids (
41+ self ,
42+ valid_sampled_token_ids : list [list [int ]],
43+ sampling_metadata : SamplingMetadata = None ,
44+ scheduler_output : SchedulerOutput = None ,
45+ spec_decode_metadata : SpecDecodeMetadata = None ,
46+ positions : torch .Tensor = None ,
47+ num_scheduled_tokens : int = 0 ,
48+ hidden_states : torch .Tensor = None ,
49+ attn_metadata = None ,
50+ aux_hidden_states : torch .Tensor = None ,
51+ ):
5452 attn_metadata = self ._get_atten_dict (scheduler_output )
5553 attn_metadata = attn_metadata [self .attn_layer_name ]
5654 next_token_ids : list [int ] = []
@@ -63,23 +61,26 @@ def generate_token_ids(self,
6361 # Get the next token id from the request state.
6462 req_id = self .runner .input_batch .req_ids [i ]
6563 req_state = self .runner .requests [req_id ]
66- seq_len = (req_state .num_computed_tokens +
67- scheduler_output .num_scheduled_tokens [req_id ])
64+ seq_len = (
65+ req_state .num_computed_tokens
66+ + scheduler_output .num_scheduled_tokens [req_id ]
67+ )
6868
6969 next_token_id = req_state .get_token_id (seq_len )
7070 next_token_ids .append (next_token_id )
71- next_token_ids = torch .tensor (next_token_ids ,
72- dtype = torch .int32 ,
73- device = self . device )
74-
71+ next_token_ids = torch .tensor (
72+ next_token_ids , dtype = torch .int32 , device = self . device
73+ )
74+
7575 if spec_decode_metadata is None :
7676 # input_ids can be None for multimodal models.
7777 target_token_ids = self .runner .input_ids [:num_scheduled_tokens ]
7878 target_positions = positions [:num_scheduled_tokens ]
79- cu_num_tokens = attn_metadata .query_start_loc
79+ cu_num_tokens = attn_metadata .query_start_loc
8080 else :
8181 num_draft_tokens = spec_decode_metadata .num_draft_tokens
82- num_rejected_tokens = [n + 1 - len (valid_sampled_token_ids [i ]) if n > 0 else 0
82+ num_rejected_tokens = [
83+ n + 1 - len (valid_sampled_token_ids [i ]) if n > 0 else 0
8384 for i , n in enumerate (num_draft_tokens )
8485 ]
8586 num_rejected_tokens = torch .tensor (
@@ -88,22 +89,24 @@ def generate_token_ids(self,
8889 device = self .device ,
8990 )
9091 num_tokens = num_scheduled_tokens - sum (num_rejected_tokens )
91- cu_num_tokens , token_indices = self .prepare_inputs (
92- attn_metadata .query_start_loc , num_rejected_tokens ,
93- num_tokens )
92+ cu_num_tokens , token_indices = self ._prepare_inputs (
93+ attn_metadata .query_start_loc , num_rejected_tokens , num_tokens
94+ )
9495 target_token_ids = self .runner .input_ids [token_indices ]
95- target_positions = positions [token_indices ]
96+ target_positions = positions [token_indices ]
9697
97- (target_token_ids , target_positions ,
98- target_slot_mapping , cu_num_tokens ) = merge_next_token_ids_into_token_ids (
99- input_token_ids = target_token_ids ,
100- input_positions = target_positions ,
101- cad = attn_metadata ,
102- next_token_ids = next_token_ids ,
103- block_size = self .block_size ,
104- max_model_len = self .vllm_config .model_config .max_model_len ,
105- arange = self .arange ,
106- cu_num_tokens = cu_num_tokens )
98+ (target_token_ids , target_positions , target_slot_mapping , cu_num_tokens ) = (
99+ merge_next_token_ids_into_token_ids (
100+ input_token_ids = target_token_ids ,
101+ input_positions = target_positions ,
102+ cad = attn_metadata ,
103+ next_token_ids = next_token_ids ,
104+ block_size = self .block_size ,
105+ max_model_len = self .vllm_config .model_config .max_model_len ,
106+ arange = self .arange ,
107+ cu_num_tokens = cu_num_tokens ,
108+ )
109+ )
107110
108111 draft_token_ids = self ._propose (
109112 target_token_ids = target_token_ids ,
@@ -118,8 +121,6 @@ def generate_token_ids(self,
118121 spec_token_ids = draft_token_ids .tolist ()
119122
120123 return spec_token_ids
121-
122-
123124
124125 def _raise_if_mrope (self ):
125126 if self .draft_model_config .uses_mrope :
@@ -135,23 +136,23 @@ def _raise_if_padded_drafter_batch(self):
135136 "in the speculative_config."
136137 )
137138
138- def _raise_if_vocab_size_mismatch (self ):
139- speculative_config = self .vllm_config .speculative_config
140- if (
141- speculative_config .method == "draft_model"
142- and speculative_config .target_model_config is not None
143- and speculative_config .draft_model_config is not None
144- ):
145- target_vocab_size = speculative_config .target_model_config .get_vocab_size ()
146- draft_vocab_size = speculative_config .draft_model_config .get_vocab_size ()
147- if target_vocab_size != draft_vocab_size :
148- raise ValueError (
149- f"Target and draft model should have the same vocabulary size. "
150- f"Target model vocab_size={ target_vocab_size } . "
151- f"Draft model vocab_size={ draft_vocab_size } . "
152- f"Using models with different tokenizers can cause out-of-bounds "
153- f"errors during speculative decoding."
154- )
139+ def _raise_if_vocab_size_mismatch (self ):
140+ speculative_config = self .vllm_config .speculative_config
141+ if (
142+ speculative_config .method == "draft_model"
143+ and speculative_config .target_model_config is not None
144+ and speculative_config .draft_model_config is not None
145+ ):
146+ target_vocab_size = speculative_config .target_model_config .get_vocab_size ()
147+ draft_vocab_size = speculative_config .draft_model_config .get_vocab_size ()
148+ if target_vocab_size != draft_vocab_size :
149+ raise ValueError (
150+ f"Target and draft model should have the same vocabulary size. "
151+ f"Target model vocab_size={ target_vocab_size } . "
152+ f"Draft model vocab_size={ draft_vocab_size } . "
153+ f"Using models with different tokenizers can cause out-of-bounds "
154+ f"errors during speculative decoding."
155+ )
155156
156157 def _raise_if_draft_tp_mismatch (self ):
157158 # Note(Tomas Ruiz) If we run the target model with TP > 1 and
@@ -210,6 +211,7 @@ def load_model(self, target_model: Any) -> None:
210211 )
211212 self .attn_layer_name = next (iter (draft_attn_layer_names ))
212213
214+
213215def create_vllm_config_for_draft_model (
214216 target_model_vllm_config : VllmConfig ,
215217) -> VllmConfig :
@@ -220,17 +222,19 @@ def create_vllm_config_for_draft_model(
220222 The vllm_config is useful when loading the draft model with get_model().
221223 """
222224 old = target_model_vllm_config
223- new_parallel_config = replace (old . speculative_config . draft_parallel_config ,
224- rank = old .parallel_config .rank
225+ new_parallel_config = replace (
226+ old . speculative_config . draft_parallel_config , rank = old .parallel_config .rank
225227 )
226-
227- new : VllmConfig = replace (old ,
228+
229+ new : VllmConfig = replace (
230+ old ,
228231 quant_config = None , # quant_config is recomputed in __init__()
229232 model_config = old .speculative_config .draft_model_config ,
230233 parallel_config = new_parallel_config ,
231234 )
232235 return new
233236
237+
234238def merge_next_token_ids_into_token_ids (
235239 input_token_ids : torch .Tensor ,
236240 input_positions : torch .Tensor ,
@@ -239,8 +243,8 @@ def merge_next_token_ids_into_token_ids(
239243 block_size : int ,
240244 max_model_len : int ,
241245 arange : torch .Tensor ,
242- cu_num_tokens
243- ):
246+ cu_num_tokens ,
247+ ):
244248 """
245249 Merges the next token ids with the existing token ids into a flat sequence.
246250 Does the same for the positions, computes new slot mapping,
@@ -251,7 +255,7 @@ def merge_next_token_ids_into_token_ids(
251255 seqs = input_token_ids , end_locs = query_end_locs , new_vals = next_token_ids
252256 )
253257 logger .warning ("new_token_ids: {}" .format (new_token_ids ))
254-
258+
255259 # append new positions
256260 positions_to_append = input_positions [query_end_locs ] + 1
257261 new_positions = extend_flat_seqs (
@@ -260,16 +264,18 @@ def merge_next_token_ids_into_token_ids(
260264 # recompute slot mapping
261265 batch_size , n_blocks_per_req = cad .block_tables .shape
262266 req_indices = torch .arange (batch_size , device = cad .query_start_loc .device )
263-
267+
264268 query_lens = cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]
265- req_indices = torch .repeat_interleave (req_indices , query_lens .to (cad .query_start_loc .device ) + 1 )
269+ req_indices = torch .repeat_interleave (
270+ req_indices , query_lens .to (cad .query_start_loc .device ) + 1
271+ )
266272 block_table_indices = req_indices * n_blocks_per_req + new_positions // block_size
267273 block_nums = cad .block_tables .view (- 1 )[block_table_indices ]
268274 block_offsets = new_positions % block_size
269275 new_slot_mapping = block_nums * block_size + block_offsets
270276 # Mask out the position ids that exceed the max model length.
271277 exceeds_max_model_len = new_positions >= max_model_len
272278 new_slot_mapping .masked_fill_ (exceeds_max_model_len , PADDING_SLOT_ID )
273-
274- cu_num_tokens = cu_num_tokens + arange [: len (cu_num_tokens )]
275- return (new_token_ids , new_positions , new_slot_mapping , cu_num_tokens )
279+
280+ cu_num_tokens = cu_num_tokens + arange [: len (cu_num_tokens )]
281+ return (new_token_ids , new_positions , new_slot_mapping , cu_num_tokens )
0 commit comments