Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
129 commits
Select commit Hold shift + click to select a range
92b1733
FA2 and FlashInfer Full cuda graph support
fhl2000 Jun 25, 2025
58ce477
fix the arch support in CMakeLists.txt to include 8.9
fhl2000 Jun 25, 2025
c2c5fea
Refactors
fhl2000 Jun 25, 2025
1606880
refactors
fhl2000 Jun 25, 2025
806432a
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jun 25, 2025
7c5df45
refactor
fhl2000 Jun 25, 2025
c7a9424
Add check for separate_attention_routine flag
fhl2000 Jun 25, 2025
e8b9296
fix typo error
fhl2000 Jun 26, 2025
94d0b79
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jun 27, 2025
a67c698
refactors and rearchitect cuda graph logic
fhl2000 Jun 28, 2025
da110af
Refactors
fhl2000 Jun 28, 2025
deaf0fe
Delect one commit
fhl2000 Jun 28, 2025
02ca154
Add support for force_no_split_graph
fhl2000 Jun 28, 2025
fa0d25c
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 1, 2025
5108bef
Huge refactors to separete cudagraph logic from vllm compilation
fhl2000 Jul 5, 2025
1c1873d
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 5, 2025
7d4667a
refactors
fhl2000 Jul 5, 2025
fedff47
fix errors
fhl2000 Jul 5, 2025
833ac56
fix small error by lazy import
fhl2000 Jul 5, 2025
d57257d
handle lint-and-deploy errors for cpu execution
fhl2000 Jul 5, 2025
8b7ea7a
remove redundents
fhl2000 Jul 5, 2025
328615d
Clear
fhl2000 Jul 6, 2025
debc682
Big refactors
fhl2000 Jul 9, 2025
cad6c39
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 9, 2025
dc455ee
cleanup
fhl2000 Jul 10, 2025
620a728
fix warmup
fhl2000 Jul 10, 2025
b1e6978
Commit suggestion: Update vllm/config.py
fhl2000 Jul 10, 2025
beee69a
commit suggestion2: Update vllm/config.py
fhl2000 Jul 10, 2025
21b1a8d
fix enforce_eager
fhl2000 Jul 10, 2025
ec79af7
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 10, 2025
210359a
small cleanup for pre-commit
fhl2000 Jul 10, 2025
11263e0
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 11, 2025
9a38a4e
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 12, 2025
699aff3
refactors
fhl2000 Jul 13, 2025
ef3d9d9
resolve yapf conflicts with isort
fhl2000 Jul 13, 2025
658565e
fixes
fhl2000 Jul 13, 2025
15e2b4a
fix global graph pool issue
fhl2000 Jul 13, 2025
4253dbf
fix refactors
fhl2000 Jul 13, 2025
2783e26
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 14, 2025
1b54962
more refactors
fhl2000 Jul 14, 2025
fb2a3c7
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 17, 2025
d6269bd
refactors for and more
fhl2000 Jul 17, 2025
2e1304c
fix pre-commit
fhl2000 Jul 17, 2025
db22ca5
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 18, 2025
72d40e6
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 20, 2025
0c79e53
change cudagraph dispatching logics; runtime style->runtime mode
fhl2000 Jul 21, 2025
75db3a6
pass pre-commit
fhl2000 Jul 21, 2025
0bca4c4
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 23, 2025
9d2f148
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 24, 2025
60bdc61
fix bug when cudagraph_separate_routine==False
fhl2000 Jul 24, 2025
9036bd2
recover FlashInfer from main branch
fhl2000 Jul 24, 2025
89ec3aa
address comments and clean up
fhl2000 Jul 26, 2025
4b991a3
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 26, 2025
614f6ea
clean up
fhl2000 Jul 26, 2025
c049627
fix
fhl2000 Jul 26, 2025
e69e488
add tests; more docs
fhl2000 Jul 27, 2025
835086a
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 27, 2025
534410e
clean up
fhl2000 Jul 27, 2025
618f7c0
small fix
fhl2000 Jul 27, 2025
1b343eb
add more docs
fhl2000 Jul 27, 2025
532f245
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 28, 2025
431a726
simplify the logic
fhl2000 Jul 28, 2025
19faeda
fix CI failures
fhl2000 Jul 29, 2025
348a117
fix CI failures again
fhl2000 Jul 29, 2025
fc5e37a
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 29, 2025
4d9829f
fix pre-commit
fhl2000 Jul 29, 2025
7773608
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 29, 2025
a692bb6
fix CI
fhl2000 Jul 29, 2025
543f264
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 30, 2025
3e5959a
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 31, 2025
aa35551
fix errors;move default initialization of cudagraph_mode to __post_in…
fhl2000 Jul 31, 2025
bad2710
fix a potential bug
fhl2000 Jul 31, 2025
f175c16
Merge branch 'vllm-project:main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 31, 2025
9916a75
Merge remote-tracking branch 'origin/main' into pr-20059
LucasWilkinson Jul 31, 2025
81d7561
wip rework cudagraph_mode
LucasWilkinson Aug 1, 2025
0137d84
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Aug 2, 2025
1bfb855
fix and re-enable FlashInfer full cudagraph
fhl2000 Aug 2, 2025
24c40ab
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Aug 2, 2025
95d94f8
fix some CI tests
fhl2000 Aug 2, 2025
e7763ef
fallback
LucasWilkinson Aug 4, 2025
803a185
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 4, 2025
645accf
warn perferred
LucasWilkinson Aug 4, 2025
5029a6a
fix bugs and some refactors;temporarily add FULL_DOUBLE mode
fhl2000 Aug 5, 2025
fef7eee
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Aug 5, 2025
e796196
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 5, 2025
816024e
fix incorrectly infering type from CUDAGraphWrapper
fhl2000 Aug 5, 2025
651f729
fix and refactor cudagraph_mode checkings
fhl2000 Aug 6, 2025
38ddeaf
remove full double
LucasWilkinson Aug 6, 2025
9ca04ed
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Aug 7, 2025
a7adfae
Merge remote-tracking branch 'origin/main' into fhl2000/full_cudagrap…
LucasWilkinson Aug 7, 2025
14e83f5
Merge branch 'fhl2000_full_cudagraph_FA2_FlashInfer_merge' into full_…
LucasWilkinson Aug 7, 2025
9cc6b93
fix
LucasWilkinson Aug 7, 2025
1e97920
fix
LucasWilkinson Aug 7, 2025
25b6242
cleanup
LucasWilkinson Aug 7, 2025
85f20bf
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Aug 8, 2025
766eb7c
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 8, 2025
b0374be
deprecation
LucasWilkinson Aug 8, 2025
a160dd4
migrate flags
LucasWilkinson Aug 8, 2025
c2dc791
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 8, 2025
028f119
cleanup
LucasWilkinson Aug 8, 2025
43db16d
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 9, 2025
2cad036
fix some unit tests
LucasWilkinson Aug 9, 2025
6839e88
more cleanup
LucasWilkinson Aug 9, 2025
04ed99a
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 9, 2025
3f2b279
fix more unit tests
LucasWilkinson Aug 9, 2025
d500150
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
fhl2000 Aug 10, 2025
a56d549
fix is_attention_splitting;fix new mamba_attn cg support
fhl2000 Aug 10, 2025
3499d7b
wip
LucasWilkinson Aug 10, 2025
83d4e7c
stabalize unit test
LucasWilkinson Aug 11, 2025
bf8a51d
cleanup
LucasWilkinson Aug 11, 2025
d1f62e4
unit test fix
LucasWilkinson Aug 11, 2025
7e19ca4
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 11, 2025
c722f2c
refactor
LucasWilkinson Aug 11, 2025
1937615
remove accidentally committed file
LucasWilkinson Aug 11, 2025
ce9cc82
fix XPU tests
LucasWilkinson Aug 11, 2025
f3561f9
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 11, 2025
9d6b189
fix xpu
LucasWilkinson Aug 11, 2025
3c4b532
match HPU cudagraph handling + down grade log
LucasWilkinson Aug 11, 2025
bed9576
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 12, 2025
19f7447
fix xpu
LucasWilkinson Aug 12, 2025
0122313
unit test fixes
LucasWilkinson Aug 12, 2025
3a2041b
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 14, 2025
641b10b
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 14, 2025
974c707
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 14, 2025
3805f3f
Apply suggestions from code review
LucasWilkinson Aug 15, 2025
f2c437a
review comments
LucasWilkinson Aug 15, 2025
1ff41d8
Update vllm/v1/worker/gpu_model_runner.py
LucasWilkinson Aug 15, 2025
af2a38c
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
fhl2000 Aug 15, 2025
f751e50
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Aug 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_ARCHS)

#
Expand Down Expand Up @@ -684,7 +684,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")

list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)

#
Expand Down
16 changes: 12 additions & 4 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

self._called = True

if not self.compilation_config.use_cudagraph or \
not self.compilation_config.cudagraph_copy_inputs:
return self.split_gm

# if we need to copy input buffers for cudagraph
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode()
Expand All @@ -585,6 +581,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
any(is_symbolic(d) for d in x.size())
]

if self.compilation_config.full_cuda_graph:
assert self.compilation_config.use_cudagraph, \
"full_cuda_graph mode requires use_cudagraph to be True"
fullgraph_wrapper = resolve_obj_by_qualname(
current_platform.get_fullgraph_wrapper_cls())
self.split_gm = fullgraph_wrapper(self.split_gm, self.vllm_config,
self.graph_pool, self.sym_tensor_indices)

if not self.compilation_config.use_cudagraph or \
not self.compilation_config.cudagraph_copy_inputs:
return self.split_gm

# compiler managed cudagraph input buffers
# we assume the first run with symbolic shapes
# has the maximum size among all the tensors
Expand Down
43 changes: 43 additions & 0 deletions vllm/compilation/base_piecewise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,46 @@ def __call__(self, *args) -> Any:
or a replayed static graph.
"""
raise NotImplementedError


class AbstractFullgraphWrapper(Protocol):
"""
FullgraphWrapper interface that allows platforms to wrap the piecewise graph
to be viewed or captured as a full graph.
"""

def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, sym_shape_indices: list[int], **kwargs):
"""
Initializes the FullgraphWrapper class with compilation and
execution-related configurations.

Args:
graph (fx.GraphModule): The graph represented in fx.
vllm_config (VllmConfig): Global configuration for vLLM.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
sym_shape_indices (list[int]):
Indices of symbolic shape.

Keyword Args:
kwargs: Additional keyword arguments reserved for future
extensions or custom platforms.

"""
raise NotImplementedError

def __call__(self, *args) -> Any:
"""
Executes the wrapped graph for given input args.

Args:
*args: Variable length input arguments to be passed into the
graph. The symbolic shape is expected to be in position
`sym_shape_indices[0]`.

Returns:
Any: Output of the executed wrapped graph.
"""
raise NotImplementedError
152 changes: 147 additions & 5 deletions vllm/compilation/cuda_piecewise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None

usage_type: Optional[str] = None


class CUDAPiecewiseBackend:

Expand Down Expand Up @@ -96,6 +98,7 @@
runtime_shape=shape,
need_to_compile=shape in self.compile_sizes,
use_cudagraph=shape in self.cudagraph_capture_sizes,
usage_type="piecewise(general)", # for logging only
)

def check_for_ending_compilation(self):
Expand Down Expand Up @@ -139,27 +142,32 @@
self.check_for_ending_compilation()

# Skip CUDA graphs if this entry doesn't use them OR
# if we're supposed to skip them globally
skip_cuda_graphs = get_forward_context().skip_cuda_graphs
if not entry.use_cudagraph or skip_cuda_graphs:
# if we're supposed to treat the piecewise graphs as a whole,
# which implies forward_context.skip_attention_cuda_graphs is False.
# In the latter case, we rely on a wrapper class to capture
# the full cudagraph outside the fx graph.
skip_attention_cuda_graphs = get_forward_context().skip_attention_cuda_graphs
if not entry.use_cudagraph or not skip_attention_cuda_graphs:
return entry.runnable(*args)

if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
if self.is_first_graph:
logger.debug(
"Warming up %s/%s for shape %s",
"Warming up %s/%s of %s usage for shape %s",
entry.num_finished_warmup,
self.compilation_config.cudagraph_num_of_warmups,
entry.usage_type,
runtime_shape)
return entry.runnable(*args)

if self.is_first_graph:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger.debug("Capturing a cudagraph for shape %s",
logger.debug("Capturing a cudagraph of %s usage for shape %s",
entry.usage_type,
runtime_shape)

input_addresses = [
Expand Down Expand Up @@ -216,3 +224,137 @@

entry.cudagraph.replay()
return entry.output


class FullCudagraphWrapper:
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, sym_shape_indices: list[int],
):
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.sym_shape_indices = sym_shape_indices

self.separate_attention_routine = vllm_config.compilation_config.separate_attention_routine

self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

self.first_run_finished = False

Check failure on line 243 in vllm/compilation/cuda_piecewise_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/cuda_piecewise_backend.py:243:81: E501 Line too long (99 > 80)

self.cudagraph_capture_sizes: set[int] = set(
self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set()

self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
self.concrete_size_entries_decode: dict[int, ConcreteSizeEntry] = {}


for shape in self.cudagraph_capture_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=False,
use_cudagraph=True,
usage_type="general",
)
if self.separate_attention_routine:
self.concrete_size_entries_decode[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=False,
use_cudagraph=True,
usage_type="decode",
)

def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
return self.graph(*args)
list_args = list(args)
runtime_shape = list_args[self.sym_shape_indices[0]].shape[0]
forward_context = get_forward_context()

if forward_context.skip_attention_cuda_graphs:
# turn back to piecewise cudagraphs backend, which is responsible
# for capturing and running the piecewise cudagraphs.
return self.graph(*args)

# if not skip, the fx graph and its sub-graphs will only be supposed to
# eagerly run the compiled graphs, which should be cudagraph capturable
# as a whole.

concrete_size_entries = self.concrete_size_entries # default as general usage
if self.separate_attention_routine and forward_context.is_pure_decoding:
concrete_size_entries = self.concrete_size_entries_decode

Check failure on line 288 in vllm/compilation/cuda_piecewise_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/cuda_piecewise_backend.py:288:81: E501 Line too long (86 > 80)
if not runtime_shape in concrete_size_entries:
# we don't need to do anything for this shape.
return self.graph(*args)

entry = concrete_size_entries[runtime_shape]

if entry.runnable is None:
entry.runnable = self.graph

if not entry.use_cudagraph:
return entry.runnable(*args)

if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
logger.debug(
"Warming up %s/%s of %s usage for shape %s",
entry.num_finished_warmup,
self.compilation_config.cudagraph_num_of_warmups,
entry.usage_type,
runtime_shape)
return entry.runnable(*args)


# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.

logger.debug("Capturing a cudagraph of %s usage for shape %s",
entry.usage_type,
runtime_shape)

input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()

Check failure on line 326 in vllm/compilation/cuda_piecewise_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/compilation/cuda_piecewise_backend.py:326:33: F841 Local variable `stack` is assigned to but never used
with ExitStack() as stack:
# mind-exploding: carefully manage the reference and memory.

Check failure on line 328 in vllm/compilation/cuda_piecewise_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM117)

vllm/compilation/cuda_piecewise_backend.py:326:13: SIM117 Use a single `with` statement with multiple contexts instead of nested `with` statements
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory.
output = weak_ref_tensors(output)

# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph

compilation_counter.num_cudagraph_captured += 1

# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output

if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)

entry.cudagraph.replay()
return entry.output
23 changes: 17 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3974,13 +3974,21 @@
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models."""
separate_attention_routine: bool = False
"""
Enable a distinct attention calls routine under an attention backend for full
cuda graph capturing. This is because some attention backends like FlashMLA,
FlashInfer, FA2, etc. implement different branches for mix prefill-decode and
pure decode cases. This flag enables us to potentially capture the cudagraph
separately for each branch.
"""

pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details"""

max_capture_size: int = field(default=None, init=False) # type: ignore

Check failure on line 3989 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:3989:81: E501 Line too long (81 > 80)
"""not configurable, computed after init"""
local_cache_dir: str = field(default=None, init=False) # type: ignore

Check failure on line 3991 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:3991:81: E501 Line too long (81 > 80)
"""local cache dir for each rank"""
bs_to_padded_graph_size: list[int] = field(
default=None, # type: ignore
Expand Down Expand Up @@ -4172,13 +4180,16 @@

def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called
if self.splitting_ops and self.full_cuda_graph:
raise ValueError("full_cuda_graph cannot be used together with "
"splitting_ops, as Full CUDA graph will override "
f"the splitting_ops: {self.splitting_ops}")

# NOTE: When full_cuda_graph is True, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx graph,
# we keep the piecewise fx graph structure but capture the full
# cudagraph outside the fx graph. This reduces some cpu overhead when
# the runtime batch_size is not cudagraph captured. This is only
# supported for separate_attention_routine.
if self.separate_attention_routine:
assert self.full_cuda_graph, "separate_attention_routine requires full_cuda_graph to be True"
if not self.splitting_ops:
self.splitting_ops = [] if self.full_cuda_graph else [
self.splitting_ops = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
Expand All @@ -4186,7 +4197,7 @@

@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:

Check failure on line 4200 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:4200:81: E501 Line too long (105 > 80)
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
Expand Down
12 changes: 9 additions & 3 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ class ForwardContext:
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None
skip_cuda_graphs: bool = False
# determine whether to use a full cudagraph for attention or piecewise
# cudagraphs that skip the attention part. By default true, we use piecewise
# cudagraphs.
skip_attention_cuda_graphs: bool = True,
is_pure_decoding: bool = False


_forward_context: Optional[ForwardContext] = None
Expand All @@ -115,7 +119,8 @@ def set_forward_context(
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False,
skip_attention_cuda_graphs: bool = True,
is_pure_decoding: bool = False,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Expand All @@ -140,7 +145,8 @@ def set_forward_context(
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
skip_cuda_graphs=skip_cuda_graphs,
skip_attention_cuda_graphs=skip_attention_cuda_graphs,
is_pure_decoding=is_pure_decoding,
)

try:
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@ def use_custom_allreduce(cls) -> bool:
@classmethod
def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa

@classmethod
def get_fullgraph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.FullCudagraphWrapper" # noqa

@classmethod
def stateless_init_device_torch_dist_pg(
Expand Down
8 changes: 8 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,14 @@ def get_piecewise_backend_cls(cls) -> str:
Get piecewise backend class for piecewise graph.
"""
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa

@classmethod
def get_fullgraph_wrapper_cls(cls) -> str:
"""
Get fullgraph wrapper class for fullgraph static graph.
"""
return "vllm.compilation.base_piecewise_backend.AbstractFullgraphWrapper" # noqa


@classmethod
def stateless_init_device_torch_dist_pg(
Expand Down
Loading
Loading