Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1d3153b
drafting docs for cuda graph v1
fhl2000 Sep 6, 2025
6a830d1
fix typos and minor polish
fhl2000 Sep 7, 2025
88d7346
fix broken table
fhl2000 Sep 7, 2025
ccc44f6
address comments
fhl2000 Sep 8, 2025
c12b82d
minor
fhl2000 Sep 8, 2025
35a0c54
fix pre-commit
fhl2000 Sep 8, 2025
631a8da
fix pre-commit again
fhl2000 Sep 8, 2025
e75a642
replace two images
fhl2000 Sep 9, 2025
c3f8115
replace one image
fhl2000 Sep 9, 2025
614e126
Merge branch 'main' into cudagraph_mode_docs
fhl2000 Sep 10, 2025
a02eb23
small fixing of torch_compile.md
fhl2000 Sep 15, 2025
a582107
Move assets
hmellor Sep 15, 2025
f723640
Formatting and `CUDA Graphs` consistency
hmellor Sep 15, 2025
7ef8153
Comment link formatting
hmellor Sep 15, 2025
8752c24
Fix `pre-commit`
hmellor Sep 15, 2025
0dd161d
`pre-commit` again...
hmellor Sep 15, 2025
50a73cb
address comments
fhl2000 Sep 15, 2025
bba2ba8
fix pre-commit
fhl2000 Sep 16, 2025
06d56de
Merge branch 'main' into cudagraph_mode_docs
fhl2000 Sep 16, 2025
8c2b392
fix links
fhl2000 Sep 19, 2025
0db6e26
modify notes for attn_ops fusion
fhl2000 Sep 20, 2025
08292de
update aiter_fa cudagraph_support
fhl2000 Sep 20, 2025
7813f6d
add some recent updates
fhl2000 Sep 27, 2025
9e549c8
small fix
fhl2000 Sep 27, 2025
9a5adf9
small
fhl2000 Sep 27, 2025
d53da8c
small
fhl2000 Sep 28, 2025
6e7e01c
Merge branch 'main' into cudagraph_mode_docs
fhl2000 Oct 3, 2025
99b3eb6
Update docs/design/cuda_graphs.md
fhl2000 Oct 7, 2025
1d57323
Apply suggestions from code review
fhl2000 Oct 7, 2025
f8dc933
adapt from review suggestions
fhl2000 Oct 7, 2025
2f5586b
fix default
fhl2000 Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
220 changes: 220 additions & 0 deletions docs/design/cuda_graphs_v1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@


Check failure on line 2 in docs/design/cuda_graphs_v1.md

View workflow job for this annotation

GitHub Actions / pre-commit

Multiple consecutive blank lines [Expected: 1; Actual: 2]
This write-up introduces the new CUDA Graph modes in vLLM v1 beyond previous [torch.compile Intergration](torch_compile.md). To summarize, we (a.) added flexible `cudagraph_mode` configuration, (b.) made full CUDA Graphs support orthogonal to compilation,and (c.) introduced a cudagraph dispatcher as a central controller that picks the desired runtime mode and cudagraphs per batch automatically.

Check failure on line 3 in docs/design/cuda_graphs_v1.md

View workflow job for this annotation

GitHub Actions / pre-commit

First line in a file should be a top-level heading [Context: "This write-up introduces the n..."]

Throughtout the document, we will walk through the motivation, cudagraph modes, the detailed design, and the usage examples of the CUDA Graph modes.

!!! note
In this document, we refer to pure decode (`max_query_len=1`) or speculative decode (`max_query_len =1+num_spec_tokens`) as **uniform decode** batches, and the opposite would be **non-uniform** batches (i.e., prefill or mixed prefill-decode batches).

!!! note
The following contents are based on the last commit of <gh-pr:20059>.

---

## 1. Motivation

In the past [torch.compile integration](torch_compile.md), we achieved a balance between performance and attention operation compatibility using piecewise compilation (+piecewise cudagraph). However, when users flipped on full CUDA Graphs, which relys on no splitting compilation, the experience used to be all-or-nothing, tightly coupled to compilation and therefore loss the flexibility attention supports (i.e., cascade attention is imcompatible with cudagraph). Many attention backends also weren’t ready for unified "full" cuda graphs capture (e.g., only FlashAttention 3 supports it currently) or only support cuda graphs for pure decode batches (e.g., Flashinfer, FlashMLA and Mamba etc.). That may lead to confusing performance/compatibility tradeoffs, inconsistent cudagraph supports and increasingly complex code structures.

