Skip to content

Commit 9b219c7

Browse files
author
01267596
committed
[feat] add draft_model spec_decode
Signed-off-by: 01267596 <[email protected]>
1 parent 46d5a77 commit 9b219c7

File tree

8 files changed

+388
-28
lines changed

8 files changed

+388
-28
lines changed

vllm_ascend/attention/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ class AscendCommonAttentionMetadata:
106106
prefill_context_parallel_metadata: Optional[
107107
AscendPrefillContextParallelMetadata] = None
108108

109+
max_seq_len: int = -1
110+
111+
def batch_size(self) -> int:
112+
return self.seq_lens_cpu.shape[0]
113+
114+
def query_lens(self) -> torch.Tensor:
115+
return self.query_start_loc[1:] - self.query_start_loc[:-1]
109116

110117
def split_decodes_and_prefills(
111118
common_attn_metadata: AscendCommonAttentionMetadata,
@@ -212,3 +219,27 @@ def transdata(nd_mat, block_size: tuple = (16, 16)):
212219
nz_mat,
213220
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
214221
return nz_mat
222+
223+
def extend_flat_seqs(
224+
seqs: torch.Tensor,
225+
end_locs: torch.Tensor,
226+
new_vals: torch.Tensor
227+
) -> torch.Tensor:
228+
"""
229+
This function appends a single new value into multiple sequences
230+
that are stored in a flat format. E.g.
231+
[x1, x2, y1] and [x3, y2] become [x1, x2, x3, y1, y2]
232+
"""
233+
new_len = seqs.shape[0] + new_vals.shape[0]
234+
new_seqs = torch.zeros(new_len, device=seqs.device, dtype=seqs.dtype)
235+
# indices for previous seqs
236+
start_locs = end_locs[:-1] + 1
237+
seqs_new_idxs = torch.ones_like(seqs)
238+
seqs_new_idxs[start_locs] += 1
239+
seqs_new_idxs = seqs_new_idxs.cumsum(0) - 1
240+
# indices for new values
241+
new_val_idxs = end_locs + 1 + torch.arange(new_vals.shape[0], device=seqs.device)
242+
# assign seqs and new vals
243+
new_seqs[seqs_new_idxs] = seqs
244+
new_seqs[new_val_idxs] = new_vals
245+
return new_seqs

vllm_ascend/core/schedule_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
class AscendSchedulerConfig(SchedulerConfig):
2828
enable_chunked_prefill: bool = False
2929
max_long_partial_prefills: int = 1
30+
max_num_partial_prefills: int = 1
3031
long_prefill_token_threshold: int = MAX_INT
3132
policy: str = "fcfs"
3233
scheduler_cls: Union[str, Type[object]] = (
@@ -47,6 +48,7 @@ def initialize_from_config(
4748
# Override default values into original SchedulerConfig
4849
scheduler_config["enable_chunked_prefill"] = False
4950
scheduler_config["max_long_partial_prefills"] = None
51+
scheduler_config["max_num_partial_prefills"] = None
5052
scheduler_config["long_prefill_token_threshold"] = None
5153
scheduler_config["policy"] = "fcfs"
5254
scheduler_config["scheduler_cls"] = (
@@ -78,6 +80,9 @@ def __post_init__(self, *args) -> None:
7880
self.max_long_partial_prefills = 1
7981
self.long_prefill_token_threshold = MAX_INT
8082

83+
if self.max_num_partial_prefills is None:
84+
self.max_num_partial_prefills = 1
85+
8186
if self.long_prefill_token_threshold is None or \
8287
self.long_prefill_token_threshold <= 0:
8388
if self.max_model_len is None:

vllm_ascend/patch/platform/patch_config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,6 @@ def __post_init__(self):
155155
)
156156
else:
157157
self.method = "draft_model"
158-
raise NotImplementedError(
159-
"Speculative decoding with draft model is not "
160-
"supported yet. Please consider using other "
161-
"speculative decoding methods such as ngram, medusa, "
162-
"eagle, or deepseek_mtp.")
163158

164159
# Replace hf_config for EAGLE draft_model
165160
if self.method in ("eagle", "eagle3"):

vllm_ascend/spec_decode/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
2121
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
2222
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
23+
from vllm_ascend.spec_decode.draft_proposer import DraftModelProposer
2324

2425

2526
def get_spec_decode_method(method,
@@ -35,6 +36,8 @@ def get_spec_decode_method(method,
3536
if is_torchair_graph:
3637
return TorchairMtpProposer(vllm_config, device, runner)
3738
return MtpProposer(vllm_config, device, runner)
39+
elif method == 'draft_model':
40+
return DraftModelProposer(vllm_config, device, runner)
3841
else:
3942
raise ValueError("Unknown speculative decoding method: "
4043
f"{method}")
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
from dataclasses import dataclass, replace
2+
from typing import Any
3+
4+
import torch
5+
6+
from vllm.attention.layer import Attention
7+
from vllm.config import VllmConfig, get_layers_from_vllm_config
8+
from vllm.config.speculative import SpeculativeConfig
9+
from vllm.logger import init_logger
10+
from vllm.model_executor.model_loader import get_model
11+
from vllm.v1.core.sched.output import SchedulerOutput
12+
from vllm.v1.sample.metadata import SamplingMetadata
13+
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
14+
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID
15+
16+
from vllm_ascend.spec_decode.eagle_proposer import SpecDecodeBaseProposer
17+
from vllm_ascend.attention.attention_v1 import AscendMetadata
18+
from vllm_ascend.attention.utils import extend_flat_seqs
19+
20+
logger = init_logger(__name__)
21+
22+
23+
class DraftModelProposer(SpecDecodeBaseProposer):
24+
def __init__(
25+
self,
26+
vllm_config: VllmConfig,
27+
device: torch.device,
28+
runner=None,
29+
):
30+
super().__init__(
31+
vllm_config=vllm_config,
32+
device=device,
33+
pass_hidden_states_to_model=False,
34+
runner=runner,
35+
)
36+
self.draft_model_config = vllm_config.speculative_config.draft_model_config
37+
self._raise_if_mrope()
38+
self._raise_if_padded_drafter_batch()
39+
self._raise_if_vocab_size_mismatch()
40+
self._raise_if_draft_tp_mismatch()
41+
42+
43+
def generate_token_ids(self,
44+
valid_sampled_token_ids: list[list[int]],
45+
sampling_metadata: SamplingMetadata = None,
46+
scheduler_output: SchedulerOutput = None,
47+
spec_decode_metadata: SpecDecodeMetadata = None,
48+
positions: torch.Tensor = None,
49+
num_scheduled_tokens: int = 0,
50+
hidden_states: torch.Tensor = None,
51+
attn_metadata=None,
52+
aux_hidden_states: torch.Tensor = None):
53+
54+
attn_metadata = self._get_atten_dict(scheduler_output)
55+
attn_metadata = attn_metadata[self.attn_layer_name]
56+
next_token_ids: list[int] = []
57+
for i, token_ids in enumerate(valid_sampled_token_ids):
58+
if token_ids:
59+
# Common case.
60+
next_token_id = token_ids[-1]
61+
else:
62+
# Partial prefill (rare case).
63+
# Get the next token id from the request state.
64+
req_id = self.runner.input_batch.req_ids[i]
65+
req_state = self.runner.requests[req_id]
66+
seq_len = (req_state.num_computed_tokens +
67+
scheduler_output.num_scheduled_tokens[req_id])
68+
69+
next_token_id = req_state.get_token_id(seq_len)
70+
next_token_ids.append(next_token_id)
71+
next_token_ids = torch.tensor(next_token_ids,
72+
dtype=torch.int32,
73+
device=self.device)
74+
75+
if spec_decode_metadata is None:
76+
# input_ids can be None for multimodal models.
77+
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
78+
target_positions = positions[:num_scheduled_tokens]
79+
cu_num_tokens =attn_metadata.query_start_loc
80+
else:
81+
num_draft_tokens = spec_decode_metadata.num_draft_tokens
82+
num_rejected_tokens = [n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
83+
for i, n in enumerate(num_draft_tokens)
84+
]
85+
num_rejected_tokens = torch.tensor(
86+
num_rejected_tokens,
87+
dtype=torch.int32,
88+
device=self.device,
89+
)
90+
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
91+
cu_num_tokens, token_indices = self.prepare_inputs(
92+
attn_metadata.query_start_loc, num_rejected_tokens,
93+
num_tokens)
94+
target_token_ids = self.runner.input_ids[token_indices]
95+
target_positions = positions[token_indices]
96+
97+
(target_token_ids, target_positions,
98+
target_slot_mapping, cu_num_tokens) = merge_next_token_ids_into_token_ids(
99+
input_token_ids=target_token_ids,
100+
input_positions=target_positions,
101+
cad=attn_metadata,
102+
next_token_ids=next_token_ids,
103+
block_size=self.block_size,
104+
max_model_len=self.vllm_config.model_config.max_model_len,
105+
arange=self.arange,
106+
cu_num_tokens=cu_num_tokens)
107+
108+
draft_token_ids = self._propose(
109+
target_token_ids=target_token_ids,
110+
target_positions=target_positions,
111+
target_hidden_states=None,
112+
target_slot_mapping=target_slot_mapping.to(torch.int32),
113+
next_token_ids=next_token_ids,
114+
cu_num_tokens=cu_num_tokens,
115+
block_table=attn_metadata.block_tables,
116+
sampling_metadata=sampling_metadata,
117+
)
118+
spec_token_ids = draft_token_ids.tolist()
119+
120+
return spec_token_ids
121+
122+
123+
124+
def _raise_if_mrope(self):
125+
if self.draft_model_config.uses_mrope:
126+
raise NotImplementedError(
127+
"Speculative Decoding with draft models does not support M-RoPE yet"
128+
)
129+
130+
def _raise_if_padded_drafter_batch(self):
131+
if not self.vllm_config.speculative_config.disable_padded_drafter_batch:
132+
raise NotImplementedError(
133+
"Speculative Decoding with draft models does not support "
134+
"padded drafter batch yet. Please pass --disable-padded-drafter-batch "
135+
"in the speculative_config."
136+
)
137+
138+
def _raise_if_vocab_size_mismatch(self):
139+
speculative_config = self.vllm_config.speculative_config
140+
if (
141+
speculative_config.method == "draft_model"
142+
and speculative_config.target_model_config is not None
143+
and speculative_config.draft_model_config is not None
144+
):
145+
target_vocab_size = speculative_config.target_model_config.get_vocab_size()
146+
draft_vocab_size = speculative_config.draft_model_config.get_vocab_size()
147+
if target_vocab_size != draft_vocab_size:
148+
raise ValueError(
149+
f"Target and draft model should have the same vocabulary size. "
150+
f"Target model vocab_size={target_vocab_size}. "
151+
f"Draft model vocab_size={draft_vocab_size}. "
152+
f"Using models with different tokenizers can cause out-of-bounds "
153+
f"errors during speculative decoding."
154+
)
155+
156+
def _raise_if_draft_tp_mismatch(self):
157+
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
158+
# the draft model with TP = 1, then the different TP ranks collide.
159+
# Specifically when all ranks compile the draft model on rank 0
160+
# (because TP=1), then the torch compile cache is overwritten and corrupted.
161+
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
162+
# To prevent this error, we assert that both TP sizes must be the same.
163+
spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config
164+
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
165+
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
166+
if draft_tp != tgt_tp:
167+
raise ValueError(
168+
f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
169+
f"must be the same. Got {draft_tp} and {tgt_tp}. "
170+
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
171+
)
172+
173+
def set_input_ids_first_pass(
174+
self,
175+
target_token_ids: torch.Tensor,
176+
next_token_ids: torch.Tensor,
177+
num_tokens: int,
178+
last_token_indices: torch.Tensor,
179+
) -> None:
180+
self.input_ids[:num_tokens] = target_token_ids
181+
182+
def load_model(self, target_model: Any) -> None:
183+
"""Takes target_model to satisfy the type checker."""
184+
185+
# This must be computed before loading the draft model
186+
# because that mutates the forward_context of the vllm_config
187+
target_attn_layer_names = set(
188+
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
189+
)
190+
191+
from vllm.compilation.backends import set_model_tag
192+
193+
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(
194+
target_model_vllm_config=self.vllm_config
195+
)
196+
logger.info(
197+
"Starting to load draft model %s. TP=%d, rank=%d",
198+
draft_vllm_config.model_config.model,
199+
draft_vllm_config.parallel_config.tensor_parallel_size,
200+
draft_vllm_config.parallel_config.rank,
201+
)
202+
with set_model_tag("draft_model"):
203+
self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model")
204+
205+
# This must be computed after loading the draft model
206+
# because that mutates the forward_context of the vllm_config
207+
draft_attn_layer_names = (
208+
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
209+
- target_attn_layer_names
210+
)
211+
self.attn_layer_name = next(iter(draft_attn_layer_names))
212+
213+
def create_vllm_config_for_draft_model(
214+
target_model_vllm_config: VllmConfig,
215+
) -> VllmConfig:
216+
"""The vllm_config is configured for the target model, e.g.
217+
its quant_config and parallel_config. But the draft model is potentially
218+
quantized differently, and has potentially different tensor_parallel_size.
219+
This function creates a new vllm_config configured for the draft model.
220+
The vllm_config is useful when loading the draft model with get_model().
221+
"""
222+
old = target_model_vllm_config
223+
new_parallel_config = replace(old.speculative_config.draft_parallel_config,
224+
rank=old.parallel_config.rank
225+
)
226+
227+
new: VllmConfig = replace(old,
228+
quant_config=None, # quant_config is recomputed in __init__()
229+
model_config=old.speculative_config.draft_model_config,
230+
parallel_config=new_parallel_config,
231+
)
232+
return new
233+
234+
def merge_next_token_ids_into_token_ids(
235+
input_token_ids: torch.Tensor,
236+
input_positions: torch.Tensor,
237+
cad: AscendMetadata,
238+
next_token_ids: torch.Tensor,
239+
block_size: int,
240+
max_model_len: int,
241+
arange: torch.Tensor,
242+
cu_num_tokens
243+
):
244+
"""
245+
Merges the next token ids with the existing token ids into a flat sequence.
246+
Does the same for the positions, computes new slot mapping,
247+
and updates the common_attn_metadata. The inputs are not modified in-place.
248+
"""
249+
query_end_locs = cu_num_tokens[1:] - 1
250+
new_token_ids = extend_flat_seqs(
251+
seqs=input_token_ids, end_locs=query_end_locs, new_vals=next_token_ids
252+
)
253+
logger.warning("new_token_ids: {}".format(new_token_ids))
254+
255+
# append new positions
256+
positions_to_append = input_positions[query_end_locs] + 1
257+
new_positions = extend_flat_seqs(
258+
seqs=input_positions, end_locs=query_end_locs, new_vals=positions_to_append
259+
)
260+
# recompute slot mapping
261+
batch_size, n_blocks_per_req = cad.block_tables.shape
262+
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)
263+
264+
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
265+
req_indices = torch.repeat_interleave(req_indices, query_lens.to(cad.query_start_loc.device) + 1)
266+
block_table_indices = req_indices * n_blocks_per_req + new_positions // block_size
267+
block_nums = cad.block_tables.view(-1)[block_table_indices]
268+
block_offsets = new_positions % block_size
269+
new_slot_mapping = block_nums * block_size + block_offsets
270+
# Mask out the position ids that exceed the max model length.
271+
exceeds_max_model_len = new_positions >= max_model_len
272+
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
273+
274+
cu_num_tokens = cu_num_tokens + arange[: len(cu_num_tokens)]
275+
return (new_token_ids, new_positions, new_slot_mapping, cu_num_tokens)

0 commit comments

Comments
 (0)