Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a012257
renaming to piecewise_backend.py
fhl2000 Aug 16, 2025
c37583c
dispatch cascade attention to NONE or PIECEWISE runtime mode;clean up…
fhl2000 Aug 17, 2025
5e998c2
Merge remote-tracking branch 'origin/main' into post_issues_20059
fhl2000 Aug 17, 2025
648fbb3
apply suggestion from bot
fhl2000 Aug 17, 2025
3c88284
Merge branch 'main' into post_issues_20059
fhl2000 Aug 17, 2025
868c85f
fix bug when attn_metadata have no use_cascade
fhl2000 Aug 17, 2025
3bef9e4
simple dispatching test
fhl2000 Aug 17, 2025
05cf012
minor comment tweak
fhl2000 Aug 17, 2025
20d8afb
address comments part1
fhl2000 Aug 22, 2025
9648a84
Merge branch 'main' into post_issues_20059
fhl2000 Aug 23, 2025
da79494
fix comments part2
fhl2000 Aug 23, 2025
01083c3
fix double validation of deprecated flag
fhl2000 Aug 27, 2025
ca3c2c9
Merge branch 'main' into post_issues_20059
fhl2000 Aug 28, 2025
59db6b1
Merge branch 'main' into post_issues_20059
fhl2000 Aug 29, 2025
11421de
Merge branch 'main' into post_issues_20059
fhl2000 Sep 4, 2025
2bf5569
fix pre-commit
fhl2000 Sep 4, 2025
bd1762a
Merge remote-tracking branch 'origin/main' into post_issues_20059
fhl2000 Sep 9, 2025
2a50ecc
Merge branch 'main' into post_issues_20059
fhl2000 Sep 10, 2025
4aef453
Merge branch 'main' into post_issues_20059
fhl2000 Sep 14, 2025
ff3c671
Merge remote-tracking branch 'origin/main' into post_issues_20059
fhl2000 Sep 19, 2025
27eecc2
disable cascade_attn when DBO
fhl2000 Sep 19, 2025
8d3ecc8
Merge branch 'main' into post_issues_20059
fhl2000 Sep 20, 2025
48a8c7f
pre-commit
fhl2000 Sep 20, 2025
f3e08f3
remove piecewise cudagraph wrapper when no needed
fhl2000 Sep 23, 2025
3faff97
simplify set_splitting_ops_for_v1
fhl2000 Sep 23, 2025
f09e47f
resolve merged from main
fhl2000 Sep 23, 2025
b8894f2
Merge branch 'main' into post_issues_20059
fhl2000 Sep 24, 2025
92cbd4f
pre-commit
fhl2000 Sep 24, 2025
df90576
modify comments for full_cuda_graph
fhl2000 Sep 24, 2025
6176761
Merge branch 'main' into post_issues_20059
ProExpertProg Sep 24, 2025
4679802
address comment;add test for splitting_ops
fhl2000 Sep 24, 2025
5475e9e
fix profile_run log
fhl2000 Sep 25, 2025
413079b
temporary disable cascade attention
fhl2000 Sep 25, 2025
254bfd3
Merge branch 'main' into post_issues_20059
fhl2000 Sep 25, 2025
f584663
Merge branch 'main' into post_issues_20059
fhl2000 Sep 25, 2025
d8a1ad7
recover
fhl2000 Sep 25, 2025
891723a
Merge remote-tracking branch 'origin/main' into post_issues_20059
fhl2000 Sep 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def call_module(self, target: torch.fx.node.Target,
runtime_shape=None)
# Lazy import here to avoid circular import
from .cuda_graph import CUDAGraphOptions
from .cuda_piecewise_backend import PiecewiseBackend
from .piecewise_backend import PiecewiseBackend

piecewise_backend = PiecewiseBackend(
submod, self.vllm_config, index,
Expand Down
18 changes: 13 additions & 5 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3626,13 +3626,21 @@ def __post_init__(self):

# final check of cudagraph mode after platform-specific update
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \
if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\
and self.model_config is not None and \
not self.model_config.disable_cascade_attn:
logger.info("CUDAGraphMode.FULL is not supported with "
"cascade attention currently. Disabling cascade"
"attention.")
self.model_config.disable_cascade_attn = True
warn_msg = ("Cascade attention is not supported with full "
"cudagraphs currently. ")
if self.compilation_config.cudagraph_mode.\
has_piecewise_cudagraphs():
logger.warning_once(
warn_msg + "It will dispatched to "
"piecewise cudagraphs if a batch runs into cascade "
"attentions")
else:
logger.warning_once(
warn_msg + "It will fallback to eager execution if a "
"batch runs into cascade attentions")

if self.compilation_config.cudagraph_mode\
.requires_piecewise_compilation():
Expand Down
33 changes: 29 additions & 4 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,17 @@ def max_cudagraph_mode(self) -> 'CUDAGraphMode':
def has_full_cudagraphs(self) -> bool:
return self.max_cudagraph_mode() == CUDAGraphMode.FULL

def has_piecewise_cudagraphs(self) -> bool:
return self.requires_piecewise_compilation()
Comment on lines +64 to +65
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two seem semantically different

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but they are equivalent in actuality since we don't allow piecewise mode with empty splitting ops (translated to FULL in this case). So, having piecewise_cudagraph means requiring piecewise compilation, and requiring piecewise compilation implies having piecewise_cudagraph.


def separate_routine(self) -> bool:
return isinstance(self.value, tuple)

def vaild_runtime_modes(self) -> bool:
return self in [
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
]


@config
@dataclass
Expand Down Expand Up @@ -544,20 +552,37 @@ def set_splitting_ops_for_v1(self):
# full cudagraph outside the fx graph. This reduces some cpu
# overhead when the runtime batch_size is not cudagraph captured.
# see https://github.com/vllm-project/vllm/pull/20059 for details.
self.splitting_ops = self._attention_ops
if self.pass_config.enable_attn_fusion:
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"When enable_attn_fusion, splitting_ops will be set "
"to empty list, and cudagraph_mode containing "
"PIECEWISE will be treated as FULL cudagraph_mode. "
"Please ensure you are using attention backends that "
"support cudagraph or set cudagraph_mode to NONE "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing, we should clarify that we disable piecewise not that piecewise is handled as full. Also I think we should do full_and_piecewise->full and piecewise->none, not piecewise->full.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have two places in this function doing the piecewise->full. The former case is for attn_ops fusion (splitting_ops=[]), so it must be FULL mode in this case. The latter is when users explicitly set splitting_ops=[]. I agree that this case is more reasonable to do full_and_piecewise->full and piecewise->none

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah for the attn fusion case let's just explicitly update cg mode there.

"explicitly if encountering any problems.")
self.cudagraph_mode = CUDAGraphMode.FULL
else:
self.splitting_ops = self._attention_ops
elif len(self.splitting_ops) == 0:
logger.warning_once("Using piecewise compilation with empty "
"splitting_ops.")
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"When compilation level is piecewise with empty "
"splitting_ops, PIECEWISE cudagraph_mode will be "
"treated as FULL cudagraph_mode. Please ensure you are "
"splitting_ops, cudagraph_mode containing PIECEWISE will "
"be treated as FULL cudagraph_mode. Please ensure you are "
"using attention backends that support cudagraph or set "
"cudagraph_mode to NONE explicitly if encountering "
"any problems.")
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
else: # len(self.splitting_ops) > 0:
assert not self.pass_config.enable_attn_fusion or \
not self.splitting_ops_contain_attention(), (
"attention ops should not be in splitting_ops "
"when enable_attn_fusion is True")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic has become insane. We should think about how to simplify it but first, we should:

  1. Add a test that checks all supported scenarios (as well as log messages and errors)
  2. Extract this into a function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have simplified a bit.


def splitting_ops_contain_attention(self) -> bool:
return self.splitting_ops is not None and all(
Expand Down
3 changes: 1 addition & 2 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ class ForwardContext:
batch_descriptor: Optional[BatchDescriptor] = None

def __post_init__(self):
assert self.cudagraph_runtime_mode in [
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
assert self.cudagraph_runtime_mode.vaild_runtime_modes(), \
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"


Expand Down
44 changes: 23 additions & 21 deletions vllm/v1/cudagraph_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

class CudagraphDispatcher:
"""
Runtime cudagraph dispatcher to dispach keys for multiple set of cudagraphs.
Runtime cudagraph dispatcher to dispatch keys for multiple sets of
cudagraphs.

The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
for FULL cudagraph runtime mode. The keys are initialized depending on
Expand All @@ -21,10 +22,10 @@ class CudagraphDispatcher:

At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
based on the input key. After dispatching (commuicate via forward context),
the cudagraph wrappers will trust the dispatch key to do either capturing
or replaying (if mode matched), or pass through to the underlying runnable
without cudagraph (if mode no match or mode is NONE).
based on the input key. After dispatching (communicated via forward
context), the cudagraph wrappers will trust the dispatch key to either
capture or replay (if the mode matches), or pass through to the underlying
runnable without cudagraph (if the mode does not match or mode is NONE).
"""

def __init__(self, vllm_config: VllmConfig):
Expand Down Expand Up @@ -52,19 +53,15 @@ def __init__(self, vllm_config: VllmConfig):
def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
batch_descriptor: BatchDescriptor):
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
f"Invalid cudagraph runtime mode: {runtime_mode}"
f"Invalid cudagraph runtime mode for keys: {runtime_mode}"
self.cudagraph_keys[runtime_mode].add(batch_descriptor)

def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
uniform_decode_query_len: int):
# This should be called only after attention backend is initialized.

