Skip to content

Commit b045413

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
authored andcommitted
[Core] CUDA Graphs for Multi-Step + Chunked-Prefill (vllm-project#8645)
Co-authored-by: Varun Sundar Rabindranath <[email protected]> Signed-off-by: Alvant <[email protected]>
1 parent e270881 commit b045413

File tree

3 files changed

+97
-34
lines changed

3 files changed

+97
-34
lines changed

csrc/prepare_inputs/advance_step.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
1717
long const* sampled_token_ids_ptr, long* input_positions_ptr,
1818
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
1919
int64_t const block_tables_stride) {
20+
int const n_pad = num_seqs - num_queries;
21+
if (n_pad && blockIdx.x == 0) {
22+
// Handle cuda graph padding
23+
int const offset = num_queries;
24+
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
25+
input_tokens_ptr[offset + i] = 0;
26+
input_positions_ptr[offset + i] = 0;
27+
slot_mapping_ptr[offset + i] = -1;
28+
}
29+
}
30+
2031
int num_query_blocks = div_ceil(num_queries, num_threads);
2132

2233
if (blockIdx.x >= num_query_blocks) {

vllm/attention/backends/flash_attn.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,30 @@ def _add_seq_group(
500500
seq_len, context_len, start_idx,
501501
self.block_size, inter_data.block_tables)
502502

503+
def _get_graph_runner_block_tables(
504+
self, num_seqs: int,
505+
block_tables: List[List[int]]) -> torch.Tensor:
506+
# The shape of graph_block_tables is
507+
# [max batch size, max context len // block size].
508+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
509+
assert max_batch_size >= num_seqs
510+
511+
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
512+
for i, block_table in enumerate(block_tables):
513+
if block_table:
514+
num_blocks = len(block_table)
515+
if num_blocks <= max_blocks:
516+
graph_block_tables[i, :num_blocks] = block_table
517+
else:
518+
# It may be possible to have more blocks allocated due
519+
# to lookahead slots of multi-step, however, they are
520+
# not used anyway, so can be safely ignored.
521+
graph_block_tables[
522+
i, :max_blocks] = block_table[:max_blocks]
523+
524+
return torch.from_numpy(graph_block_tables).to(
525+
device=self.runner.device, non_blocking=True)
526+
503527
def build(self, seq_lens: List[int], query_lens: List[int],
504528
cuda_graph_pad_size: int, batch_size: int):
505529
"""Build attention metadata with on-device tensors.
@@ -533,29 +557,13 @@ def build(self, seq_lens: List[int], query_lens: List[int],
533557
max_decode_seq_len = max(self.curr_seq_lens, default=0)
534558
num_decode_tokens = self.num_decode_tokens
535559

560+
num_seqs = len(seq_lens)
536561
if use_captured_graph:
537562
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
538563
self.block_tables.extend([] * cuda_graph_pad_size)
539-
num_decode_tokens = batch_size
540-
541-
# The shape of graph_block_tables is
542-
# [max batch size, max context len // block size].
543-
input_block_tables = self.runner.graph_block_tables[:batch_size]
544-
max_blocks = input_block_tables.shape[1]
545-
for i, block_table in enumerate(self.block_tables):
546-
if block_table:
547-
num_blocks = len(block_table)
548-
if num_blocks <= max_blocks:
549-
input_block_tables[i, :num_blocks] = block_table
550-
else:
551-
# It may be possible to have more blocks allocated due
552-
# to lookahead slots of multi-step, however, they are
553-
# not used anyway, so can be safely ignored.
554-
input_block_tables[
555-
i, :max_blocks] = block_table[:max_blocks]
556-
557-
block_tables = torch.from_numpy(input_block_tables).to(
558-
device=device, non_blocking=True)
564+
num_decode_tokens = batch_size - self.num_prefill_tokens
565+
block_tables = self._get_graph_runner_block_tables(
566+
num_seqs, self.block_tables)
559567
else:
560568
block_tables = make_tensor_with_pad(
561569
self.block_tables,

vllm/worker/model_runner.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -712,14 +712,62 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
712712

713713
def _use_captured_graph(self,
714714
batch_size: int,
715+
decode_only: bool,
715716
max_decode_seq_len: int,
716717
max_encoder_seq_len: int = 0) -> bool:
717-
return (self.decode_only and not self.runner.model_config.enforce_eager
718+
return (decode_only and not self.runner.model_config.enforce_eager
718719
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
719720
and max_decode_seq_len <= self.runner.max_seq_len_to_capture
720721
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
721722
and batch_size <= self.runner.max_batchsize_to_capture)
722723

724+
def _get_cuda_graph_pad_size(self,
725+
num_seqs: int,
726+
max_decode_seq_len: int,
727+
max_encoder_seq_len: int = 0) -> int:
728+
"""
729+
Determine the number of padding sequences required for running in
730+
CUDA graph mode. Returns -1 if CUDA graphs cannot be used.
731+
732+
In the multi-step + chunked-prefill case, only the first step
733+
has Prefills (if any). The rest of the steps are guaranteed to be all
734+
decodes. In this case, we set up the padding as if all the sequences
735+
are decodes so we may run all steps except the first step in CUDA graph
736+
mode. The padding is accounted for in the multi-step `advance_step`
737+
family of functions.
738+
739+
Args:
740+
num_seqs (int): Number of sequences scheduled to run.
741+
max_decode_seq_len (int): Greatest of all the decode sequence
742+
lengths. Used only in checking the viablility of using
743+
CUDA graphs.
744+
max_encoder_seq_len (int, optional): Greatest of all the encode
745+
sequence lengths. Defaults to 0. Used only in checking the
746+
viability of using CUDA graphs.
747+
Returns:
748+
int: Returns the determined number of padding sequences. If
749+
CUDA graphs is not viable, returns -1.
750+
"""
751+
is_mscp: bool = self.runner.scheduler_config.is_multi_step and \
752+
self.runner.scheduler_config.chunked_prefill_enabled
753+
decode_only = self.decode_only or is_mscp
754+
if not decode_only:
755+
# Early exit so we can treat num_seqs as the batch_size below.
756+
return -1
757+
758+
# batch_size out of this function refers to the number of input
759+
# tokens being scheduled. This conflation of num_seqs as batch_size
760+
# is valid as this is a decode-only case.
761+
batch_size = num_seqs
762+
if not self._use_captured_graph(batch_size, decode_only,
763+
max_decode_seq_len,
764+
max_encoder_seq_len):
765+
return -1
766+
767+
graph_batch_size = _get_graph_batch_size(batch_size)
768+
assert graph_batch_size >= batch_size
769+
return graph_batch_size - batch_size
770+
723771
def build(self) -> ModelInputForGPU:
724772
"""Finalize the builder intermediate data and
725773
create on-device tensors.
@@ -778,21 +826,17 @@ def build(self) -> ModelInputForGPU:
778826
for data in self.inter_data_list
779827
}
780828

781-
batch_size = len(input_tokens)
782-
use_captured_graph = self._use_captured_graph(
783-
batch_size,
784-
max_decode_seq_len,
829+
cuda_graph_pad_size = self._get_cuda_graph_pad_size(
830+
num_seqs=len(seq_lens),
831+
max_decode_seq_len=max_encoder_seq_len,
785832
max_encoder_seq_len=max_encoder_seq_len)
786833

787-
# If cuda graph can be used, pad tensors accordingly.
788-
# See `capture_model` API for more details.
789-
# vLLM uses cuda graph only for decoding requests.
790-
cuda_graph_pad_size = -1
791-
if use_captured_graph:
792-
graph_batch_size = _get_graph_batch_size(batch_size)
793-
assert graph_batch_size >= batch_size
794-
cuda_graph_pad_size = graph_batch_size - batch_size
795-
batch_size = graph_batch_size
834+
batch_size = len(input_tokens)
835+
if cuda_graph_pad_size != -1:
836+
# If cuda graph can be used, pad tensors accordingly.
837+
# See `capture_model` API for more details.
838+
# vLLM uses cuda graph only for decoding requests.
839+
batch_size += cuda_graph_pad_size
796840

797841
# Tokens and positions.
798842
if cuda_graph_pad_size:

0 commit comments

Comments
 (0)