So we seek to a more fine-grained cudagraph solution that can:

* Explicitly aware if a batch is prefill/mixed batch or a uniform decode batch and capture/replay the cudagraphs accordingly, as an unified full cudagraph for different cases of the same batchsize is commonly unfeasible (e.g., for many attention backends).
* Capture "full" while maintaining the abilities of piecewise cudagraph. i.e, can dispatch cudagraph-incompatible routines (e.g., cascade attention or mixed prefill/decode batches for some attention backends) into piecewise cudagraphs.
* Achieve centralized control of the cudagraph behavious via a dispatcher, which makes cudagraph dispatching easier understanding and more extensible.
* Final and also minor, make cuda graph support to models that does not fit vllm's torch.compile integration system design in v1.

Apart from the above concerns, we also found that when a batch cannot hit a full cudagraph, the host-side eager execution of the flattened compiled fx graph(previous behavior) can be slower than the piecewise compiled fx graph in python (see [here](gh-pr:20059>)). So we are in favor of maintaining the piecewise compilation when enabling full cudagraphs to reduce host-side overhead. We can safely do this as full cudagraph and compilation are actually orthogonal to each other.

---

## 2. CudagraphModes

`CUDAGraphMode` (enum type) is the single knob you tune in `CompilationConfig.cudagraph_mode`:

* `NONE` — turn CG off. Good for debugging.
* `PIECEWISE` — default in v1. Most flexible: attention or other cudagraph-incompatible operations stay eager, everything else goes into CG.
* `FULL` — a single-mode strategy, which only captures full cudagraphs for non-uniform batches, then uniform-decode batches reuse cudagraph of non-uniform batch of same batch_size , since they are compatible; can be good for small models or workloads with small prompts.
* `FULL_DECODE_ONLY` — full cudagraph for uniform decode, eager run for prefill/mixed etc; can be good for decode instances in a P/D setup where prefill is not as important so we can save some memory.
* `FULL_AND_PIECEWISE` — full cudagraph for uniform decode, piecewise cudagraph for others; the general most performant setting for most models.

Defaults: If you’re on v1 with piecewise compilation, we default to `PIECEWISE` for safety reason (For mamba mixer models, it's `FULL_AND_PIECEWISE`). Otherwise we default to `NONE`.

!!! note
We also fuse the subset modes `NONE`, `PIECEWISE`, and `FULL` as the concrete runtime modes for cudagraph dispatching, so they are treated as one of the "decode_mode" or "mixed_mode" at runtime.

While `NONE` , `PIECEWISE`, and `FULL` are single mode configurations and simply equivalent to past implementation of eager execution, piecewise cudagraph, and full cudagraph respectively, `FULL_DECODE_ONLY` and `FULL_AND_PIECEWISE` are newly appended dual mode configurations, which require dispatching to dynamically switch between concrete runtime modes according to runtime batches.

!!! note
Not all the above modes are valid for every attention backends. We will discuss the compatibility later. But for users experience, we alias `FULL` mode to `FULL_AND_PIECEWISE` (-O 3) or `FULL_DECODE_ONLY` (-O 0) for attention backends that supports cudagraph for only pure decode or uniform decode.

Check failure on line 49 in docs/design/cuda_graphs_v1.md

View workflow job for this annotation

GitHub Actions / pre-commit

Trailing spaces [Expected: 0 or 2; Actual: 4]
---

## 3. Detailed Design

### 3.1 Overview

The new CUDA Graph logic is built on top of piecewise compilation and supports dual cudagraph runtime mode switching. To make the system work, there are two core classes, i.e., [CUDAGraphWrapper][vllm.compilation.cuda_graph.CUDAGraphWrapper] and [CudagraphDispatcher][vllm.v1.cudagraph_dispatcher.CudagraphDispatcher] and an auxiliary component, i.e., [CUDAGraphMode][vllm.config.compilation.CUDAGraphMode] (introduced above) used for runtime mode, [BatchDescriptor][vllm.forward_context.BatchDescriptor] serving as the dispatch key.

See the following figures for a quick comparison between the previous and current design patterns of cudagraph with inductor compilation. We can see that previously the cudagraph logic and compilation logic is tightly coupled into the vllm `PiecewiseBackend`, and cudagraph is implicitly dispatched by `batch_size` idlely. Now the cudagraph logic is separated into the `CUDAGraphWrapper` class, responsible for both full and piecewise cudagraphs abilities, and dispatching is **explicitly** done via **runtime mode** plus the `BatchDescriptor` as the **dispatch key** via `CudagraphDispatcher`.

![previous_design](../assets/design/cuda_graphs_v1/previous_design.jpg)

![new_design](../assets/design/cuda_graphs_v1/current_design.jpg)

### 3.2 [BatchDescriptor][vllm.forward_context.BatchDescriptor]

Check failure on line 64 in docs/design/cuda_graphs_v1.md

View workflow job for this annotation

GitHub Actions / pre-commit

Headings should be surrounded by blank lines [Expected: 1; Actual: 0; Below] [Context: "### 3.2 [BatchDescriptor][vllm.forward_context.BatchDescriptor]"]
`BatchDescriptor` is a component within `ForwardContext`, alongside the cudagraph runtime modes, serving as the core structure for dispatching keys at runtime. The prototype is:
```python

Check failure on line 66 in docs/design/cuda_graphs_v1.md

View workflow job for this annotation

GitHub Actions / pre-commit

Fenced code blocks should be surrounded by blank lines [Context: "```python"]
class BatchDescriptor(NamedTuple):
num_tokens: int
uniform_decode: bool = False
```

Check failure on line 70 in docs/design/cuda_graphs_v1.md

View workflow job for this annotation

GitHub Actions / pre-commit

Fenced code blocks should be surrounded by blank lines [Context: "```"]
where num_tokens can be the padded token length, and uniform_decode is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_schedual_token is divisible by that desired `max_query_len`.

Check failure on line 71 in docs/design/cuda_graphs_v1.md

View workflow job for this annotation

GitHub Actions / pre-commit

Trailing spaces [Expected: 0 or 2; Actual: 1]

The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a cudagraph item. We are safe to exclude item like `uniform_query_len` because the it is a constant at runtime for a certain setup currently, e.g., it should be either `1` for commonly pure decode or `1+num_spec_tokens` for a vaildation phase of speculative decode.

!!! note
The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths setting (<gh-pr:23679>), or other modifications needed to support cudagraphs for models that input is not necessarily token length awared (for example, some multi-modal inputs).

### 3.3 [CudagraphDispatcher][vllm.v1.cudagraph_dispatcher.CudagraphDispatcher]

The dispatcher takes responsibility for maintaining two sets of valid dispatching keys, one set for `FULL` runtime mode and one set for `PIECEWISE` runtime mode, and dispatches the correct runtime mode and the dispatching keys before executing model's forwards. It will take in the initial key (an rough batch_descriptor for the padded input) and return the selected runtime mode and the final batch_descriptor, then tells the CudgraphWarpper instances that decision through forward contexts. We should notice that CUDAGraphDispatcher is the only source of truth for available cudagraph keys, and the CUDAGraphWrapper instances could have less logic and unquestioningly trust the forward context on what cudagraph to dispatch to.

The initialization of dispatching keys are through the `initialize_cudagraph_keys` method of the dispatcher, which is called by the gpu_model_runner after all possible attention backends are initialized. This is actually the place where we can get much fancier in the future and “prepare” all kinds of cudagraph combos. But for now, we just append availabel keys based on the valid combos of `decode_mode`/`mixed_mode` of cudagraph_mode and `cudagraph_capture_sizes` in the compilation config.

The dispatching code is like:
```python

Check failure on line 85 in docs/design/cuda_graphs_v1.md

View workflow job for this annotation

GitHub Actions / pre-commit

Fenced code blocks should be surrounded by blank lines [Context: "```python"]
batch_descriptor=BatchDescriptor(num_tokens=num_input_tokens, uniformed_decode=...)
runtime_mode, batch_descriptor = cudagraphdispatcher.dispatch(batch_descriptor)
# execution
with set_forward_context(...,
cudagraph_runtime_mode=runtime_mode,
batch_descriptor=batch_descriptor):
output = self.model(...)
```

Inside the `dispatch()` method, the dispatcher will search the proper cudagraph runtime mode and existing dispatching keys for a return. We basically search the existing keys following the priority: `FULL`>`PIECEWISE`>`None`. If the dispatching key does not exist, default to return `NONE` mode for eager execution. The implementations can be found [here](https://github.com/vllm-project/vllm/blob/main/vllm/v1/cudagraph_dispatcher.py#L91).

Here is an simplified illustration of the workflow at runtime in the model executor:
![executor_runtime](../assets/design/cuda_graphs_v1/executor_runtime.jpg)

### 3.4 [CUDAGraphWrapper][vllm.compilation.cuda_graph.CUDAGraphWrapper]

A `CUDAGraphWrapper` instance wraps a runnable and simply mimics the runable with appended cudagraph abilities. Each wrapper instance is bound to a specific `runtime_mode`, which is restricted to `PIECEWISE` and `FULL` mode, and takes responsibility for capturing/replaying and passing through (directly calling) the runnable. At runtime, each wrapper would:

1. inspect the runtime_mode and batch_descriptor(dispatching key) from the global forward context.

Check failure on line 104 in docs/design/cuda_graphs_v1.md

View workflow job for this annotation

GitHub Actions / pre-commit

Trailing spaces [Expected: 0 or 2; Actual: 1]
2. If runtime_mode is `NONE` or runtime_mode does not match the mode of the wrapper, just call the runnable directly.
3. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,the wrapper will perform cudagraph capture(if key does not exist, create
a new entry and cache it) or replay (if key exists in the cache).

The above steps is based on the assumption that the cudagraph wrapper would directly trusts what’s in the forward context (controlled by the dispatcher) without any fallback behavior. See the implementation [here](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/cuda_graph.py#L106).

#### 3.4.1 Nested Wrapper design
The core mechanism of making a full cudagraph and piecewise cudagraph coexist and compatible is the nested cudagraph wrapper design, building on top of piecewise compilation with only a single piecewise fx graph. We wrap a FULL mode wrapper outside the entire model for the full cudagraph functionality; meanwhile, each piecewise backend is wrapped via a `PIECEWISE` mode wrapper inside the compilation.

The flow chart below should clearly describe how it works.
![wrapper_flow](../assets/design/cuda_graphs_v1/wrapper_flow.png)

Therefore, for a `FULL` runtime mode, it is safe to capture/replay a full cudagraph since the piecewise wrapper is not activated. The situation is similar for `PIECEWISE` mode, as there are no conflicts between the `FULL` mode wrapper and `PIECEWISE` mode wrappers. For `NONE` runtime mode, both `FULL` and `PIECEWISE` wrappers would not be activated, so an eager execution is passed.

### 3.5 Full cudagraph capturing & warm-up

The cudagraph capturing is happened on the first call runner's dummy_run with non-`NONE` runtime mode. And for full cudagraph capture (pass `FULL` runtime mode), the core idea of explicitly capturing different cases (i.e. prefill/mixed batch or uniform_decode batch ) is
to tell the underlying attention backend to launch the desired kernel routines (i.e., may launch different kernels or combos for different cases) via carefully crafting the attn_metadatas. To distinguish prefill/mixed batch or uniform_decode batch, the most important property is the `max_query_len` in attn_metadata (true for most attention backends). we set it to the desired uniform_query_len for uniform_decode otherwise we make it just the `num_tokens` for a non-uniform_decode batch.

The cudagraph wrapper no longer manages the warm-up logic. The warm-up process is now controlled directly by the GPU model runner, where the `NONE` runtime mode is assigned to play an eager execution for warm-up. When warming up for a full cudagraph, it is also important to pass `force_attention=True` to the `dummy_run` function to explicitly warm up the attention backends.

---

## 4. Cudagraph Compatibility of Attention Backends

To signal the cuda graph compatibility of the attention backends, we introduce a new enum type [AttentionCGSupport][vllm.v1.attention.backends.utils.AttentionCGSupport], which is a enum type that tracks the capability of the attention backend to support cudagraph. the value is sorted in the order of the capability, i.e., `ALWAYS`> `UNIFORM_BATCH`> `UNIFORM_SINGLE_TOKEN_DECODE`> `NEVER`.

```python
class AttentionCGSupport(enum.Enum):
""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""

ALWAYS = 3
"""Cudagraph always supported; supports mixed-prefill-decode"""
UNIFORM_BATCH = 2
"""Cudagraph supported for batches the only contain query lengths that are
the same, this can be used for spec-decode
i.e. "decodes" are 1 + num_speculative_tokens"""
UNIFORM_SINGLE_TOKEN_DECODE = 1
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
NEVER = 0
"""NO cudagraph support"""
```

If we have hybrid attention backends (e.g. in mamba mixer models), we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible cudagraph mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation level. For the complete fallback policy, please see the code of [initialize_cudagraph_capture][vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_cudagraph_capture].

We have a collection of backends supporting full cudagraph. See the table below for a reference:
| Attention Backend | cudagraph_support | Comments |
|-------------------|---------|
| FlashAttention v2 | `UNIFORM_BATCH` | Actually `ALWAYS` but workaround to fallback to `FULL_AND_PIECEWISE` for preformance reason |
| FlashAttention v3 | `ALWAYS` | have unified routine for batch, so `FULL` mode is good |
| Triton Attention | `ALWAYS` | prefer `FULL_AND_PIECEWISE` since it have different kernels for prefill/mixed and pure decode batches |
| AITER FA | `ALWAYS`| |
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| FlashMLA | `UNIFORM_BATCH` | |
| AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | Default is `FULL_AND_PIECEWISE` |

Unlisted backends are all declared as `NEVER`.

---

## 5. Usage guide

Now the CLI is directly using the uppercase string of cudagraph_mode for compilation_config: `--compilation-config '{"cudagraph_mode": "..."}'`, where `...` should be one of `NONE`, `PIECEWISE`, `FULL`, `FULL_DECODE_ONLY`, and `FULL_AND_PIECEWISE`. Note that all `PIECEWISE` related modes require piecewise compilation, and all `FULL` related modes need cudagraph support of attention backends. For example:
```bash
vllm serve --model meta-llama/Llama-3.1-8B-Instruct --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}'
```

### 5.1 Python examples

```python
import os
os.environ.setdefault("VLLM_LOGGING_LEVEL", "DEBUG")

