22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import copy
44from contextlib import nullcontext
5+ from unittest .mock import patch
56
67import pytest
8+ from pydantic import ValidationError
79
810from vllm .compilation .counter import compilation_counter
911from vllm .compilation .fix_functionalization import FixFunctionalizationPass
1012from vllm .config import CompilationConfig , CUDAGraphMode , VllmConfig
1113from vllm .config .compilation import CompilationMode
1214from vllm .engine .arg_utils import EngineArgs
1315from vllm .platforms import current_platform
14- from vllm .utils .torch_utils import _is_torch_equal_or_newer , is_torch_equal_or_newer
16+ from vllm .utils .torch_utils import _is_torch_equal_or_newer
1517
1618
1719def test_version ():
@@ -23,14 +25,6 @@ def test_version():
2325 assert not _is_torch_equal_or_newer ("2.7.1" , "2.8.0.dev" )
2426
2527
26- def test_use_cudagraphs_dynamic ():
27- vllm_config = VllmConfig ()
28- # Default V1 configuration now starts without cudagraphs enabled; the
29- # engine decides when to capture based on runtime settings instead of a
30- # blanket default.
31- assert vllm_config .compilation_config .use_cudagraph
32-
33-
3428def test_copy_pass ():
3529 vllm_config = VllmConfig ()
3630 inductor_pass = FixFunctionalizationPass (vllm_config )
@@ -65,7 +59,7 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
6559 monkeypatch .setenv ("VLLM_DISABLE_COMPILE_CACHE" , val )
6660
6761 compilation_config = {
68- "use_cudagraph " : False , # speed things up a bit
62+ "cudagraph_mode " : CUDAGraphMode . NONE , # speed things up a bit
6963 }
7064 with (
7165 compilation_counter .expect (
@@ -83,20 +77,31 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
8377
8478# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
8579@pytest .mark .forked
86- @pytest .mark .parametrize ("enabled" , [True , False ])
87- def test_use_cudagraphs (vllm_runner , monkeypatch , enabled ):
80+ @pytest .mark .parametrize (
81+ "cudagraph_mode,num_cudagraph_captured" ,
82+ [
83+ (CUDAGraphMode .NONE , 0 ),
84+ (CUDAGraphMode .FULL_DECODE_ONLY , 1 ),
85+ (CUDAGraphMode .PIECEWISE , 13 ),
86+ (CUDAGraphMode .FULL_AND_PIECEWISE , 14 ),
87+ ],
88+ )
89+ def test_use_cudagraphs (
90+ vllm_runner , monkeypatch , cudagraph_mode , num_cudagraph_captured
91+ ):
8892 # Disable multiprocessing so that the counter is in the same process
8993 monkeypatch .setenv ("VLLM_ENABLE_V1_MULTIPROCESSING" , "0" )
9094
9195 compilation_config = {
9296 "cudagraph_capture_sizes" : [100 ],
93- "use_cudagraph " : enabled ,
97+ "cudagraph_mode " : cudagraph_mode ,
9498 }
99+ num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode .NONE else 0
95100 with (
96101 compilation_counter .expect (
97102 num_graphs_seen = 1 ,
98- num_gpu_runner_capture_triggers = 1 if enabled else 0 ,
99- num_cudagraph_captured = 13 if enabled else 0 ,
103+ num_gpu_runner_capture_triggers = num_gpu_runner_capture_triggers ,
104+ num_cudagraph_captured = num_cudagraph_captured ,
100105 ),
101106 # loading the model causes compilation (if enabled) to happen
102107 vllm_runner (
@@ -168,19 +173,18 @@ def test_splitting_ops_dynamic():
168173 assert not config .compilation_config .splitting_ops_contain_attention ()
169174
170175 # When use_inductor_graph_partition=True
171- if is_torch_equal_or_newer ("2.9.0.dev" ):
172- config = VllmConfig (
173- compilation_config = CompilationConfig (
174- mode = CompilationMode .VLLM_COMPILE ,
175- use_inductor_graph_partition = True ,
176- splitting_ops = ["vllm::unified_attention" ],
177- )
176+ config = VllmConfig (
177+ compilation_config = CompilationConfig (
178+ mode = CompilationMode .VLLM_COMPILE ,
179+ use_inductor_graph_partition = True ,
180+ splitting_ops = ["vllm::unified_attention" ],
178181 )
179- # with inductor partition we use splitting_ops directly for
180- # partition rules
181- assert config .compilation_config .splitting_ops == ["vllm::unified_attention" ]
182+ )
183+ # with inductor partition we use splitting_ops directly for
184+ # partition rules
185+ assert config .compilation_config .splitting_ops == ["vllm::unified_attention" ]
182186
183- # When attn_fusion pass enabled, splitting_ops now default to attention ops .
187+ # When attn_fusion pass enabled.
184188 config = VllmConfig (
185189 compilation_config = CompilationConfig (
186190 mode = CompilationMode .VLLM_COMPILE ,
@@ -189,29 +193,41 @@ def test_splitting_ops_dynamic():
189193 cudagraph_mode = CUDAGraphMode .PIECEWISE ,
190194 )
191195 )
192- # With the new simplified logic, attention fusion works with splitting_ops
193- assert config .compilation_config .splitting_ops_contain_attention ()
194- # cudagraph mode remains PIECEWISE
195- assert config .compilation_config .cudagraph_mode == CUDAGraphMode .PIECEWISE
196+ assert config .compilation_config .splitting_ops == []
197+ # cudagraph mode also fall back to FULL
198+ assert config .compilation_config .cudagraph_mode == CUDAGraphMode .FULL
196199
197- # When both use_inductor_graph_partition and attn_fusion pass enabled.
198- if is_torch_equal_or_newer ("2.9.0.dev" ):
200+ # splitting_ops can not contain attention ops when attn_fusion
201+ # pass enabled.
202+ with pytest .raises (ValidationError ):
199203 config = VllmConfig (
200204 compilation_config = CompilationConfig (
201205 mode = CompilationMode .VLLM_COMPILE ,
202- use_inductor_graph_partition = True ,
203206 pass_config = {"enable_attn_fusion" : True , "enable_noop" : True },
204207 custom_ops = ["+quant_fp8" ],
205208 cudagraph_mode = CUDAGraphMode .PIECEWISE ,
209+ # work around for accessing all attntion ops
210+ splitting_ops = CompilationConfig ()._attention_ops ,
206211 )
207212 )
208- # With inductor graph partition, attn_fusion and splitting_ops
209- # work together. Default splitting_ops include attention ops.
210- assert config .compilation_config .splitting_ops_contain_attention ()
211- # enable_attn_fusion is directly supported under
212- # use_inductor_graph_partition=True, and cudagraph_mode
213- # is unchanged.
214- assert config .compilation_config .cudagraph_mode == CUDAGraphMode .PIECEWISE
213+
214+ # When both use_inductor_graph_partition and attn_fusion pass enabled.
215+ config = VllmConfig (
216+ compilation_config = CompilationConfig (
217+ mode = CompilationMode .VLLM_COMPILE ,
218+ use_inductor_graph_partition = True ,
219+ pass_config = {"enable_attn_fusion" : True , "enable_noop" : True },
220+ custom_ops = ["+quant_fp8" ],
221+ cudagraph_mode = CUDAGraphMode .PIECEWISE ,
222+ )
223+ )
224+ # With inductor graph partition, attn_fusion and splitting_ops
225+ # work together. Default splitting_ops include attention ops.
226+ assert config .compilation_config .splitting_ops_contain_attention ()
227+ # enable_attn_fusion is directly supported under
228+ # use_inductor_graph_partition=True, and cudagraph_mode
229+ # is unchanged.
230+ assert config .compilation_config .cudagraph_mode == CUDAGraphMode .PIECEWISE
215231
216232
217233def test_should_split ():
@@ -293,25 +309,36 @@ def attention(
293309 "tp_size" ,
294310 "enable_sequence_parallelism" ,
295311 "max_num_batched_tokens" ,
296- "use_cudagraph " ,
312+ "cudagraph_mode " ,
297313 "expected_max_size" ,
298314 ),
299315 [
300- (None , None , 1 , False , 2048 , True , 512 ),
301- ([1 , 2 , 4 ], 4 , 1 , False , 2048 , True , 4 ),
302- ([1 , 2 , 4 ], 8 , 1 , False , 2048 , True , RuntimeError ),
303- ([1 , 256 ], None , 1 , False , 2048 , 256 ),
304- ([], None , 1 , False , 2048 , False , 0 ),
305- (None , 0 , 1 , False , 2048 , False , 0 ),
316+ (None , None , 1 , False , 2048 , CUDAGraphMode .FULL_AND_PIECEWISE , 256 ),
317+ ([1 , 2 , 4 ], 4 , 1 , False , 2048 , CUDAGraphMode .FULL_AND_PIECEWISE , 4 ),
318+ (
319+ [1 , 2 , 4 ],
320+ 8 ,
321+ 1 ,
322+ False ,
323+ 2048 ,
324+ CUDAGraphMode .FULL_AND_PIECEWISE ,
325+ ValidationError ,
326+ ),
327+ ([1 , 256 ], None , 1 , False , 2048 , CUDAGraphMode .FULL_AND_PIECEWISE , 256 ),
328+ ([], None , 1 , False , 2048 , CUDAGraphMode .NONE , 0 ),
329+ (None , 0 , 1 , False , 2048 , CUDAGraphMode .NONE , 0 ),
306330 # truncated to nearest multiple of 8 or 16
307- (None , 257 , 1 , False , 2048 , True , 256 ),
308- ([1 , 2 , 4 , 15 ], None , 1 , False , 2048 , True , 15 ), # max from list
309- ([1 , 2 , 4 , 15 ], None , 2 , True , 2048 , True , 4 ), # filtered out 15 due to SP
310- ([1 , 2 , 4 , 15 ], None , 1 , False , 8 , True , 4 ), # limited by the max_tokens
331+ (None , 257 , 1 , False , 2048 , CUDAGraphMode .FULL_AND_PIECEWISE , 256 ),
332+ # max from list
333+ ([1 , 2 , 4 , 15 ], None , 1 , False , 2048 , CUDAGraphMode .FULL_AND_PIECEWISE , 15 ),
334+ # filtered out 15 due to SP
335+ ([1 , 2 , 4 , 15 ], None , 2 , True , 2048 , CUDAGraphMode .FULL_AND_PIECEWISE , 4 ),
336+ # limited by the max_tokens
337+ ([1 , 2 , 4 , 15 ], None , 1 , False , 8 , CUDAGraphMode .FULL_AND_PIECEWISE , 4 ),
311338 # the list should contain at least 1 element when use cudagraph
312- ([], None , 1 , False , 2048 , True , RuntimeError ),
339+ ([], None , 1 , False , 2048 , CUDAGraphMode . FULL_AND_PIECEWISE , ValidationError ),
313340 # the max capturing size should be >= 1 when use cudagraph
314- (None , 0 , 1 , False , 2048 , True , RuntimeError ),
341+ (None , 0 , 1 , False , 2048 , CUDAGraphMode . FULL_AND_PIECEWISE , ValidationError ),
315342 ],
316343)
317344def test_cudagraph_sizes_post_init (
@@ -320,15 +347,17 @@ def test_cudagraph_sizes_post_init(
320347 tp_size ,
321348 enable_sequence_parallelism ,
322349 max_num_batched_tokens ,
323- use_cudagraph ,
350+ cudagraph_mode ,
324351 expected_max_size ,
325352):
326353 ctx = nullcontext ()
327- if isinstance ( expected_max_size , Exception ) :
354+ if expected_max_size == ValidationError :
328355 ctx = pytest .raises (expected_max_size )
329356
330- cudagraph_mode = CUDAGraphMode .PIECEWISE if use_cudagraph else CUDAGraphMode .NONE
331- with ctx :
357+ with (
358+ ctx ,
359+ patch ("vllm.config.parallel.cuda_device_count_stateless" , return_value = tp_size ),
360+ ):
332361 compilation_config = CompilationConfig (
333362 cudagraph_capture_sizes = cudagraph_capture_sizes ,
334363 max_cudagraph_capture_size = max_cudagraph_capture_size ,
@@ -342,11 +371,13 @@ def test_cudagraph_sizes_post_init(
342371 engine_args = EngineArgs (
343372 model = "facebook/opt-125m" ,
344373 tensor_parallel_size = tp_size ,
374+ max_num_seqs = min (max_num_batched_tokens , 128 ),
345375 max_num_batched_tokens = max_num_batched_tokens ,
346376 compilation_config = compilation_config ,
347377 )
348378 vllm_config = engine_args .create_engine_config ()
349379
350- assert (
351- vllm_config .compilation_config .max_cudagraph_capture_size == expected_max_size
352- )
380+ assert (
381+ vllm_config .compilation_config .max_cudagraph_capture_size
382+ == expected_max_size
383+ )
0 commit comments