# Note: we create all valid keys possible for cudagraph but do not
# guarantee all keys would be used. For example, we create keys for
# piecewise cudagraphs when it is piecewise compilation, which is always
# valid, but for attention backend support unified routine, we may not
# trigger capturing/replaying the piecewise cudagraphs depending on
# CompilationConfig.cudagraph_mode. In addition, if we allow lazy
# Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs in self.compilation_config.cudagraph_capture_sizes:
Expand All @@ -89,10 +86,13 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
self.keys_initialized = True

def dispatch(
self, batch_descriptor: BatchDescriptor
self,
batch_descriptor: BatchDescriptor,
use_cascade_attn: bool = False
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
"""
Given a batch descriptor, dispatch to a cudagraph mode.
Given conditions(e.g.,batch descriptor and if using cascade attention),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
Expand All @@ -102,14 +102,16 @@ def dispatch(
"initialized. No cudagraph will be used.")
return CUDAGraphMode.NONE, None

# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor

# otherwise, check if non-uniform key exists
non_uniform_key = batch_descriptor.non_uniform
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key
# if a batch use cascade attention, bypass checking full cudagraphs
if not use_cascade_attn:
# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor

# otherwise, check if non-uniform key exists
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key

# also check if non-uniform key exists for more "general"
# piecewise cudagraph
Expand Down
33 changes: 22 additions & 11 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,11 +689,13 @@ def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata],
np.ndarray, Optional[CommonAttentionMetadata], int]:
np.ndarray, Optional[CommonAttentionMetadata], int, bool]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
logits_indices, spec_decode_metadata
logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata,
max_num_scheduled_tokens, use_cascade_attn
]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
Expand Down Expand Up @@ -840,6 +842,7 @@ def _prepare_inputs(
)

attn_metadata: dict[str, Any] = {}
use_cascade_attn = False

# Prepare encoder attention metadata separately
# (encoder layers are not in KV cache groups)
Expand Down Expand Up @@ -908,6 +911,8 @@ def _prepare_inputs(
common_attn_metadata=common_attn_metadata,
))

use_cascade_attn |= attn_metadata_i.use_cascade

fast_prefill_metadata = attn_metadata_i
if (self.cache_config.kv_sharing_fast_prefill
and self.kv_sharing_fast_prefill_eligible_layers):
Expand Down Expand Up @@ -938,7 +943,7 @@ def _prepare_inputs(

return (attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata,
max_num_scheduled_tokens)
max_num_scheduled_tokens, use_cascade_attn)

def _compute_cascade_attn_prefix_len(
self,
Expand Down Expand Up @@ -1517,7 +1522,8 @@ def execute_model(
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = (self._prepare_inputs(scheduler_output))
max_query_len,
use_cascade_attn) = (self._prepare_inputs(scheduler_output))

num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
Expand Down Expand Up @@ -1593,7 +1599,8 @@ def execute_model(
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode)
cudagraph_runtime_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(batch_descriptor)
self.cudagraph_dispatcher.dispatch(batch_descriptor,
use_cascade_attn)

# Run the model.
# Use persistent buffers for CUDA graphs.
Expand Down Expand Up @@ -2253,9 +2260,7 @@ def _dummy_run(
skip_eplb: If True, skip EPLB state update.
is_profile: If True, this is a profile run.
"""
assert cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}
assert cudagraph_runtime_mode.vaild_runtime_modes()

# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
Expand Down Expand Up @@ -2709,9 +2714,9 @@ def freeze_gc():
def _capture_cudagraphs(self, compilation_cases: list[int],
cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool):
assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \
cudagraph_runtime_mode in [CUDAGraphMode.FULL,
CUDAGraphMode.PIECEWISE]
assert cudagraph_runtime_mode in [CUDAGraphMode.FULL,
CUDAGraphMode.PIECEWISE],\
f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"

# Only rank 0 should print progress bar during capture
if is_global_first_rank():
Expand Down Expand Up @@ -2853,6 +2858,12 @@ def create_attn_groups(
self.is_encoder_only_model = True

def initialize_cudagraph_capture(self) -> None:
"""
Resolve the cudagraph_mode when there are multiple
attention backends with conflicting CUDA graph support.
Initialize the cudagraph_dispatcher based on the resolved
cudagraph_mode.
"""
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_builder_name = None

Expand Down
Loading