From b46663906a3f0ef96eed0beb69e0b4b80eb4b317 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 23 Sep 2025 02:58:00 +0000 Subject: [PATCH 1/3] Change the default CUDAGraphMode from PIECEWISE TO FULL_AND_PIECEWISE Signed-off-by: mgoin --- vllm/config/__init__.py | 2 +- vllm/config/compilation.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 92fc68f8927c..0e075ed365b0 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -510,7 +510,7 @@ def __post_init__(self): if envs.VLLM_USE_V1 and self.compilation_config.level \ == CompilationLevel.PIECEWISE: self.compilation_config.cudagraph_mode = \ - CUDAGraphMode.PIECEWISE + CUDAGraphMode.FULL_AND_PIECEWISE else: self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 34fa7fcfe7e8..0441745e8b36 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -228,15 +228,14 @@ class CompilationConfig: The mode of the cudagraph: - NONE, no cudagraph capture. - - PIECEWISE. (v1 default) + - PIECEWISE. - FULL. - FULL_DECODE_ONLY. - - FULL_AND_PIECEWISE. + - FULL_AND_PIECEWISE. (v1 default) PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph incompatible ops (i.e. some attention ops) outside the cudagraph for general flexibility. - This is the default mode. FULL mode: Capture full cudagraph for all batches. Can be good for small models or workloads with small prompts; not supported by many backends. @@ -249,7 +248,7 @@ class CompilationConfig: FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and piecewise cudagraph for prefill and mixed prefill-decode batches. - This is like the most performant mode for most models. + This is the most performant mode for most models and is the default. Currently, the cudagraph mode is only used for the v1 engine. Note that the cudagraph logic is generally orthogonal to the From 91f371fa3ab48eca6674b2a38cdcff4f1b57f8df Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 23 Sep 2025 17:07:28 +0000 Subject: [PATCH 2/3] Reduce seq_lens to max_query_len and setup fallbacks Signed-off-by: mgoin --- vllm/v1/worker/gpu_model_runner.py | 31 ++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 89b9a3c34f2a..834cb5163124 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2960,8 +2960,7 @@ def _dummy_run( # TODO(luka) better system for describing dummy batches seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] else: - # Make sure max_model_len is used at the graph capture time. - seq_lens = self.max_model_len + seq_lens = max_query_len self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() @@ -3562,6 +3561,34 @@ def initialize_cudagraph_capture(self) -> None: CUDAGraphMode.FULL_DECODE_ONLY logger.warning(msg) + # check that if we are doing decode full-cudagraphs it is supported + if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and min_cg_support == AttentionCGSupport.NEVER): + msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})") + if (self.compilation_config.level == CompilationLevel.PIECEWISE and + (self.compilation_config.splitting_ops_contain_attention() + or self.compilation_config.use_inductor_graph_partition)): + msg += "; setting cudagraph_mode=PIECEWISE because "\ + "attention is compiled piecewise" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + else: + msg += "; setting cudagraph_mode=NONE because "\ + "attention is not compiled piecewise" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.NONE + logger.warning(msg) + + # pooling model does not support full cudagraphs + if cudagraph_mode.has_full_cudagraphs() and self.is_pooling_model: + msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + "with pooling model; setting cudagraph_mode=PIECEWISE") + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + logger.warning(msg) + # check that if we are doing spec-decode + decode full-cudagraphs it is # supported if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL From 74be337c0e22ebd63a29480d490fad8f5d7d1822 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 23 Sep 2025 17:48:30 +0000 Subject: [PATCH 3/3] Move pooler special case to config Signed-off-by: mgoin --- vllm/config/__init__.py | 7 +++++++ vllm/v1/worker/gpu_model_runner.py | 8 -------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 0e075ed365b0..6f79392094ef 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -509,8 +509,15 @@ def __post_init__(self): if self.compilation_config.cudagraph_mode is None: if envs.VLLM_USE_V1 and self.compilation_config.level \ == CompilationLevel.PIECEWISE: + # default to full and piecewise for most models self.compilation_config.cudagraph_mode = \ CUDAGraphMode.FULL_AND_PIECEWISE + + # pooling model does not support full cudagraphs + if self.model_config is not None and \ + self.model_config.pooler_config is not None: + self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE else: self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 834cb5163124..7263a5610987 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3581,14 +3581,6 @@ def initialize_cudagraph_capture(self) -> None: CUDAGraphMode.NONE logger.warning(msg) - # pooling model does not support full cudagraphs - if cudagraph_mode.has_full_cudagraphs() and self.is_pooling_model: - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - "with pooling model; setting cudagraph_mode=PIECEWISE") - cudagraph_mode = self.compilation_config.cudagraph_mode = \ - CUDAGraphMode.PIECEWISE - logger.warning(msg) - # check that if we are doing spec-decode + decode full-cudagraphs it is # supported if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL