@@ -289,18 +289,40 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
289289 run_model (vllm_config , mod_A , cudagraph_runtime_mode )
290290
291291
292- def test_no_weak_ref_output_decorator ():
292+ @pytest .mark .parametrize ("use_inductor_graph_partition" , [True , False ])
293+ def test_no_weak_ref_output_decorator (use_inductor_graph_partition , monkeypatch ):
294+ # disable compile cache so that we can count the number of compilations
295+ # appropriately
296+ monkeypatch .setenv ("VLLM_DISABLE_COMPILE_CACHE" , "1" )
297+
298+ if use_inductor_graph_partition and not is_torch_equal_or_newer ("2.9.0.dev" ):
299+ pytest .skip ("inductor graph partition is only available in PyTorch 2.9+" )
300+
293301 # piecewise
294302 vllm_config = VllmConfig (
295303 compilation_config = CompilationConfig (
296304 mode = CompilationMode .VLLM_COMPILE ,
297305 use_cudagraph = True ,
298- splitting_ops = ["silly. attention" ],
306+ splitting_ops = ["silly:: attention" ],
299307 cudagraph_capture_sizes = [1 , 2 ],
308+ use_inductor_graph_partition = use_inductor_graph_partition ,
300309 )
301310 )
302311 cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
303312
313+ expected_num_graphs_seen = 1
314+ expected_num_cudagraph_captured = (
315+ 4 # num_cudagraph_sizes * num cudagraphs to capture
316+ )
317+ if use_inductor_graph_partition :
318+ expected_num_piecewise_graphs_seen = 1
319+ expected_num_piecewise_capturable_graphs_seen = 1
320+ expected_num_backend_compilations = 1
321+ else :
322+ expected_num_piecewise_graphs_seen = 3
323+ expected_num_piecewise_capturable_graphs_seen = 2
324+ expected_num_backend_compilations = 2
325+
304326 @support_torch_compile (no_weak_ref_output = False )
305327 class A (nn .Module ):
306328 def __init__ (
@@ -333,12 +355,11 @@ class C(B): ...
333355
334356 # A has support_torch_compile
335357 with compilation_counter .expect (
336- num_graphs_seen = 1 ,
337- num_piecewise_graphs_seen = 3 ,
338- num_piecewise_capturable_graphs_seen = 2 ,
339- num_backend_compilations = 2 ,
340- num_cudagraph_captured = 4 ,
341- # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
358+ num_graphs_seen = expected_num_graphs_seen ,
359+ num_piecewise_graphs_seen = expected_num_piecewise_graphs_seen ,
360+ num_piecewise_capturable_graphs_seen = expected_num_piecewise_capturable_graphs_seen ,
361+ num_backend_compilations = expected_num_backend_compilations ,
362+ num_cudagraph_captured = expected_num_cudagraph_captured ,
342363 ):
343364 run_model (vllm_config , mod_A , cudagraph_runtime_mode )
344365
@@ -349,11 +370,11 @@ class C(B): ...
349370
350371 # B also has support_torch_compile
351372 with compilation_counter .expect (
352- num_graphs_seen = 1 ,
353- num_piecewise_graphs_seen = 3 ,
354- num_piecewise_capturable_graphs_seen = 2 ,
355- num_backend_compilations = 2 ,
356- num_cudagraph_captured = 4 ,
373+ num_graphs_seen = expected_num_graphs_seen ,
374+ num_piecewise_graphs_seen = expected_num_piecewise_graphs_seen ,
375+ num_piecewise_capturable_graphs_seen = expected_num_piecewise_capturable_graphs_seen ,
376+ num_backend_compilations = expected_num_backend_compilations ,
377+ num_cudagraph_captured = expected_num_cudagraph_captured ,
357378 ):
358379 run_model (vllm_config , mod_B , cudagraph_runtime_mode )
359380
@@ -364,10 +385,10 @@ class C(B): ...
364385
365386 # C has support_torch_compile
366387 with compilation_counter .expect (
367- num_graphs_seen = 1 ,
368- num_piecewise_graphs_seen = 3 ,
369- num_piecewise_capturable_graphs_seen = 2 ,
370- num_backend_compilations = 2 ,
371- num_cudagraph_captured = 4 ,
388+ num_graphs_seen = expected_num_graphs_seen ,
389+ num_piecewise_graphs_seen = expected_num_piecewise_graphs_seen ,
390+ num_piecewise_capturable_graphs_seen = expected_num_piecewise_capturable_graphs_seen ,
391+ num_backend_compilations = expected_num_backend_compilations ,
392+ num_cudagraph_captured = expected_num_cudagraph_captured ,
372393 ):
373394 run_model (vllm_config , mod_C , cudagraph_runtime_mode )
0 commit comments