Skip to content
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b7d124c
[V1][CUDA Graph] Fix attention metadata tensor sizes for padded batches
ayushsatyam146 Sep 19, 2025
a42e3ce
review comments
LucasWilkinson Oct 5, 2025
f586e2b
Merge branch 'main' into cudagraph-fix
LucasWilkinson Oct 6, 2025
93071cf
format
LucasWilkinson Oct 6, 2025
f6b7d63
Merge branch 'main' into cudagraph-fix
LucasWilkinson Oct 6, 2025
9acd529
cleanup
LucasWilkinson Oct 6, 2025
d6a5a9a
cleanup
LucasWilkinson Oct 7, 2025
4a0e9df
more cleanup
LucasWilkinson Oct 7, 2025
0895405
Merge remote-tracking branch 'origin/main' into cudagraph-fix
LucasWilkinson Oct 7, 2025
3f495e2
cleanup
LucasWilkinson Oct 7, 2025
01db86e
clean up
LucasWilkinson Oct 7, 2025
5780e9a
no need for lora change
LucasWilkinson Oct 7, 2025
00c3280
review comments
LucasWilkinson Oct 8, 2025
5d838ed
Merge remote-tracking branch 'origin/main' into cudagraph-fix
LucasWilkinson Oct 8, 2025
370e700
more refactoring
LucasWilkinson Oct 8, 2025
a335b80
unifiy build attention metadata
LucasWilkinson Oct 8, 2025
89ca98e
clean-up
LucasWilkinson Oct 8, 2025
bbbc8fb
refactor
LucasWilkinson Oct 14, 2025
997f71e
wip
LucasWilkinson Oct 14, 2025
c907ef3
cleanup
LucasWilkinson Oct 14, 2025
d88842f
cleanup
LucasWilkinson Oct 14, 2025
39562aa
fix
LucasWilkinson Oct 14, 2025
0d997cf
cleanup
LucasWilkinson Oct 14, 2025
ccfb764
cleanup
LucasWilkinson Oct 14, 2025
021882c
fix
LucasWilkinson Oct 14, 2025
1d533bc
clean up
LucasWilkinson Oct 14, 2025
e48ab54
fix docs error
LucasWilkinson Oct 14, 2025
f3d09f5
Merge remote-tracking branch 'nm/lwilkinson/seperate-build-attn-metad…
LucasWilkinson Oct 15, 2025
d23b110
Fix merge conflicts: add missing imports and fix indentation
LucasWilkinson Oct 15, 2025
c3eba9b
Merge remote-tracking branch 'origin/main' into pr/ayushsatyam146/24002
LucasWilkinson Nov 11, 2025
d00b29c
wip
LucasWilkinson Nov 12, 2025
df7edde
wip
LucasWilkinson Nov 12, 2025
f95af46
Merge branch 'lwilkinson/pad-before-metadata' into pr/ayushsatyam146/…
LucasWilkinson Nov 12, 2025
6160f1c
update docs
LucasWilkinson Nov 12, 2025
cebcc54
cleanup
LucasWilkinson Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,20 @@ class BatchDescriptor(NamedTuple):
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
"""
num_reqs: Optional[int] = None
"""
Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
we don't need to know the number of requests.
"""

@property
def non_uniform(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
"""
return BatchDescriptor(self.num_tokens, uniform_decode=False)
return BatchDescriptor(
self.num_tokens, uniform_decode=False, num_reqs=self.num_reqs
)


