@@ -286,18 +286,40 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
286286 run_model (vllm_config , mod_A , cudagraph_runtime_mode )
287287
288288
289- def test_no_weak_ref_output_decorator ():
289+ @pytest .mark .parametrize ("use_inductor_graph_partition" , [True , False ])
290+ def test_no_weak_ref_output_decorator (use_inductor_graph_partition , monkeypatch ):
291+ # disable compile cache so that we can count the number of compilations
292+ # appropriately
293+ monkeypatch .setenv ("VLLM_DISABLE_COMPILE_CACHE" , "1" )
294+
295+ if use_inductor_graph_partition and not is_torch_equal_or_newer ("2.9.0.dev" ):
296+ pytest .skip ("inductor graph partition is only available in PyTorch 2.9+" )
297+
290298 # piecewise
291299 vllm_config = VllmConfig (
292300 compilation_config = CompilationConfig (
293301 mode = CompilationMode .VLLM_COMPILE ,
294302 use_cudagraph = True ,
295- splitting_ops = ["silly. attention" ],
303+ splitting_ops = ["silly:: attention" ],
296304 cudagraph_capture_sizes = [1 , 2 ],
305+ use_inductor_graph_partition = use_inductor_graph_partition ,
297306 )
298307 )
299308 cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
300309
310+ expected_num_graphs_seen = 1
311+ expected_num_cudagraph_captured = (
312+ 4 # num_cudagraph_sizes * num cudagraphs to capture
313+ )
314+ if use_inductor_graph_partition :
315+ expected_num_piecewise_graphs_seen = 1
316+ expected_num_piecewise_capturable_graphs_seen = 1
317+ expected_num_backend_compilations = 1
318+ else :
319+ expected_num_piecewise_graphs_seen = 3
320+ expected_num_piecewise_capturable_graphs_seen = 2
321+ expected_num_backend_compilations = 2
322+
301323 @support_torch_compile (no_weak_ref_output = False )
302324 class A (nn .Module ):
303325 def __init__ (
@@ -330,12 +352,11 @@ class C(B): ...
330352
331353 # A has support_torch_compile
332354 with compilation_counter .expect (
333- num_graphs_seen = 1 ,
334- num_piecewise_graphs_seen = 3 ,
335- num_piecewise_capturable_graphs_seen = 2 ,
336- num_backend_compilations = 2 ,
337- num_cudagraph_captured = 4 ,
338- # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
355+ num_graphs_seen = expected_num_graphs_seen ,
356+ num_piecewise_graphs_seen = expected_num_piecewise_graphs_seen ,
357+ num_piecewise_capturable_graphs_seen = expected_num_piecewise_capturable_graphs_seen ,
358+ num_backend_compilations = expected_num_backend_compilations ,
359+ num_cudagraph_captured = expected_num_cudagraph_captured ,
339360 ):
340361 run_model (vllm_config , mod_A , cudagraph_runtime_mode )
341362
@@ -346,11 +367,11 @@ class C(B): ...
346367
347368 # B also has support_torch_compile
348369 with compilation_counter .expect (
349- num_graphs_seen = 1 ,
350- num_piecewise_graphs_seen = 3 ,
351- num_piecewise_capturable_graphs_seen = 2 ,
352- num_backend_compilations = 2 ,
353- num_cudagraph_captured = 4 ,
370+ num_graphs_seen = expected_num_graphs_seen ,
371+ num_piecewise_graphs_seen = expected_num_piecewise_graphs_seen ,
372+ num_piecewise_capturable_graphs_seen = expected_num_piecewise_capturable_graphs_seen ,
373+ num_backend_compilations = expected_num_backend_compilations ,
374+ num_cudagraph_captured = expected_num_cudagraph_captured ,
354375 ):
355376 run_model (vllm_config , mod_B , cudagraph_runtime_mode )
356377
@@ -361,10 +382,10 @@ class C(B): ...
361382
362383 # C has support_torch_compile
363384 with compilation_counter .expect (
364- num_graphs_seen = 1 ,
365- num_piecewise_graphs_seen = 3 ,
366- num_piecewise_capturable_graphs_seen = 2 ,
367- num_backend_compilations = 2 ,
368- num_cudagraph_captured = 4 ,
385+ num_graphs_seen = expected_num_graphs_seen ,
386+ num_piecewise_graphs_seen = expected_num_piecewise_graphs_seen ,
387+ num_piecewise_capturable_graphs_seen = expected_num_piecewise_capturable_graphs_seen ,
388+ num_backend_compilations = expected_num_backend_compilations ,
389+ num_cudagraph_captured = expected_num_cudagraph_captured ,
369390 ):
370391 run_model (vllm_config , mod_C , cudagraph_runtime_mode )
0 commit comments