diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e816a20fe064..2af0e46ea15f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -443,6 +443,7 @@ steps: - vllm/ - tests/compile commands: + - pytest -v -s compile/test_config.py - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py - pytest -v -s compile/test_fusion_attn.py diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md index b56cf61e782c..177a581587d0 100644 --- a/docs/design/cuda_graphs.md +++ b/docs/design/cuda_graphs.md @@ -218,16 +218,6 @@ outputs = model.generate( ) ``` -### Migration from legacy flags - -Legacy `use_cudagraph` and `full_cuda_graph` are unified by `cudagraph_mode`: - -* `use_cudagraph=False` → `NONE`. -* `use_cudagraph=True` and `full_cuda_graph=False` → `PIECEWISE`. -* `full_cuda_graph=True` → directly set `FULL` and rely on the graceful fallback policy. - -As they are deprecated and will be removed in the next major or minor release, i.e., v0.11.0 or v1.0.0, we recommend using cudagraph_mode instead. - ### Piecewise compilation and full graph custom passes (attention fusion, sequence parallelism) Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs. diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index 700f57ffb068..64d626bae483 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -203,7 +203,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool): vllm_config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - use_cudagraph=True, + cudagraph_mode=CUDAGraphMode.PIECEWISE, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], use_inductor_graph_partition=use_inductor_graph_partition, @@ -281,7 +281,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool): vllm_config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - use_cudagraph=False, + cudagraph_mode=CUDAGraphMode.NONE, splitting_ops=["silly::attention"], use_inductor_graph_partition=use_inductor_graph_partition, ) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 228859532ef4..a48af8a8952a 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -62,7 +62,6 @@ def _run_simple_model( vllm_config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - use_cudagraph=True, use_inductor=use_inductor, splitting_ops=splitting_ops, use_inductor_graph_partition=use_inductor_graph_partition, diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 6887673eb6a5..92998ede1699 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -449,7 +449,6 @@ def benchmark(): if piecewise: compilation_config = CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=cudagraph_sizes, ) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 7455147f2b95..bb66ef5529b1 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -2,8 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy from contextlib import nullcontext +from unittest.mock import patch import pytest +from pydantic import ValidationError from vllm.compilation.counter import compilation_counter from vllm.compilation.fix_functionalization import FixFunctionalizationPass @@ -11,7 +13,7 @@ from vllm.config.compilation import CompilationMode from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform -from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer +from vllm.utils.torch_utils import _is_torch_equal_or_newer def test_version(): @@ -23,14 +25,6 @@ def test_version(): assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev") -def test_use_cudagraphs_dynamic(): - vllm_config = VllmConfig() - # Default V1 configuration now starts without cudagraphs enabled; the - # engine decides when to capture based on runtime settings instead of a - # blanket default. - assert vllm_config.compilation_config.use_cudagraph - - def test_copy_pass(): vllm_config = VllmConfig() inductor_pass = FixFunctionalizationPass(vllm_config) @@ -65,7 +59,7 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val): monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val) compilation_config = { - "use_cudagraph": False, # speed things up a bit + "cudagraph_mode": CUDAGraphMode.NONE, # speed things up a bit } with ( compilation_counter.expect( @@ -83,20 +77,31 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val): # forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 @pytest.mark.forked -@pytest.mark.parametrize("enabled", [True, False]) -def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): +@pytest.mark.parametrize( + "cudagraph_mode,num_cudagraph_captured", + [ + (CUDAGraphMode.NONE, 0), + (CUDAGraphMode.FULL_DECODE_ONLY, 1), + (CUDAGraphMode.PIECEWISE, 13), + (CUDAGraphMode.FULL_AND_PIECEWISE, 14), + ], +) +def test_use_cudagraphs( + vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured +): # Disable multiprocessing so that the counter is in the same process monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") compilation_config = { "cudagraph_capture_sizes": [100], - "use_cudagraph": enabled, + "cudagraph_mode": cudagraph_mode, } + num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0 with ( compilation_counter.expect( num_graphs_seen=1, - num_gpu_runner_capture_triggers=1 if enabled else 0, - num_cudagraph_captured=13 if enabled else 0, + num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers, + num_cudagraph_captured=num_cudagraph_captured, ), # loading the model causes compilation (if enabled) to happen vllm_runner( @@ -168,19 +173,18 @@ def test_splitting_ops_dynamic(): assert not config.compilation_config.splitting_ops_contain_attention() # When use_inductor_graph_partition=True - if is_torch_equal_or_newer("2.9.0.dev"): - config = VllmConfig( - compilation_config=CompilationConfig( - mode=CompilationMode.VLLM_COMPILE, - use_inductor_graph_partition=True, - splitting_ops=["vllm::unified_attention"], - ) + config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + use_inductor_graph_partition=True, + splitting_ops=["vllm::unified_attention"], ) - # with inductor partition we use splitting_ops directly for - # partition rules - assert config.compilation_config.splitting_ops == ["vllm::unified_attention"] + ) + # with inductor partition we use splitting_ops directly for + # partition rules + assert config.compilation_config.splitting_ops == ["vllm::unified_attention"] - # When attn_fusion pass enabled, splitting_ops now default to attention ops. + # When attn_fusion pass enabled. config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, @@ -189,29 +193,41 @@ def test_splitting_ops_dynamic(): cudagraph_mode=CUDAGraphMode.PIECEWISE, ) ) - # With the new simplified logic, attention fusion works with splitting_ops - assert config.compilation_config.splitting_ops_contain_attention() - # cudagraph mode remains PIECEWISE - assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE + assert config.compilation_config.splitting_ops == [] + # cudagraph mode also fall back to FULL + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL - # When both use_inductor_graph_partition and attn_fusion pass enabled. - if is_torch_equal_or_newer("2.9.0.dev"): + # splitting_ops can not contain attention ops when attn_fusion + # pass enabled. + with pytest.raises(ValidationError): config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - use_inductor_graph_partition=True, pass_config={"enable_attn_fusion": True, "enable_noop": True}, custom_ops=["+quant_fp8"], cudagraph_mode=CUDAGraphMode.PIECEWISE, + # work around for accessing all attntion ops + splitting_ops=CompilationConfig()._attention_ops, ) ) - # With inductor graph partition, attn_fusion and splitting_ops - # work together. Default splitting_ops include attention ops. - assert config.compilation_config.splitting_ops_contain_attention() - # enable_attn_fusion is directly supported under - # use_inductor_graph_partition=True, and cudagraph_mode - # is unchanged. - assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE + + # When both use_inductor_graph_partition and attn_fusion pass enabled. + config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + use_inductor_graph_partition=True, + pass_config={"enable_attn_fusion": True, "enable_noop": True}, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + ) + ) + # With inductor graph partition, attn_fusion and splitting_ops + # work together. Default splitting_ops include attention ops. + assert config.compilation_config.splitting_ops_contain_attention() + # enable_attn_fusion is directly supported under + # use_inductor_graph_partition=True, and cudagraph_mode + # is unchanged. + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE def test_should_split(): @@ -293,25 +309,36 @@ def attention( "tp_size", "enable_sequence_parallelism", "max_num_batched_tokens", - "use_cudagraph", + "cudagraph_mode", "expected_max_size", ), [ - (None, None, 1, False, 2048, True, 512), - ([1, 2, 4], 4, 1, False, 2048, True, 4), - ([1, 2, 4], 8, 1, False, 2048, True, RuntimeError), - ([1, 256], None, 1, False, 2048, 256), - ([], None, 1, False, 2048, False, 0), - (None, 0, 1, False, 2048, False, 0), + (None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256), + ([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4), + ( + [1, 2, 4], + 8, + 1, + False, + 2048, + CUDAGraphMode.FULL_AND_PIECEWISE, + ValidationError, + ), + ([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256), + ([], None, 1, False, 2048, CUDAGraphMode.NONE, 0), + (None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0), # truncated to nearest multiple of 8 or 16 - (None, 257, 1, False, 2048, True, 256), - ([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list - ([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP - ([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens + (None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256), + # max from list + ([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15), + # filtered out 15 due to SP + ([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4), + # limited by the max_tokens + ([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4), # the list should contain at least 1 element when use cudagraph - ([], None, 1, False, 2048, True, RuntimeError), + ([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError), # the max capturing size should be >= 1 when use cudagraph - (None, 0, 1, False, 2048, True, RuntimeError), + (None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError), ], ) def test_cudagraph_sizes_post_init( @@ -320,15 +347,17 @@ def test_cudagraph_sizes_post_init( tp_size, enable_sequence_parallelism, max_num_batched_tokens, - use_cudagraph, + cudagraph_mode, expected_max_size, ): ctx = nullcontext() - if isinstance(expected_max_size, Exception): + if expected_max_size == ValidationError: ctx = pytest.raises(expected_max_size) - cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE - with ctx: + with ( + ctx, + patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size), + ): compilation_config = CompilationConfig( cudagraph_capture_sizes=cudagraph_capture_sizes, max_cudagraph_capture_size=max_cudagraph_capture_size, @@ -342,11 +371,13 @@ def test_cudagraph_sizes_post_init( engine_args = EngineArgs( model="facebook/opt-125m", tensor_parallel_size=tp_size, + max_num_seqs=min(max_num_batched_tokens, 128), max_num_batched_tokens=max_num_batched_tokens, compilation_config=compilation_config, ) vllm_config = engine_args.create_engine_config() - assert ( - vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size - ) + assert ( + vllm_config.compilation_config.max_cudagraph_capture_size + == expected_max_size + ) diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index c9d01f2317d2..1850cc8f1479 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -80,7 +80,6 @@ def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatc vllm_config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], use_inductor_graph_partition=use_inductor_graph_partition, @@ -215,7 +214,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch ), compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], use_inductor_graph_partition=use_inductor_graph_partition, @@ -257,7 +255,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch ), compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], use_inductor_graph_partition=use_inductor_graph_partition, diff --git a/tests/models/multimodal/generation/test_qwen2_5_vl.py b/tests/models/multimodal/generation/test_qwen2_5_vl.py index 1a7d854352ae..6b009075abfa 100644 --- a/tests/models/multimodal/generation/test_qwen2_5_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_5_vl.py @@ -61,10 +61,8 @@ def test_qwen2_5_vl_evs_functionality( model, runner="generate", max_model_len=4000, - max_num_seqs=1, dtype=dtype, limit_mm_per_prompt={"video": 1}, - tensor_parallel_size=1, video_pruning_rate=video_pruning_rate, ) as vllm_model: # Generate output - this should not crash diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 68eb9420e70d..b0d1bc2bab30 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -206,7 +206,6 @@ class CompilationConfig: - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops] - [`compile_mm_encoder`][vllm.config.CompilationConfig.compile_mm_encoder] - CudaGraph capture: - - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph] - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode] - [`cudagraph_capture_sizes`] [vllm.config.CompilationConfig.cudagraph_capture_sizes] @@ -216,7 +215,6 @@ class CompilationConfig: [vllm.config.CompilationConfig.cudagraph_num_of_warmups] - [`cudagraph_copy_inputs`] [vllm.config.CompilationConfig.cudagraph_copy_inputs] - - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph] - Inductor compilation: - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] @@ -396,18 +394,6 @@ class CompilationConfig: Warning: This flag is new and subject to change in addition more modes may be added. """ - use_cudagraph: bool = True - """Whether to use cudagraph inside compilation: - - - False: cudagraph inside compilation is not used.\n - - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses, and all - splitting ops write their outputs to input buffers. - - Warning: This flag is deprecated and will be removed in the next major or - minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL_AND - _PIECEWISE instead. - """ cudagraph_num_of_warmups: int = 0 """Number of warmup runs for cudagraph. It means the first several runs will be treated as warmup runs. @@ -425,15 +411,6 @@ class CompilationConfig: internally managed buffer. Default is False. Note that this flag is only effective when cudagraph_mode is PIECEWISE. """ - full_cuda_graph: bool | None = False - """whether to use a full cuda graph for the entire forward pass rather than - splitting certain operations such as attention into subgraphs. Thus this - flag cannot be used together with splitting_ops. This may provide - performance benefits for smaller models. - Warning: This flag is deprecated and will be removed in the next major or - minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode= - FULL_AND_PIECEWISE instead. - """ cudagraph_specialize_lora: bool = True """Whether to create separate cuda graphs for cases with and without active LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used @@ -603,13 +580,19 @@ def validate_mode_before(cls, value: Any) -> Any: @field_validator("cudagraph_mode", mode="before") @classmethod def validate_cudagraph_mode_before(cls, value: Any) -> Any: - """ - enable parse the `cudagraph_mode` enum type from string - """ + """Enable parsing of the `cudagraph_mode` enum type from string.""" if isinstance(value, str): return CUDAGraphMode[value.upper()] return value + @field_validator("pass_config", mode="before") + @classmethod + def validate_pass_config_before(cls, value: Any) -> Any: + """Enable parsing of the `pass_config` field from a dictionary.""" + if isinstance(value, dict): + return PassConfig(**value) + return value + @field_validator("compile_cache_save_format") @classmethod def validate_compile_cache_save_format(cls, value: str) -> str: @@ -666,9 +649,6 @@ def __post_init__(self) -> None: func if isinstance(func, InductorPass) else CallableInductorPass(func) ) - if isinstance(self.pass_config, dict): - self.pass_config = PassConfig(**self.pass_config) - if self.pass_config.enable_qk_norm_rope_fusion: # TODO(zhuhaoran): support rope native forward match and remove this. # Linked issue: https://github.com/vllm-project/vllm/issues/28042 @@ -684,36 +664,6 @@ def __post_init__(self) -> None: self.inductor_compile_config["combo_kernels"] = True self.inductor_compile_config["benchmark_combo_kernel"] = True - # migrate the deprecated flags - if not self.use_cudagraph: - logger.warning( - "use_cudagraph is deprecated, use cudagraph_mode=NONE instead." - ) - if ( - self.cudagraph_mode is not None - and self.cudagraph_mode != CUDAGraphMode.NONE - ): - raise ValueError( - "use_cudagraph and cudagraph_mode are mutually" - " exclusive, prefer cudagraph_mode since " - "use_cudagraph is deprecated." - ) - self.cudagraph_mode = CUDAGraphMode.NONE - if self.full_cuda_graph: - logger.warning( - "full_cuda_graph is deprecated, use cudagraph_mode=FULL instead." - ) - if ( - self.cudagraph_mode is not None - and not self.cudagraph_mode.has_full_cudagraphs() - ): - raise ValueError( - "full_cuda_graph and cudagraph_mode are " - "mutually exclusive, prefer cudagraph_mode " - "since full_cuda_graph is deprecated." - ) - self.cudagraph_mode = CUDAGraphMode.FULL - if self.use_inductor_graph_partition and not is_torch_equal_or_newer( "2.9.0.dev" ): @@ -891,20 +841,19 @@ def set_splitting_ops_for_inductor_graph_partition(self): def set_splitting_ops_for_attn_fusion(self): assert self.pass_config.enable_attn_fusion - # For dynamo-partition (non-inductor) attention fusion, - # set splitting_ops to empty to avoid splitting at attention ops - self.splitting_ops = [] - if self.cudagraph_mode.has_piecewise_cudagraphs(): - logger.warning_once( - "enable_attn_fusion is incompatible with piecewise " - "cudagraph when use_inductor_graph_partition is off. " - "In this case, splitting_ops will be set to empty " - "list, and cudagraph_mode will be set to FULL. " - "Please ensure you are using attention backends that " - "support cudagraph or set cudagraph_mode to NONE " - "explicitly if encountering any problems." - ) - self.cudagraph_mode = CUDAGraphMode.FULL + if self.splitting_ops is None: + self.splitting_ops = [] + if self.cudagraph_mode.has_piecewise_cudagraphs(): + logger.warning_once( + "enable_attn_fusion is incompatible with piecewise " + "cudagraph when use_inductor_graph_partition is off. " + "In this case, splitting_ops will be set to empty " + "list, and cudagraph_mode will be set to FULL. " + "Please ensure you are using attention backends that " + "support cudagraph or set cudagraph_mode to NONE " + "explicitly if encountering any problems." + ) + self.cudagraph_mode = CUDAGraphMode.FULL assert not self.splitting_ops_contain_attention(), ( "attention ops should not be in splitting_ops " diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 60458b26944a..f581267f73f7 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -656,14 +656,6 @@ def __post_init__(self): f"cudagraph_mode={self.compilation_config.cudagraph_mode}" ) - # final migrate the deprecated flags - self.compilation_config.use_cudagraph = ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - ) - self.compilation_config.full_cuda_graph = ( - self.compilation_config.cudagraph_mode.has_full_cudagraphs() - ) - if self.parallel_config.enable_dbo: a2a_backend = self.parallel_config.all2all_backend assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], ( @@ -853,7 +845,9 @@ def _set_cudagraph_sizes(self): ) # de-duplicate the sizes provided by the config dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes)) - cudagraph_capture_sizes = dedup_sizes + cudagraph_capture_sizes = [ + i for i in dedup_sizes if i <= max_num_tokens + ] # sort to make sure the sizes are in ascending order cudagraph_capture_sizes.sort() else: diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 909af09be255..8e949e53330c 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -123,7 +123,7 @@ def build( elif ( num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs - and self.compilation_config.full_cuda_graph + and self.compilation_config.cudagraph_mode.has_full_cudagraphs() ): padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 4bc1057333a5..888734e5d2b6 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -302,7 +302,7 @@ def build( elif ( num_decodes <= self.decode_cudagraph_max_bs - and self.compilation_config.full_cuda_graph + and self.compilation_config.cudagraph_mode.has_full_cudagraphs() ): # Pad state tensor for CUDA graph num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index 22ad1054b35e..de0cb73db091 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -81,7 +81,7 @@ def build( elif ( num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs - and self.compilation_config.full_cuda_graph + and self.compilation_config.cudagraph_mode.has_full_cudagraphs() ): num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_(