def _compute_sp_num_tokens(
Expand Down
21 changes: 1 addition & 20 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,31 +656,12 @@ def build(

if num_decodes > 0:
pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (
self.enable_cuda_graph
and pure_decode
and num_decodes <= self._decode_cudagraph_max_bs
)
if use_cudagraph:
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_decode_tokens
)
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self.paged_kv_indptr_cpu[
1 + num_decodes : 1 + num_input_tokens
].fill_(paged_kv_indptr_cpu[-1])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_(
1
)

else:
num_input_tokens = num_decode_tokens
num_input_tokens = num_decode_tokens

attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,9 @@ def split_decodes_and_prefills(
if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
# 0-query len indicates a padded request; leave this at the back
# of the batch with the prefills
is_prefill = query_lens > decode_threshold | query_lens == 0

if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
Expand Down
110 changes: 83 additions & 27 deletions vllm/v1/cudagraph_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import NamedTuple, Optional

from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger

logger = init_logger(__name__)


class CUDAGraphKey(NamedTuple):
num_tokens: int
uniform_decode: bool

@staticmethod
def from_batch_descriptor(batch_descriptor: BatchDescriptor):
return CUDAGraphKey(
batch_descriptor.num_tokens, batch_descriptor.uniform_decode
)


class CudagraphDispatcher:
Expand Down Expand Up @@ -31,9 +45,11 @@ def __init__(self, vllm_config: VllmConfig):
self.cudagraph_mode = self.compilation_config.cudagraph_mode

# Dict to store valid cudagraph dispatching keys.
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
CUDAGraphMode.PIECEWISE: set(),
CUDAGraphMode.FULL: set(),
self.cudagraph_keys: dict[
CUDAGraphMode, dict[CUDAGraphKey, BatchDescriptor]
] = {
CUDAGraphMode.PIECEWISE: {},
CUDAGraphMode.FULL: {},
}

not_use_piecewise_compilation = (
Expand Down Expand Up @@ -61,7 +77,8 @@ def add_cudagraph_key(
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid cudagraph runtime mode for keys: {runtime_mode}"
)
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
key = CUDAGraphKey.from_batch_descriptor(batch_descriptor)
self.cudagraph_keys[runtime_mode][key] = batch_descriptor

def initialize_cudagraph_keys(
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
Expand All @@ -71,11 +88,24 @@ def initialize_cudagraph_keys(
# Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
# Add mixed mode keys with proper num_reqs calculation
if (mixed_mode := cudagraph_mode.mixed_mode()) in (
CUDAGraphMode.PIECEWISE,
CUDAGraphMode.FULL,
):
for bs in self.compilation_config.cudagraph_capture_sizes:
num_reqs = (
self.calculate_num_reqs_for_tokens(
bs, uniform_decode_query_len, False
)
if mixed_mode == CUDAGraphMode.FULL
else None
)
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
BatchDescriptor(num_tokens=bs, uniform_decode=False),
mixed_mode,
BatchDescriptor(
num_tokens=bs, uniform_decode=False, num_reqs=num_reqs
),
)

# if decode cudagraph mode is FULL, and we don't already have mixed
Expand All @@ -94,12 +124,39 @@ def initialize_cudagraph_keys(
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in cudagraph_capture_sizes_for_decode:
num_reqs = self.calculate_num_reqs_for_tokens(
bs, uniform_decode_query_len, True
)
self.add_cudagraph_key(
CUDAGraphMode.FULL,
BatchDescriptor(num_tokens=bs, uniform_decode=True),
BatchDescriptor(
num_tokens=bs, uniform_decode=True, num_reqs=num_reqs
),
)

self.keys_initialized = True

def calculate_num_reqs_for_tokens(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I understand what this does. It gives the maximum possible number of non-empty requests? But I thought we also have requests of length 0 sometimes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just mirrors what we capture in the dummy run (this behavior is unchanged by this PR, we just track num_reqs explicitly now); we need to refactor the dummy_run, that'll be a follow-up

self, num_tokens: int, uniform_decode_query_len: int, uniform_decode: bool
) -> int:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs

if uniform_decode:
num_reqs = num_tokens // uniform_decode_query_len
assert num_tokens % uniform_decode_query_len == 0
assert num_reqs <= max_num_seqs
return num_reqs
return min(num_tokens, max_num_seqs)

def _is_compatible(
self, batch_descriptor: BatchDescriptor, candidate: BatchDescriptor
) -> bool:
"""Check if candidate cudagraph can handle the batch request."""
if candidate.num_reqs is None:
return True
assert batch_descriptor.num_reqs is not None
return candidate.num_reqs >= batch_descriptor.num_reqs

def dispatch(
self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
Expand All @@ -113,21 +170,20 @@ def dispatch(
if not self.keys_initialized:
return CUDAGraphMode.NONE, None

non_uniform_key = batch_descriptor.non_uniform
# if a batch use cascade attention, bypass checking full cudagraphs
if not use_cascade_attn:
# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor

# otherwise, check if non-uniform key exists
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key

# also check if non-uniform key exists for more "general"
# piecewise cudagraph
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, non_uniform_key

# finally, just return no cudagraphs
return CUDAGraphMode.NONE, None
key = CUDAGraphKey.from_batch_descriptor(batch_descriptor)
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_batch_desc = None

if key in self.cudagraph_keys[CUDAGraphMode.FULL]:
cudagraph_batch_desc = self.cudagraph_keys[CUDAGraphMode.FULL][key]
cudagraph_mode = CUDAGraphMode.FULL
elif key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
cudagraph_batch_desc = self.cudagraph_keys[CUDAGraphMode.PIECEWISE][key]
cudagraph_mode = CUDAGraphMode.PIECEWISE

if cudagraph_batch_desc is not None:
assert self._is_compatible(batch_descriptor, cudagraph_batch_desc), (
f"Batch descriptor {batch_descriptor} is not compatible with "
f"cudagraph batch descriptor: {cudagraph_batch_desc} (key {key})"
)
return cudagraph_mode, cudagraph_batch_desc
Loading