1+ from dataclasses import dataclass , replace
2+ from typing import Any
3+
4+ import torch
5+
6+ from vllm .attention .layer import Attention
7+ from vllm .config import VllmConfig , get_layers_from_vllm_config
8+ from vllm .config .speculative import SpeculativeConfig
9+ from vllm .logger import init_logger
10+ from vllm .model_executor .model_loader import get_model
11+ from vllm .v1 .core .sched .output import SchedulerOutput
12+ from vllm .v1 .sample .metadata import SamplingMetadata
13+ from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
14+ from vllm .v1 .spec_decode .eagle import PADDING_SLOT_ID
15+
16+ from vllm_ascend .spec_decode .eagle_proposer import SpecDecodeBaseProposer
17+ from vllm_ascend .attention .attention_v1 import AscendMetadata
18+ from vllm_ascend .attention .utils import extend_flat_seqs
19+
20+ logger = init_logger (__name__ )
21+
22+
23+ class DraftModelProposer (SpecDecodeBaseProposer ):
24+ def __init__ (
25+ self ,
26+ vllm_config : VllmConfig ,
27+ device : torch .device ,
28+ runner = None ,
29+ ):
30+ super ().__init__ (
31+ vllm_config = vllm_config ,
32+ device = device ,
33+ pass_hidden_states_to_model = False ,
34+ runner = runner ,
35+ )
36+ self .draft_model_config = vllm_config .speculative_config .draft_model_config
37+ self ._raise_if_mrope ()
38+ self ._raise_if_padded_drafter_batch ()
39+ self ._raise_if_vocab_size_mismatch ()
40+ self ._raise_if_draft_tp_mismatch ()
41+
42+
43+ def generate_token_ids (self ,
44+ valid_sampled_token_ids : list [list [int ]],
45+ sampling_metadata : SamplingMetadata = None ,
46+ scheduler_output : SchedulerOutput = None ,
47+ spec_decode_metadata : SpecDecodeMetadata = None ,
48+ positions : torch .Tensor = None ,
49+ num_scheduled_tokens : int = 0 ,
50+ hidden_states : torch .Tensor = None ,
51+ attn_metadata = None ,
52+ aux_hidden_states : torch .Tensor = None ):
53+
54+ attn_metadata = self ._get_atten_dict (scheduler_output )
55+ attn_metadata = attn_metadata [self .attn_layer_name ]
56+ next_token_ids : list [int ] = []
57+ for i , token_ids in enumerate (valid_sampled_token_ids ):
58+ if token_ids :
59+ # Common case.
60+ next_token_id = token_ids [- 1 ]
61+ else :
62+ # Partial prefill (rare case).
63+ # Get the next token id from the request state.
64+ req_id = self .runner .input_batch .req_ids [i ]
65+ req_state = self .runner .requests [req_id ]
66+ seq_len = (req_state .num_computed_tokens +
67+ scheduler_output .num_scheduled_tokens [req_id ])
68+
69+ next_token_id = req_state .get_token_id (seq_len )
70+ next_token_ids .append (next_token_id )
71+ next_token_ids = torch .tensor (next_token_ids ,
72+ dtype = torch .int32 ,
73+ device = self .device )
74+
75+ if spec_decode_metadata is None :
76+ # input_ids can be None for multimodal models.
77+ target_token_ids = self .runner .input_ids [:num_scheduled_tokens ]
78+ target_positions = positions [:num_scheduled_tokens ]
79+ cu_num_tokens = attn_metadata .query_start_loc
80+ else :
81+ num_draft_tokens = spec_decode_metadata .num_draft_tokens
82+ num_rejected_tokens = [n + 1 - len (valid_sampled_token_ids [i ]) if n > 0 else 0
83+ for i , n in enumerate (num_draft_tokens )
84+ ]
85+ num_rejected_tokens = torch .tensor (
86+ num_rejected_tokens ,
87+ dtype = torch .int32 ,
88+ device = self .device ,
89+ )
90+ num_tokens = num_scheduled_tokens - sum (num_rejected_tokens )
91+ cu_num_tokens , token_indices = self .prepare_inputs (
92+ attn_metadata .query_start_loc , num_rejected_tokens ,
93+ num_tokens )
94+ target_token_ids = self .runner .input_ids [token_indices ]
95+ target_positions = positions [token_indices ]
96+
97+ (target_token_ids , target_positions ,
98+ target_slot_mapping , cu_num_tokens ) = merge_next_token_ids_into_token_ids (
99+ input_token_ids = target_token_ids ,
100+ input_positions = target_positions ,
101+ cad = attn_metadata ,
102+ next_token_ids = next_token_ids ,
103+ block_size = self .block_size ,
104+ max_model_len = self .vllm_config .model_config .max_model_len ,
105+ arange = self .arange ,
106+ cu_num_tokens = cu_num_tokens )
107+
108+ draft_token_ids = self ._propose (
109+ target_token_ids = target_token_ids ,
110+ target_positions = target_positions ,
111+ target_hidden_states = None ,
112+ target_slot_mapping = target_slot_mapping .to (torch .int32 ),
113+ next_token_ids = next_token_ids ,
114+ cu_num_tokens = cu_num_tokens ,
115+ block_table = attn_metadata .block_tables ,
116+ sampling_metadata = sampling_metadata ,
117+ )
118+ spec_token_ids = draft_token_ids .tolist ()
119+
120+ return spec_token_ids
121+
122+
123+
124+ def _raise_if_mrope (self ):
125+ if self .draft_model_config .uses_mrope :
126+ raise NotImplementedError (
127+ "Speculative Decoding with draft models does not support M-RoPE yet"
128+ )
129+
130+ def _raise_if_padded_drafter_batch (self ):
131+ if not self .vllm_config .speculative_config .disable_padded_drafter_batch :
132+ raise NotImplementedError (
133+ "Speculative Decoding with draft models does not support "
134+ "padded drafter batch yet. Please pass --disable-padded-drafter-batch "
135+ "in the speculative_config."
136+ )
137+
138+ def _raise_if_vocab_size_mismatch (self ):
139+ speculative_config = self .vllm_config .speculative_config
140+ if (
141+ speculative_config .method == "draft_model"
142+ and speculative_config .target_model_config is not None
143+ and speculative_config .draft_model_config is not None
144+ ):
145+ target_vocab_size = speculative_config .target_model_config .get_vocab_size ()
146+ draft_vocab_size = speculative_config .draft_model_config .get_vocab_size ()
147+ if target_vocab_size != draft_vocab_size :
148+ raise ValueError (
149+ f"Target and draft model should have the same vocabulary size. "
150+ f"Target model vocab_size={ target_vocab_size } . "
151+ f"Draft model vocab_size={ draft_vocab_size } . "
152+ f"Using models with different tokenizers can cause out-of-bounds "
153+ f"errors during speculative decoding."
154+ )
155+
156+ def _raise_if_draft_tp_mismatch (self ):
157+ # Note(Tomas Ruiz) If we run the target model with TP > 1 and
158+ # the draft model with TP = 1, then the different TP ranks collide.
159+ # Specifically when all ranks compile the draft model on rank 0
160+ # (because TP=1), then the torch compile cache is overwritten and corrupted.
161+ # We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
162+ # To prevent this error, we assert that both TP sizes must be the same.
163+ spec_cfg : SpeculativeConfig = self .vllm_config .speculative_config
164+ tgt_tp = spec_cfg .target_parallel_config .tensor_parallel_size
165+ draft_tp = spec_cfg .draft_parallel_config .tensor_parallel_size
166+ if draft_tp != tgt_tp :
167+ raise ValueError (
168+ f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
169+ f"must be the same. Got { draft_tp } and { tgt_tp } . "
170+ "Please pass 'draft_tensor_parallel_size' in the speculative_config."
171+ )
172+
173+ def set_input_ids_first_pass (
174+ self ,
175+ target_token_ids : torch .Tensor ,
176+ next_token_ids : torch .Tensor ,
177+ num_tokens : int ,
178+ last_token_indices : torch .Tensor ,
179+ ) -> None :
180+ self .input_ids [:num_tokens ] = target_token_ids
181+
182+ def load_model (self , target_model : Any ) -> None :
183+ """Takes target_model to satisfy the type checker."""
184+
185+ # This must be computed before loading the draft model
186+ # because that mutates the forward_context of the vllm_config
187+ target_attn_layer_names = set (
188+ get_layers_from_vllm_config (self .vllm_config , Attention ).keys ()
189+ )
190+
191+ from vllm .compilation .backends import set_model_tag
192+
193+ draft_vllm_config : VllmConfig = create_vllm_config_for_draft_model (
194+ target_model_vllm_config = self .vllm_config
195+ )
196+ logger .info (
197+ "Starting to load draft model %s. TP=%d, rank=%d" ,
198+ draft_vllm_config .model_config .model ,
199+ draft_vllm_config .parallel_config .tensor_parallel_size ,
200+ draft_vllm_config .parallel_config .rank ,
201+ )
202+ with set_model_tag ("draft_model" ):
203+ self .model = get_model (vllm_config = draft_vllm_config , prefix = "draft_model" )
204+
205+ # This must be computed after loading the draft model
206+ # because that mutates the forward_context of the vllm_config
207+ draft_attn_layer_names = (
208+ get_layers_from_vllm_config (self .vllm_config , Attention ).keys ()
209+ - target_attn_layer_names
210+ )
211+ self .attn_layer_name = next (iter (draft_attn_layer_names ))
212+
213+ def create_vllm_config_for_draft_model (
214+ target_model_vllm_config : VllmConfig ,
215+ ) -> VllmConfig :
216+ """The vllm_config is configured for the target model, e.g.
217+ its quant_config and parallel_config. But the draft model is potentially
218+ quantized differently, and has potentially different tensor_parallel_size.
219+ This function creates a new vllm_config configured for the draft model.
220+ The vllm_config is useful when loading the draft model with get_model().
221+ """
222+ old = target_model_vllm_config
223+ new_parallel_config = replace (old .speculative_config .draft_parallel_config ,
224+ rank = old .parallel_config .rank
225+ )
226+
227+ new : VllmConfig = replace (old ,
228+ quant_config = None , # quant_config is recomputed in __init__()
229+ model_config = old .speculative_config .draft_model_config ,
230+ parallel_config = new_parallel_config ,
231+ )
232+ return new
233+
234+ def merge_next_token_ids_into_token_ids (
235+ input_token_ids : torch .Tensor ,
236+ input_positions : torch .Tensor ,
237+ cad : AscendMetadata ,
238+ next_token_ids : torch .Tensor ,
239+ block_size : int ,
240+ max_model_len : int ,
241+ arange : torch .Tensor ,
242+ cu_num_tokens
243+ ):
244+ """
245+ Merges the next token ids with the existing token ids into a flat sequence.
246+ Does the same for the positions, computes new slot mapping,
247+ and updates the common_attn_metadata. The inputs are not modified in-place.
248+ """
249+ query_end_locs = cu_num_tokens [1 :] - 1
250+ new_token_ids = extend_flat_seqs (
251+ seqs = input_token_ids , end_locs = query_end_locs , new_vals = next_token_ids
252+ )
253+ logger .warning ("new_token_ids: {}" .format (new_token_ids ))
254+
255+ # append new positions
256+ positions_to_append = input_positions [query_end_locs ] + 1
257+ new_positions = extend_flat_seqs (
258+ seqs = input_positions , end_locs = query_end_locs , new_vals = positions_to_append
259+ )
260+ # recompute slot mapping
261+ batch_size , n_blocks_per_req = cad .block_tables .shape
262+ req_indices = torch .arange (batch_size , device = cad .query_start_loc .device )
263+
264+ query_lens = cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]
265+ req_indices = torch .repeat_interleave (req_indices , query_lens .to (cad .query_start_loc .device ) + 1 )
266+ block_table_indices = req_indices * n_blocks_per_req + new_positions // block_size
267+ block_nums = cad .block_tables .view (- 1 )[block_table_indices ]
268+ block_offsets = new_positions % block_size
269+ new_slot_mapping = block_nums * block_size + block_offsets
270+ # Mask out the position ids that exceed the max model length.
271+ exceeds_max_model_len = new_positions >= max_model_len
272+ new_slot_mapping .masked_fill_ (exceeds_max_model_len , PADDING_SLOT_ID )
273+
274+ cu_num_tokens = cu_num_tokens + arange [: len (cu_num_tokens )]
275+ return (new_token_ids , new_positions , new_slot_mapping , cu_num_tokens )
0 commit comments