Skip to content

Commit 869d762

Browse files
comaniacLeiWang1999
authored andcommitted
[Bugfix] Fix decode tokens w. CUDA graph (vllm-project#6757)
Signed-off-by: LeiWang1999 <[email protected]>
1 parent a411329 commit 869d762

File tree

4 files changed

+31
-4
lines changed

4 files changed

+31
-4
lines changed

tests/worker/test_model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def test_prepare_decode_cuda_graph(batch_size):
193193
for _ in range(expected_bs - len(seq_lens)):
194194
seq_lens.append(1)
195195
assert attn_metadata.seq_lens == seq_lens
196+
assert attn_metadata.num_decode_tokens == len(seq_lens)
196197
start_idx = 0
197198
start_loc = [start_idx]
198199
for _ in context_lens:

vllm/attention/backends/flash_attn.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,15 @@ def _add_seq_group(
272272

273273
def build(self, seq_lens: List[int], query_lens: List[int],
274274
cuda_graph_pad_size: int, batch_size: int):
275-
"""Build attention metadata with on-device tensors."""
275+
"""Build attention metadata with on-device tensors.
276+
277+
Args:
278+
seq_lens: The maybe padded sequence lengths of the input sequences.
279+
query_lens: The query lengths of the input sequences.
280+
cuda_graph_pad_size: The padding size for cuda graph.
281+
-1 if cuda graph is not used.
282+
batch_size: The maybe padded batch size.
283+
"""
276284
for inter_data in self.input_builder.inter_data_list:
277285
self._add_seq_group(inter_data,
278286
self.input_builder.chunked_prefill_enabled)
@@ -297,7 +305,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
297305
if use_captured_graph:
298306
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
299307
self.block_tables.extend([] * cuda_graph_pad_size)
300-
num_decode_tokens = batch_size + cuda_graph_pad_size
308+
num_decode_tokens = batch_size
301309

302310
# The shape of graph_block_tables is
303311
# [max batch size, max context len // block size].

vllm/attention/backends/flashinfer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,15 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
320320

321321
def build(self, seq_lens: List[int], query_lens: List[int],
322322
cuda_graph_pad_size: int, batch_size: int):
323+
"""Build attention metadata with on-device tensors.
324+
325+
Args:
326+
seq_lens: The maybe padded sequence lengths of the input sequences.
327+
query_lens: The query lengths of the input sequences.
328+
cuda_graph_pad_size: The padding size for cuda graph.
329+
-1 if cuda graph is not used.
330+
batch_size: The maybe padded batch size.
331+
"""
323332
for inter_data in self.input_builder.inter_data_list:
324333
self._add_seq_group(inter_data,
325334
self.input_builder.chunked_prefill_enabled)
@@ -334,7 +343,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
334343
if use_captured_graph:
335344
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
336345
self.block_tables.extend([] * cuda_graph_pad_size)
337-
num_decode_tokens = batch_size + cuda_graph_pad_size
346+
num_decode_tokens = batch_size
338347

339348
# The shape of graph_block_tables is
340349
# [max batch size, max context len // block size].

vllm/attention/backends/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,15 @@ def _add_seq_group(
149149

150150
def build(self, seq_lens: List[int], query_lens: List[int],
151151
cuda_graph_pad_size: int, batch_size: int):
152+
"""Build attention metadata with on-device tensors.
153+
154+
Args:
155+
seq_lens: The maybe padded sequence lengths of the input sequences.
156+
query_lens: The query lengths of the input sequences.
157+
cuda_graph_pad_size: The padding size for cuda graph.
158+
-1 if cuda graph is not used.
159+
batch_size: The maybe padded batch size.
160+
"""
152161
for inter_data in self.input_builder.inter_data_list:
153162
self._add_seq_group(inter_data,
154163
self.input_builder.chunked_prefill_enabled)
@@ -173,7 +182,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
173182
if use_captured_graph:
174183
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
175184
self.block_tables.extend([] * cuda_graph_pad_size)
176-
num_decode_tokens = batch_size + cuda_graph_pad_size
185+
num_decode_tokens = batch_size
177186

178187
# The shape of graph_block_tables is
179188
# [max batch size, max context len // block size].

0 commit comments

Comments
 (0)