Skip to content

Commit c5da93e

Browse files
committed
Fix test_no_weak_ref_output_decorator
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 02c0eae commit c5da93e

File tree

3 files changed

+43
-21
lines changed

3 files changed

+43
-21
lines changed

tests/compile/test_decorator.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -286,18 +286,40 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
286286
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
287287

288288

289-
def test_no_weak_ref_output_decorator():
289+
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
290+
def test_no_weak_ref_output_decorator(use_inductor_graph_partition, monkeypatch):
291+
# disable compile cache so that we can count the number of compilations
292+
# appropriately
293+
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
294+
295+
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
296+
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
297+
290298
# piecewise
291299
vllm_config = VllmConfig(
292300
compilation_config=CompilationConfig(
293301
mode=CompilationMode.VLLM_COMPILE,
294302
use_cudagraph=True,
295-
splitting_ops=["silly.attention"],
303+
splitting_ops=["silly::attention"],
296304
cudagraph_capture_sizes=[1, 2],
305+
use_inductor_graph_partition=use_inductor_graph_partition,
297306
)
298307
)
299308
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
300309

310+
expected_num_graphs_seen = 1
311+
expected_num_cudagraph_captured = (
312+
4 # num_cudagraph_sizes * num cudagraphs to capture
313+
)
314+
if use_inductor_graph_partition:
315+
expected_num_piecewise_graphs_seen = 1
316+
expected_num_piecewise_capturable_graphs_seen = 1
317+
expected_num_backend_compilations = 1
318+
else:
319+
expected_num_piecewise_graphs_seen = 3
320+
expected_num_piecewise_capturable_graphs_seen = 2
321+
expected_num_backend_compilations = 2
322+
301323
@support_torch_compile(no_weak_ref_output=False)
302324
class A(nn.Module):
303325
def __init__(
@@ -330,12 +352,11 @@ class C(B): ...
330352

331353
# A has support_torch_compile
332354
with compilation_counter.expect(
333-
num_graphs_seen=1,
334-
num_piecewise_graphs_seen=3,
335-
num_piecewise_capturable_graphs_seen=2,
336-
num_backend_compilations=2,
337-
num_cudagraph_captured=4,
338-
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
355+
num_graphs_seen=expected_num_graphs_seen,
356+
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
357+
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
358+
num_backend_compilations=expected_num_backend_compilations,
359+
num_cudagraph_captured=expected_num_cudagraph_captured,
339360
):
340361
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
341362

@@ -346,11 +367,11 @@ class C(B): ...
346367

347368
# B also has support_torch_compile
348369
with compilation_counter.expect(
349-
num_graphs_seen=1,
350-
num_piecewise_graphs_seen=3,
351-
num_piecewise_capturable_graphs_seen=2,
352-
num_backend_compilations=2,
353-
num_cudagraph_captured=4,
370+
num_graphs_seen=expected_num_graphs_seen,
371+
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
372+
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
373+
num_backend_compilations=expected_num_backend_compilations,
374+
num_cudagraph_captured=expected_num_cudagraph_captured,
354375
):
355376
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
356377

@@ -361,10 +382,10 @@ class C(B): ...
361382

362383
# C has support_torch_compile
363384
with compilation_counter.expect(
364-
num_graphs_seen=1,
365-
num_piecewise_graphs_seen=3,
366-
num_piecewise_capturable_graphs_seen=2,
367-
num_backend_compilations=2,
368-
num_cudagraph_captured=4,
385+
num_graphs_seen=expected_num_graphs_seen,
386+
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
387+
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
388+
num_backend_compilations=expected_num_backend_compilations,
389+
num_cudagraph_captured=expected_num_cudagraph_captured,
369390
):
370391
run_model(vllm_config, mod_C, cudagraph_runtime_mode)

vllm/compilation/decorators.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,9 @@ def patched_inline_call(self_):
495495
InliningInstructionTranslator, "inline_call_", patched_inline_call
496496
),
497497
torch._dynamo.config.patch(**dynamo_config_patches),
498-
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
498+
maybe_use_cudagraph_partition_wrapper(
499+
self.vllm_config, self.no_weak_ref_output
500+
),
499501
_torch27_patch_tensor_subclasses(),
500502
):
501503
if envs.VLLM_USE_AOT_COMPILE:

vllm/compilation/wrapper.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@ def __init__(self, no_weak_ref_output: bool = False):
9494
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
9595

9696
backend = vllm_config.compilation_config.init_backend(
97-
vllm_config,
98-
no_weak_ref_output
97+
vllm_config, no_weak_ref_output
9998
)
10099
options = {}
101100

0 commit comments

Comments
 (0)