Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions docs/design/cuda_graphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,6 @@ outputs = model.generate(
)
```

### Migration from legacy flags

Legacy `use_cudagraph` and `full_cuda_graph` are unified by `cudagraph_mode`:

* `use_cudagraph=False` → `NONE`.
* `use_cudagraph=True` and `full_cuda_graph=False` → `PIECEWISE`.
* `full_cuda_graph=True` → directly set `FULL` and rely on the graceful fallback policy.

As they are deprecated and will be removed in the next major or minor release, i.e., v0.11.0 or v1.0.0, we recommend using cudagraph_mode instead.

### Piecewise compilation and full graph custom passes (attention fusion, sequence parallelism)

Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs.
Expand Down
3 changes: 1 addition & 2 deletions tests/compile/piecewise/test_multiple_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
Expand Down Expand Up @@ -281,7 +280,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=False,
cudagraph_mode=CUDAGraphMode.NONE,
splitting_ops=["silly::attention"],
use_inductor_graph_partition=use_inductor_graph_partition,
)
Expand Down
1 change: 0 additions & 1 deletion tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def _run_simple_model(
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
use_inductor=use_inductor,
splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,
Expand Down
1 change: 0 additions & 1 deletion tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,6 @@ def benchmark():
if piecewise:
compilation_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=cudagraph_sizes,
)
Expand Down
75 changes: 42 additions & 33 deletions tests/compile/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from contextlib import nullcontext

import pytest
from pydantic import ValidationError

from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
Expand All @@ -23,14 +24,6 @@ def test_version():
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")


def test_use_cudagraphs_dynamic():
vllm_config = VllmConfig()
# Default V1 configuration now starts without cudagraphs enabled; the
# engine decides when to capture based on runtime settings instead of a
# blanket default.
assert vllm_config.compilation_config.use_cudagraph


def test_copy_pass():
vllm_config = VllmConfig()
inductor_pass = FixFunctionalizationPass(vllm_config)
Expand Down Expand Up @@ -65,7 +58,7 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)

compilation_config = {
"use_cudagraph": False, # speed things up a bit
"cudagraph_mode": CUDAGraphMode.NONE, # speed things up a bit
}
with (
compilation_counter.expect(
Expand All @@ -83,20 +76,24 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):

# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
@pytest.mark.parametrize("enabled", [True, False])
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
@pytest.mark.parametrize(
"cudagraph_mode", [CUDAGraphMode.FULL_AND_PIECEWISE, CUDAGraphMode.NONE]
)
def test_use_cudagraphs(vllm_runner, monkeypatch, cudagraph_mode):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")

compilation_config = {
"cudagraph_capture_sizes": [100],
"use_cudagraph": enabled,
"cudagraph_mode": cudagraph_mode,
}
with (
compilation_counter.expect(
num_graphs_seen=1,
num_gpu_runner_capture_triggers=1 if enabled else 0,
num_cudagraph_captured=13 if enabled else 0,
num_gpu_runner_capture_triggers=1
if cudagraph_mode != CUDAGraphMode.NONE
else 0,
num_cudagraph_captured=13 if cudagraph_mode != CUDAGraphMode.NONE else 0,
),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
Expand Down Expand Up @@ -249,25 +246,36 @@ def test_resolve_operator_overload():
"tp_size",
"enable_sequence_parallelism",
"max_num_batched_tokens",
"use_cudagraph",
"cudagraph_mode",
"expected_max_size",
),
[
(None, None, 1, False, 2048, True, 512),
([1, 2, 4], 4, 1, False, 2048, True, 4),
([1, 2, 4], 8, 1, False, 2048, True, RuntimeError),
([1, 256], None, 1, False, 2048, 256),
([], None, 1, False, 2048, False, 0),
(None, 0, 1, False, 2048, False, 0),
(None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
(
[1, 2, 4],
8,
1,
False,
2048,
CUDAGraphMode.FULL_AND_PIECEWISE,
ValidationError,
),
([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
(None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
# truncated to nearest multiple of 8 or 16
(None, 257, 1, False, 2048, True, 256),
([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list
([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP
([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
# max from list
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
# filtered out 15 due to SP
([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
# limited by the max_tokens
([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
# the list should contain at least 1 element when use cudagraph
([], None, 1, False, 2048, True, RuntimeError),
([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
# the max capturing size should be >= 1 when use cudagraph
(None, 0, 1, False, 2048, True, RuntimeError),
(None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
],
)
def test_cudagraph_sizes_post_init(
Expand All @@ -276,14 +284,13 @@ def test_cudagraph_sizes_post_init(
tp_size,
enable_sequence_parallelism,
max_num_batched_tokens,
use_cudagraph,
cudagraph_mode,
expected_max_size,
):
ctx = nullcontext()
if isinstance(expected_max_size, Exception):
if expected_max_size == ValidationError:
ctx = pytest.raises(expected_max_size)

cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE
with ctx:
compilation_config = CompilationConfig(
cudagraph_capture_sizes=cudagraph_capture_sizes,
Expand All @@ -298,11 +305,13 @@ def test_cudagraph_sizes_post_init(
engine_args = EngineArgs(
model="facebook/opt-125m",
tensor_parallel_size=tp_size,
max_num_seqs=min(max_num_batched_tokens, 128),
max_num_batched_tokens=max_num_batched_tokens,
compilation_config=compilation_config,
)
vllm_config = engine_args.create_engine_config()

assert (
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
)
assert (
vllm_config.compilation_config.max_cudagraph_capture_size
== expected_max_size
)
3 changes: 0 additions & 3 deletions tests/compile/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatc
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
Expand Down Expand Up @@ -215,7 +214,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
Expand Down Expand Up @@ -257,7 +255,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
Expand Down
53 changes: 0 additions & 53 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ class CompilationConfig:
- [`custom_ops`][vllm.config.CompilationConfig.custom_ops]
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
- CudaGraph capture:
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
- [`cudagraph_capture_sizes`]
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
Expand All @@ -160,7 +159,6 @@ class CompilationConfig:
[vllm.config.CompilationConfig.cudagraph_num_of_warmups]
- [`cudagraph_copy_inputs`]
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
- [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
- Inductor compilation:
- [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
Expand Down Expand Up @@ -328,18 +326,6 @@ class CompilationConfig:
Warning: This flag is new and subject to change in addition
more modes may be added.
"""
use_cudagraph: bool = True
"""Whether to use cudagraph inside compilation:

- False: cudagraph inside compilation is not used.\n
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.

Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL_AND
_PIECEWISE instead.
"""
cudagraph_num_of_warmups: int = 0
"""Number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Expand All @@ -357,15 +343,6 @@ class CompilationConfig:
internally managed buffer. Default is False.
Note that this flag is only effective when cudagraph_mode is PIECEWISE.
"""
full_cuda_graph: bool | None = False
"""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models.
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
FULL_AND_PIECEWISE instead.
"""
cudagraph_specialize_lora: bool = True
"""Whether to create separate cuda graphs for cases with and without active
LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used
Expand Down Expand Up @@ -578,36 +555,6 @@ def __post_init__(self) -> None:
self.inductor_compile_config["combo_kernels"] = True
self.inductor_compile_config["benchmark_combo_kernel"] = True

# migrate the deprecated flags
if not self.use_cudagraph:
logger.warning(
"use_cudagraph is deprecated, use cudagraph_mode=NONE instead."
)
if (
self.cudagraph_mode is not None
and self.cudagraph_mode != CUDAGraphMode.NONE
):
raise ValueError(
"use_cudagraph and cudagraph_mode are mutually"
" exclusive, prefer cudagraph_mode since "
"use_cudagraph is deprecated."
)
self.cudagraph_mode = CUDAGraphMode.NONE
if self.full_cuda_graph:
logger.warning(
"full_cuda_graph is deprecated, use cudagraph_mode=FULL instead."
)
if (
self.cudagraph_mode is not None
and not self.cudagraph_mode.has_full_cudagraphs()
):
raise ValueError(
"full_cuda_graph and cudagraph_mode are "
"mutually exclusive, prefer cudagraph_mode "
"since full_cuda_graph is deprecated."
)
self.cudagraph_mode = CUDAGraphMode.FULL

if self.use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"
):
Expand Down
14 changes: 6 additions & 8 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,14 +536,6 @@ def __post_init__(self):
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
)

# final migrate the deprecated flags
self.compilation_config.use_cudagraph = (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
)
self.compilation_config.full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)

if self.parallel_config.enable_dbo:
a2a_backend = self.parallel_config.all2all_backend
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
Expand Down Expand Up @@ -720,6 +712,12 @@ def _set_cudagraph_sizes(self):
# de-duplicate the sizes provided by the config
dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
cudagraph_capture_sizes = dedup_sizes
# filter out sizes larger than max_cudagraph_capture_size
cudagraph_capture_sizes = [
i
for i in cudagraph_capture_sizes
if i <= max_cudagraph_capture_size
]
# sort to make sure the sizes are in ascending order
cudagraph_capture_sizes.sort()
else:
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mamba1_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def build(
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
state_indices_for_decode = state_indices_tensor[:num_decodes]
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mamba2_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def build(

elif (
num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
# Pad state tensor for CUDA graph
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/short_conv_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def build(
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(
Expand Down