Skip to content

Commit 4d0a890

Browse files
leex404Xin Li
authored andcommitted
fix: private memory size too large in sample_recovered_tokens_kernel (#115)
* [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <[email protected]> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <[email protected]> * [chore] change file directory Signed-off-by: Xin Li <[email protected]> --------- Signed-off-by: Xin Li <[email protected]> Co-authored-by: Xin Li <[email protected]>
1 parent 62f9761 commit 4d0a890

File tree

4 files changed

+91
-1
lines changed

4 files changed

+91
-1
lines changed

vllm_metax/models/deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def sparse_attn_indexer(
632632
decode_metadata.seq_lens,
633633
decode_metadata.block_table,
634634
decode_metadata.schedule_metadata,
635-
max_context_len=max_model_len,
635+
max_model_len,
636636
)
637637
# padded query len
638638
current_device = padded_q_bf16_decode_tokens.device

vllm_metax/patch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from . import device_allocator
77
from . import model_executor
88
from . import oot
9+
from . import sample
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from . import rejection_sampler
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from vllm.triton_utils import tl, triton
3+
4+
import vllm.v1.sample.rejection_sampler
5+
6+
# SPDX-License-Identifier: Apache-2.0
7+
8+
9+
@triton.jit
10+
def sample_recovered_tokens_kernel(
11+
output_token_ids_ptr, # [num_tokens]
12+
cu_num_draft_tokens_ptr, # [batch_size]
13+
draft_token_ids_ptr, # [num_tokens]
14+
draft_probs_ptr, # [num_tokens, vocab_size] or None
15+
target_probs_ptr, # [num_tokens, vocab_size]
16+
q_ptr, # [batch_size, vocab_size]
17+
vocab_size,
18+
PADDED_VOCAB_SIZE: tl.constexpr,
19+
NO_DRAFT_PROBS: tl.constexpr,
20+
BLOCK_SIZE: tl.constexpr = 1024,
21+
):
22+
req_idx = tl.program_id(0)
23+
if req_idx == 0:
24+
start_idx = 0
25+
else:
26+
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
27+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
28+
num_draft_tokens = end_idx - start_idx
29+
30+
# Early exit for out-of-range positions.
31+
pos = tl.program_id(1)
32+
if pos >= num_draft_tokens:
33+
return
34+
35+
max_prob = -float('inf')
36+
best_token_id = 0
37+
38+
for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE):
39+
block_end = min(block_start + BLOCK_SIZE, vocab_size)
40+
41+
vocab_offset = tl.arange(0, BLOCK_SIZE)
42+
mask = vocab_offset < block_end - block_start
43+
44+
if NO_DRAFT_PROBS:
45+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
46+
prob = tl.load(
47+
target_probs_ptr + (start_idx + pos) * vocab_size +
48+
block_start + vocab_offset,
49+
mask=(mask & (vocab_offset + block_start != draft_token_id)),
50+
other=0)
51+
52+
else:
53+
draft_prob = tl.load(draft_probs_ptr +
54+
(start_idx + pos) * vocab_size + block_start +
55+
vocab_offset,
56+
mask=mask,
57+
other=0)
58+
target_prob = tl.load(target_probs_ptr +
59+
(start_idx + pos) * vocab_size +
60+
block_start + vocab_offset,
61+
mask=mask,
62+
other=0)
63+
prob = tl.maximum(target_prob - draft_prob, 0)
64+
65+
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
66+
# `tl.argmax` will select the maximum value.
67+
68+
q = tl.load(q_ptr + req_idx * vocab_size + block_start + vocab_offset,
69+
mask=mask,
70+
other=float("-inf"))
71+
72+
# recovered_id = tl.argmax(prob / q, axis=-1)
73+
# calc block prob and token ID
74+
block_prob = prob / q
75+
block_max_prob = tl.max(block_prob, axis=-1)
76+
block_best_token_id = tl.argmax(block_prob, axis=-1) + block_start
77+
78+
# update token ID
79+
max_prob = tl.maximum(max_prob, block_max_prob)
80+
best_token_id = tl.where(block_max_prob >= max_prob,
81+
block_best_token_id, best_token_id)
82+
83+
tl.store(output_token_ids_ptr + start_idx + pos, best_token_id)
84+
85+
86+
vllm.v1.sample.rejection_sampler.sample_recovered_tokens_kernel = sample_recovered_tokens_kernel

0 commit comments

Comments
 (0)