import vllm
from vllm.config import CUDAGraphMode

compilation_config = {"level": 3, "cudagraph_mode": "FULL_AND_PIECEWISE"}
model = vllm.LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
dtype='auto',
compilation_config = compilation_config,
)
sampling_params = vllm.SamplingParams(
temperature=0, # greedy decoding
max_tokens=1024,
)
outputs = model.generate(
"My name is john and",
sampling_params=sampling_params,
)
```

### 5.2 Migration from legacy flags

Legacy `use_cudagraph` and `full_cuda_graph` are unified by `cudagraph_mode`:

- `use_cudagraph=False``NONE`.
- `use_cudagraph=True` and `full_cuda_graph=False``PIECEWISE`.
- `full_cuda_graph=True` → directly set `FULL` and account for the gracefully fallback policy.

As they are deprecated and will be removed in the next major or minor release, i.e. v0.11.0 or v1.0.0. we recommend to use cudagraph_mode instead.

### 5.3 NOTE for attention ops fusion:
Currently, the default behavior of cudagraph_mode != `NONE` would always keep the attention ops in the splitting_ops to get piecewise fx graph, which means attention ops fusion is not compatible with piecewise cudagraph. In case one needs attention ops fusion, one can just manually passing `splitting_ops=[]` to compilation_config to retain the flattened fx graph, and using cudagraph_mode = "FULL" or "FULL_DECODE_ONLY" (should just avoid the PIECEWISE in mode even though we are using -O3). Currently, this RFC <gh-issue:23261> is tracking the progress of making attention ops fusion compatible with piecewise cudagraph to allow `FULL_AND_PIECEWISE` mode.

## 6. About the Performance
See the following link for examples:

comment1: <https://github.com/vllm-project/vllm/pull/20059#issuecomment-3160858458>
comment2: <https://github.com/vllm-project/vllm/pull/20059#issuecomment-3188735226>
comment3: <https://github.com/vllm-project/vllm/pull/20059#issuecomment-3219888738>

3 changes: 1 addition & 2 deletions docs/design/torch_compile.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,5 @@

### Full Cudagraph capture

It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config '{"full_cuda_graph": true}'`.
It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. See [CUDA Graphs v1](cuda_graphs_v1.md) for more details.

Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled.