Skip to content

Commit 51506a1

Browse files
author
01267596
committed
[feat] add native kvcache offload
Signed-off-by: 01267596 <[email protected]>
1 parent 9b219c7 commit 51506a1

File tree

1 file changed

+76
-70
lines changed

1 file changed

+76
-70
lines changed

vllm_ascend/spec_decode/draft_proposer.py

Lines changed: 76 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
1-
from dataclasses import dataclass, replace
1+
from dataclasses import replace
22
from typing import Any
33

44
import torch
5-
65
from vllm.attention.layer import Attention
76
from vllm.config import VllmConfig, get_layers_from_vllm_config
87
from vllm.config.speculative import SpeculativeConfig
98
from vllm.logger import init_logger
109
from vllm.model_executor.model_loader import get_model
1110
from vllm.v1.core.sched.output import SchedulerOutput
1211
from vllm.v1.sample.metadata import SamplingMetadata
13-
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1412
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID
15-
16-
from vllm_ascend.spec_decode.eagle_proposer import SpecDecodeBaseProposer
13+
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1714
from vllm_ascend.attention.attention_v1 import AscendMetadata
1815
from vllm_ascend.attention.utils import extend_flat_seqs
16+
from vllm_ascend.spec_decode.eagle_proposer import SpecDecodeBaseProposer
1917

2018
logger = init_logger(__name__)
2119

@@ -39,18 +37,18 @@ def __init__(
3937
self._raise_if_vocab_size_mismatch()
4038
self._raise_if_draft_tp_mismatch()
4139

42-
43-
def generate_token_ids(self,
44-
valid_sampled_token_ids: list[list[int]],
45-
sampling_metadata: SamplingMetadata = None,
46-
scheduler_output: SchedulerOutput = None,
47-
spec_decode_metadata: SpecDecodeMetadata = None,
48-
positions: torch.Tensor = None,
49-
num_scheduled_tokens: int = 0,
50-
hidden_states: torch.Tensor = None,
51-
attn_metadata=None,
52-
aux_hidden_states: torch.Tensor = None):
53-
40+
def generate_token_ids(
41+
self,
42+
valid_sampled_token_ids: list[list[int]],
43+
sampling_metadata: SamplingMetadata = None,
44+
scheduler_output: SchedulerOutput = None,
45+
spec_decode_metadata: SpecDecodeMetadata = None,
46+
positions: torch.Tensor = None,
47+
num_scheduled_tokens: int = 0,
48+
hidden_states: torch.Tensor = None,
49+
attn_metadata=None,
50+
aux_hidden_states: torch.Tensor = None,
51+
):
5452
attn_metadata = self._get_atten_dict(scheduler_output)
5553
attn_metadata = attn_metadata[self.attn_layer_name]
5654
next_token_ids: list[int] = []
@@ -63,23 +61,26 @@ def generate_token_ids(self,
6361
# Get the next token id from the request state.
6462
req_id = self.runner.input_batch.req_ids[i]
6563
req_state = self.runner.requests[req_id]
66-
seq_len = (req_state.num_computed_tokens +
67-
scheduler_output.num_scheduled_tokens[req_id])
64+
seq_len = (
65+
req_state.num_computed_tokens
66+
+ scheduler_output.num_scheduled_tokens[req_id]
67+
)
6868

6969
next_token_id = req_state.get_token_id(seq_len)
7070
next_token_ids.append(next_token_id)
71-
next_token_ids = torch.tensor(next_token_ids,
72-
dtype=torch.int32,
73-
device=self.device)
74-
71+
next_token_ids = torch.tensor(
72+
next_token_ids, dtype=torch.int32, device=self.device
73+
)
74+
7575
if spec_decode_metadata is None:
7676
# input_ids can be None for multimodal models.
7777
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
7878
target_positions = positions[:num_scheduled_tokens]
79-
cu_num_tokens =attn_metadata.query_start_loc
79+
cu_num_tokens = attn_metadata.query_start_loc
8080
else:
8181
num_draft_tokens = spec_decode_metadata.num_draft_tokens
82-
num_rejected_tokens = [n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
82+
num_rejected_tokens = [
83+
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
8384
for i, n in enumerate(num_draft_tokens)
8485
]
8586
num_rejected_tokens = torch.tensor(
@@ -88,22 +89,24 @@ def generate_token_ids(self,
8889
device=self.device,
8990
)
9091
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
91-
cu_num_tokens, token_indices = self.prepare_inputs(
92-
attn_metadata.query_start_loc, num_rejected_tokens,
93-
num_tokens)
92+
cu_num_tokens, token_indices = self._prepare_inputs(
93+
attn_metadata.query_start_loc, num_rejected_tokens, num_tokens
94+
)
9495
target_token_ids = self.runner.input_ids[token_indices]
95-
target_positions = positions[token_indices]
96+
target_positions = positions[token_indices]
9697

97-
(target_token_ids, target_positions,
98-
target_slot_mapping, cu_num_tokens) = merge_next_token_ids_into_token_ids(
99-
input_token_ids=target_token_ids,
100-
input_positions=target_positions,
101-
cad=attn_metadata,
102-
next_token_ids=next_token_ids,
103-
block_size=self.block_size,
104-
max_model_len=self.vllm_config.model_config.max_model_len,
105-
arange=self.arange,
106-
cu_num_tokens=cu_num_tokens)
98+
(target_token_ids, target_positions, target_slot_mapping, cu_num_tokens) = (
99+
merge_next_token_ids_into_token_ids(
100+
input_token_ids=target_token_ids,
101+
input_positions=target_positions,
102+
cad=attn_metadata,
103+
next_token_ids=next_token_ids,
104+
block_size=self.block_size,
105+
max_model_len=self.vllm_config.model_config.max_model_len,
106+
arange=self.arange,
107+
cu_num_tokens=cu_num_tokens,
108+
)
109+
)
107110

108111
draft_token_ids = self._propose(
109112
target_token_ids=target_token_ids,
@@ -118,8 +121,6 @@ def generate_token_ids(self,
118121
spec_token_ids = draft_token_ids.tolist()
119122

120123
return spec_token_ids
121-
122-
123124

124125
def _raise_if_mrope(self):
125126
if self.draft_model_config.uses_mrope:
@@ -135,23 +136,23 @@ def _raise_if_padded_drafter_batch(self):
135136
"in the speculative_config."
136137
)
137138

138-
def _raise_if_vocab_size_mismatch(self):
139-
speculative_config = self.vllm_config.speculative_config
140-
if (
141-
speculative_config.method == "draft_model"
142-
and speculative_config.target_model_config is not None
143-
and speculative_config.draft_model_config is not None
144-
):
145-
target_vocab_size = speculative_config.target_model_config.get_vocab_size()
146-
draft_vocab_size = speculative_config.draft_model_config.get_vocab_size()
147-
if target_vocab_size != draft_vocab_size:
148-
raise ValueError(
149-
f"Target and draft model should have the same vocabulary size. "
150-
f"Target model vocab_size={target_vocab_size}. "
151-
f"Draft model vocab_size={draft_vocab_size}. "
152-
f"Using models with different tokenizers can cause out-of-bounds "
153-
f"errors during speculative decoding."
154-
)
139+
def _raise_if_vocab_size_mismatch(self):
140+
speculative_config = self.vllm_config.speculative_config
141+
if (
142+
speculative_config.method == "draft_model"
143+
and speculative_config.target_model_config is not None
144+
and speculative_config.draft_model_config is not None
145+
):
146+
target_vocab_size = speculative_config.target_model_config.get_vocab_size()
147+
draft_vocab_size = speculative_config.draft_model_config.get_vocab_size()
148+
if target_vocab_size != draft_vocab_size:
149+
raise ValueError(
150+
f"Target and draft model should have the same vocabulary size. "
151+
f"Target model vocab_size={target_vocab_size}. "
152+
f"Draft model vocab_size={draft_vocab_size}. "
153+
f"Using models with different tokenizers can cause out-of-bounds "
154+
f"errors during speculative decoding."
155+
)
155156

156157
def _raise_if_draft_tp_mismatch(self):
157158
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
@@ -210,6 +211,7 @@ def load_model(self, target_model: Any) -> None:
210211
)
211212
self.attn_layer_name = next(iter(draft_attn_layer_names))
212213

214+
213215
def create_vllm_config_for_draft_model(
214216
target_model_vllm_config: VllmConfig,
215217
) -> VllmConfig:
@@ -220,17 +222,19 @@ def create_vllm_config_for_draft_model(
220222
The vllm_config is useful when loading the draft model with get_model().
221223
"""
222224
old = target_model_vllm_config
223-
new_parallel_config = replace(old.speculative_config.draft_parallel_config,
224-
rank=old.parallel_config.rank
225+
new_parallel_config = replace(
226+
old.speculative_config.draft_parallel_config, rank=old.parallel_config.rank
225227
)
226-
227-
new: VllmConfig = replace(old,
228+
229+
new: VllmConfig = replace(
230+
old,
228231
quant_config=None, # quant_config is recomputed in __init__()
229232
model_config=old.speculative_config.draft_model_config,
230233
parallel_config=new_parallel_config,
231234
)
232235
return new
233236

237+
234238
def merge_next_token_ids_into_token_ids(
235239
input_token_ids: torch.Tensor,
236240
input_positions: torch.Tensor,
@@ -239,8 +243,8 @@ def merge_next_token_ids_into_token_ids(
239243
block_size: int,
240244
max_model_len: int,
241245
arange: torch.Tensor,
242-
cu_num_tokens
243-
):
246+
cu_num_tokens,
247+
):
244248
"""
245249
Merges the next token ids with the existing token ids into a flat sequence.
246250
Does the same for the positions, computes new slot mapping,
@@ -251,7 +255,7 @@ def merge_next_token_ids_into_token_ids(
251255
seqs=input_token_ids, end_locs=query_end_locs, new_vals=next_token_ids
252256
)
253257
logger.warning("new_token_ids: {}".format(new_token_ids))
254-
258+
255259
# append new positions
256260
positions_to_append = input_positions[query_end_locs] + 1
257261
new_positions = extend_flat_seqs(
@@ -260,16 +264,18 @@ def merge_next_token_ids_into_token_ids(
260264
# recompute slot mapping
261265
batch_size, n_blocks_per_req = cad.block_tables.shape
262266
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)
263-
267+
264268
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
265-
req_indices = torch.repeat_interleave(req_indices, query_lens.to(cad.query_start_loc.device) + 1)
269+
req_indices = torch.repeat_interleave(
270+
req_indices, query_lens.to(cad.query_start_loc.device) + 1
271+
)
266272
block_table_indices = req_indices * n_blocks_per_req + new_positions // block_size
267273
block_nums = cad.block_tables.view(-1)[block_table_indices]
268274
block_offsets = new_positions % block_size
269275
new_slot_mapping = block_nums * block_size + block_offsets
270276
# Mask out the position ids that exceed the max model length.
271277
exceeds_max_model_len = new_positions >= max_model_len
272278
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
273-
274-
cu_num_tokens = cu_num_tokens + arange[: len(cu_num_tokens)]
275-
return (new_token_ids, new_positions, new_slot_mapping, cu_num_tokens)
279+
280+
cu_num_tokens = cu_num_tokens + arange[: len(cu_num_tokens)]
281+
return (new_token_ids, new_positions, new_slot_mapping, cu_num_tokens)

0 commit comments

Comments
 (0)