Skip to content

Commit b30cc63

Browse files
committed
Fix test_no_weak_ref_output_decorator
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent a330e95 commit b30cc63

File tree

2 files changed

+39
-23
lines changed

2 files changed

+39
-23
lines changed

tests/compile/test_decorator.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -289,18 +289,40 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
289289
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
290290

291291

292-
def test_no_weak_ref_output_decorator():
292+
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
293+
def test_no_weak_ref_output_decorator(use_inductor_graph_partition, monkeypatch):
294+
# disable compile cache so that we can count the number of compilations
295+
# appropriately
296+
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
297+
298+
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
299+
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
300+
293301
# piecewise
294302
vllm_config = VllmConfig(
295303
compilation_config=CompilationConfig(
296304
mode=CompilationMode.VLLM_COMPILE,
297305
use_cudagraph=True,
298-
splitting_ops=["silly.attention"],
306+
splitting_ops=["silly::attention"],
299307
cudagraph_capture_sizes=[1, 2],
308+
use_inductor_graph_partition=use_inductor_graph_partition,
300309
)
301310
)
302311
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
303312

313+
expected_num_graphs_seen = 1
314+
expected_num_cudagraph_captured = (
315+
4 # num_cudagraph_sizes * num cudagraphs to capture
316+
)
317+
if use_inductor_graph_partition:
318+
expected_num_piecewise_graphs_seen = 1
319+
expected_num_piecewise_capturable_graphs_seen = 1
320+
expected_num_backend_compilations = 1
321+
else:
322+
expected_num_piecewise_graphs_seen = 3
323+
expected_num_piecewise_capturable_graphs_seen = 2
324+
expected_num_backend_compilations = 2
325+
304326
@support_torch_compile(no_weak_ref_output=False)
305327
class A(nn.Module):
306328
def __init__(
@@ -333,12 +355,11 @@ class C(B): ...
333355

334356
# A has support_torch_compile
335357
with compilation_counter.expect(
336-
num_graphs_seen=1,
337-
num_piecewise_graphs_seen=3,
338-
num_piecewise_capturable_graphs_seen=2,
339-
num_backend_compilations=2,
340-
num_cudagraph_captured=4,
341-
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
358+
num_graphs_seen=expected_num_graphs_seen,
359+
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
360+
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
361+
num_backend_compilations=expected_num_backend_compilations,
362+
num_cudagraph_captured=expected_num_cudagraph_captured,
342363
):
343364
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
344365

@@ -349,11 +370,11 @@ class C(B): ...
349370

350371
# B also has support_torch_compile
351372
with compilation_counter.expect(
352-
num_graphs_seen=1,
353-
num_piecewise_graphs_seen=3,
354-
num_piecewise_capturable_graphs_seen=2,
355-
num_backend_compilations=2,
356-
num_cudagraph_captured=4,
373+
num_graphs_seen=expected_num_graphs_seen,
374+
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
375+
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
376+
num_backend_compilations=expected_num_backend_compilations,
377+
num_cudagraph_captured=expected_num_cudagraph_captured,
357378
):
358379
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
359380

@@ -364,10 +385,10 @@ class C(B): ...
364385

365386
# C has support_torch_compile
366387
with compilation_counter.expect(
367-
num_graphs_seen=1,
368-
num_piecewise_graphs_seen=3,
369-
num_piecewise_capturable_graphs_seen=2,
370-
num_backend_compilations=2,
371-
num_cudagraph_captured=4,
388+
num_graphs_seen=expected_num_graphs_seen,
389+
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
390+
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
391+
num_backend_compilations=expected_num_backend_compilations,
392+
num_cudagraph_captured=expected_num_cudagraph_captured,
372393
):
373394
run_model(vllm_config, mod_C, cudagraph_runtime_mode)

vllm/compilation/decorators.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,6 @@ def _support_torch_compile(
286286
"""
287287
A decorator to add support for compiling the forward method of a class.
288288
"""
289-
setattr(cls, IGNORE_COMPILE_KEY, False)
290-
291-
# setting as attribute on cls ensures child class will override parent class
292-
setattr(cls, LAST_PIECEWISE_GRAPH_WEAKREF_KEY, no_weak_ref_output)
293-
294289
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
295290
# support decorating multiple times
296291
return cls

0 commit comments

Comments
 (0)