From c80944cbcff6f3f048648cb13bf35b13e240f8d6 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Tue, 5 Aug 2025 14:54:49 -0700 Subject: [PATCH 1/5] Only weakref very last graph Signed-off-by: Yong Hoon Shin --- .../compile/fullgraph/test_multiple_graphs.py | 8 ++-- vllm/compilation/backends.py | 22 ++++++++++- vllm/compilation/decorators.py | 39 ++++++++++++++++++- vllm/compilation/piecewise_backend.py | 3 +- vllm/compilation/wrapper.py | 8 ++-- vllm/config/compilation.py | 8 +++- 6 files changed, 75 insertions(+), 13 deletions(-) diff --git a/tests/compile/fullgraph/test_multiple_graphs.py b/tests/compile/fullgraph/test_multiple_graphs.py index 6d3788af9de0..7d2b33a28114 100644 --- a/tests/compile/fullgraph/test_multiple_graphs.py +++ b/tests/compile/fullgraph/test_multiple_graphs.py @@ -80,7 +80,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -@support_torch_compile +@support_torch_compile(no_weak_ref_output=True) class CompiledAttention(nn.Module): def __init__( self, @@ -144,8 +144,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.hidden_states[:bsz].copy_(x) x = self.attn_one(self.hidden_states[:bsz]) self.hidden_states[:bsz].copy_(x) - x = self.attn_two(self.hidden_states[:bsz]) - return x + y = self.attn_two(self.hidden_states[:bsz]) + # Use value x in the final output to test that value of x + # is not overwritten by call to self.attn_two when using cudagraph + return x + y @torch.inference_mode diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1e66f21ff638..115e6bdbed4b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -386,6 +386,7 @@ def __init__( module: torch.fx.GraphModule, compile_submod_names: list[str], vllm_config: VllmConfig, + no_weak_ref_output: bool, vllm_backend: "VllmBackend", ): super().__init__(module) @@ -398,6 +399,7 @@ def __init__( self.vllm_backend = vllm_backend # When True, it annoyingly dumps the torch.fx.Graph on errors. self.extra_traceback = False + self.no_weak_ref_output = no_weak_ref_output def run(self, *args): fake_args = [ @@ -462,6 +464,15 @@ def call_module( current_platform.get_static_graph_wrapper_cls() ) + # By default, convert output of last graph in a compilation unit + # to a weakref to save some memory. However, if there are >1 + # submodules compiled inside the model, last graph outputs from + # non-last submodule should not be converted to a weakref as it + # may result in memory being overwritten by subsequent graph + # replays. In these cases, no_weak_ref_output can be set to True + weak_ref_output = (piecewise_backend.is_last_graph + and not self.no_weak_ref_output) + # Always assign PIECEWISE runtime mode to the # CUDAGraphWrapper for piecewise_backend, to distinguish # it from the FULL cudagraph runtime mode, no matter it @@ -473,7 +484,7 @@ def call_module( cudagraph_options=CUDAGraphOptions( debug_log_enable=piecewise_backend.is_first_graph, gc_disable=not piecewise_backend.is_first_graph, - weak_ref_output=piecewise_backend.is_last_graph, + weak_ref_output=weak_ref_output, ), ) else: @@ -534,6 +545,7 @@ class VllmBackend: def __init__( self, vllm_config: VllmConfig, + no_weak_ref_output: bool = False, prefix: str = "", ): # if the model is initialized with a non-empty prefix, @@ -557,6 +569,8 @@ def __init__( self.compilation_config ) + self.no_weak_ref_output = no_weak_ref_output + # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -742,7 +756,11 @@ def __call__( # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter( - self.split_gm, submod_names_to_compile, self.vllm_config, self + self.split_gm, + submod_names_to_compile, + self.vllm_config, + self.no_weak_ref_output, + self ).run(*example_inputs) graph_path = os.path.join(local_cache_dir, "computation_graph.py") diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 11a18c0e6bb7..a492db484e3b 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -72,6 +72,13 @@ def support_torch_compile( ) -> Callable[[_T], _T]: ... +@overload +def support_torch_compile( + *, + no_weak_ref_output: bool = False, +) -> Callable[[_T], _T]: ... + + @overload def support_torch_compile( *, @@ -104,6 +111,7 @@ def support_torch_compile( dynamic_arg_dims: dict[str, int | list[int]] | None = None, mark_unbacked_dims: dict[str, int | list[int]] | None = None, enable_if: Callable[[VllmConfig], bool] | None = None, + no_weak_ref_output: bool = False, ) -> Callable[[_T], _T] | _T: """ A decorator to add support for compiling the forward method of a class. @@ -161,6 +169,32 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... dim to be decorated with `mark_unbacked`. This is useful if we would like to enforce that dynamo does not specialize on 0/1 values in the case of dummy input such as for vision model compilation + + If `no_weak_ref_output` is set to `True`, the output of the last graph + of each compiled nn.Module will not be converted to a weakref. + This conversion saves memory but is only safe when the output of the last + graph is not used by any subsequent CUDA graphs in the model forward. + In full cudagraph mode, this is ignored as full cudagraphs are always + captured for the full model. + + This defaults to `True`, because in most cases the entire model is being + compiled, so the assumption that there is no other cuda graph after the last + graph holds. However, in rare cases, multiple submodules are compiled within + a single model. In this case, only the output of the last graph of the last + submodule is safe to be converted to a weakref. For example, if a model has + 2 submodules mod_A and mod_B that are piecewise compiled + graph captured + separately, e.g.: + + def forward(self, x): + a_out = self.mod_A(x) + b_out = self.mod_B(a_out) + return a_out + b_out + + Then the output of mod_A should NOT be converted to a weakref, because the + call to mod_B may overwrite `a_out`. This is because vLLM shares a global + memory pool for all CUDA graphs, causing PyTorch to re-use memory where + possible. To avoid its output from being overwritten, mod_A should specify + `@support_torch_compile(no_weak_ref_output=True)` """ def cls_decorator_helper(cls: _T) -> _T: @@ -199,7 +233,7 @@ def cls_decorator_helper(cls: _T) -> _T: f"Argument {k} not found in the forward method of {cls}" ) return _support_torch_compile( - cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if + cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if, no_weak_ref_output ) if cls is not None: @@ -242,6 +276,7 @@ def _support_torch_compile( dynamic_arg_dims: dict[str, int | list[int]], mark_unbacked_dims: dict[str, int | list[int]] | None = None, enable_if: Callable[[VllmConfig], bool] | None = None, + no_weak_ref_output: bool = False, ) -> _T: """ A decorator to add support for compiling the forward method of a class. @@ -291,7 +326,7 @@ def __init__( compilation_counter.num_models_seen += 1 self.compiled = False - TorchCompileWithNoGuardsWrapper.__init__(self) + TorchCompileWithNoGuardsWrapper.__init__(self, no_weak_ref_output) cls.__init__ = __init__ diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 2931580afbbb..e5c0ba823787 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -51,7 +51,8 @@ def __init__( self.vllm_backend = vllm_backend self.is_first_graph = piecewise_compile_index == 0 - self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 + self.is_last_graph = ( + piecewise_compile_index == total_piecewise_compiles - 1) self.is_full_graph = total_piecewise_compiles == 1 diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 493e57f97f0f..9c05764057f5 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -85,16 +85,18 @@ class TorchCompileWithNoGuardsWrapper: since we drop all guards. """ - def __init__(self): + def __init__(self, no_weak_ref_output: bool = False): self.compiled = False - vllm_config = get_current_vllm_config() self.vllm_config = vllm_config mode = vllm_config.compilation_config.mode if mode is None: raise RuntimeError("Compilation mode cannot be NO_COMPILATION") - backend = vllm_config.compilation_config.init_backend(vllm_config) + backend = vllm_config.compilation_config.init_backend( + vllm_config, + no_weak_ref_output + ) options = {} if isinstance(backend, str) and backend == "inductor": diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ca01cb3fb55d..57ac2900b9df 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -710,7 +710,11 @@ def __post_init__(self) -> None: if self.backend == "": self.backend = current_platform.simple_compile_backend - def init_backend(self, vllm_config: "VllmConfig") -> str | Callable: + def init_backend( + self, + vllm_config: "VllmConfig", + no_weak_ref_output: bool = False, + ) -> str | Callable: """ Initialize the backend for the compilation config from a vllm config. Arguments: @@ -748,7 +752,7 @@ def init_backend(self, vllm_config: "VllmConfig") -> str | Callable: # TODO[@lucaskabela]: See if we can forward prefix # https://github.com/vllm-project/vllm/issues/27045 - return VllmBackend(vllm_config) + return VllmBackend(vllm_config, no_weak_ref_output) def post_init_cudagraph_sizes(self) -> None: """To complete the initialization after cudagraph related From 41b3740954469ad13d03922c2640c4a251281527 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Tue, 26 Aug 2025 12:35:50 -0700 Subject: [PATCH 2/5] Fix no_weak_ref_output not working with subclasses Signed-off-by: Yong Hoon Shin --- tests/compile/test_decorator.py | 84 +++++++++++++++++++++++++++ vllm/compilation/backends.py | 11 ++-- vllm/compilation/counter.py | 2 + vllm/compilation/cuda_graph.py | 2 + vllm/compilation/decorators.py | 12 +++- vllm/compilation/piecewise_backend.py | 3 +- 6 files changed, 106 insertions(+), 8 deletions(-) diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index 1850cc8f1479..c7081fa8008d 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -284,3 +284,87 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch # num_cudagraph_sizes * num cudagraphable graphs to capture ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) + + +def test_no_weak_ref_output_decorator(): + # piecewise + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + ) + ) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + + @support_torch_compile(no_weak_ref_output=False) + class A(nn.Module): + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs + ) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + @support_torch_compile(no_weak_ref_output=True) + @support_torch_compile(no_weak_ref_output=False) + class B(A): ... + + # no_weak_ref_output defaults to False + @support_torch_compile() + class C(B): ... + + with compilation_counter.expect( + num_weakref_output_graphs=1, + # Single compile target (mod A), one VllmBackend initialized + # no_weak_ref_output set to False + ) and set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() + + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) + + with compilation_counter.expect( + num_weakref_output_graphs=0, + ) and set_current_vllm_config(vllm_config): + mod_B = B(vllm_config=vllm_config, prefix="").eval().cuda() + + # B also has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + ): + run_model(vllm_config, mod_B, cudagraph_runtime_mode) + + with compilation_counter.expect( + num_weakref_output_graphs=1, + ) and set_current_vllm_config(vllm_config): + mod_C = C(vllm_config=vllm_config, prefix="").eval().cuda() + + # C has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + ): + run_model(vllm_config, mod_C, cudagraph_runtime_mode) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 115e6bdbed4b..70cd24fdeb4d 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -470,8 +470,9 @@ def call_module( # non-last submodule should not be converted to a weakref as it # may result in memory being overwritten by subsequent graph # replays. In these cases, no_weak_ref_output can be set to True - weak_ref_output = (piecewise_backend.is_last_graph - and not self.no_weak_ref_output) + weak_ref_output = ( + piecewise_backend.is_last_graph and not self.no_weak_ref_output + ) # Always assign PIECEWISE runtime mode to the # CUDAGraphWrapper for piecewise_backend, to distinguish @@ -756,11 +757,11 @@ def __call__( # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter( - self.split_gm, + self.split_gm, submod_names_to_compile, - self.vllm_config, + self.vllm_config, self.no_weak_ref_output, - self + self, ).run(*example_inputs) graph_path = os.path.join(local_cache_dir, "computation_graph.py") diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 20918099f169..801af176e113 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -29,6 +29,8 @@ class CompilationCounter: num_compiled_artifacts_saved: int = 0 # Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE stock_torch_compile_count: int = 0 + # Number of piecewise graphs that will return a weakref output + num_weakref_output_graphs: int = 0 def clone(self) -> "CompilationCounter": return copy.deepcopy(self) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index a2e0abfebc2c..1c230c3b941c 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -181,6 +181,8 @@ def __call__(self, *args, **kwargs): # any other cuda graph. output = weak_ref_tensors(output) + compilation_counter.num_weakref_output_graphs += 1 + # here we always use weak ref for the output # to save memory entry.output = weak_ref_tensors(output) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index a492db484e3b..e6fb7fa97f9e 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -34,6 +34,7 @@ logger = init_logger(__name__) IGNORE_COMPILE_KEY = "_ignore_compile_vllm" +LAST_PIECEWISE_GRAPH_WEAKREF_KEY = "_last_graph_weakref_vllm" _T = TypeVar("_T", bound=type[nn.Module]) @@ -233,7 +234,11 @@ def cls_decorator_helper(cls: _T) -> _T: f"Argument {k} not found in the forward method of {cls}" ) return _support_torch_compile( - cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if, no_weak_ref_output + cls, + inferred_dynamic_arg_dims, + mark_unbacked_dims, + enable_if, + no_weak_ref_output, ) if cls is not None: @@ -294,6 +299,9 @@ def _support_torch_compile( setattr(cls, IGNORE_COMPILE_KEY, False) + # setting as attribute on cls ensures child class will override parent class + setattr(cls, LAST_PIECEWISE_GRAPH_WEAKREF_KEY, no_weak_ref_output) + def __init__( self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs ): @@ -324,6 +332,8 @@ def __init__( if self.do_not_compile: return + no_weak_ref_output = getattr(cls, LAST_PIECEWISE_GRAPH_WEAKREF_KEY, False) + compilation_counter.num_models_seen += 1 self.compiled = False TorchCompileWithNoGuardsWrapper.__init__(self, no_weak_ref_output) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index e5c0ba823787..2931580afbbb 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -51,8 +51,7 @@ def __init__( self.vllm_backend = vllm_backend self.is_first_graph = piecewise_compile_index == 0 - self.is_last_graph = ( - piecewise_compile_index == total_piecewise_compiles - 1) + self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_full_graph = total_piecewise_compiles == 1 From 02c0eae02986c07cdccca429f4e70c556860717c Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Tue, 11 Nov 2025 23:02:39 -0800 Subject: [PATCH 3/5] Make no_weak_ref_output=True work with use_inductor_graph_partition=True Signed-off-by: Yong Hoon Shin --- vllm/compilation/decorators.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index e6fb7fa97f9e..5cf3145292af 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -332,7 +332,7 @@ def __init__( if self.do_not_compile: return - no_weak_ref_output = getattr(cls, LAST_PIECEWISE_GRAPH_WEAKREF_KEY, False) + self.no_weak_ref_output = getattr(cls, LAST_PIECEWISE_GRAPH_WEAKREF_KEY, False) compilation_counter.num_models_seen += 1 self.compiled = False @@ -523,7 +523,10 @@ def patched_inline_call(self_): @contextlib.contextmanager -def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): +def maybe_use_cudagraph_partition_wrapper( + vllm_config: VllmConfig, + no_weak_ref_output: bool, +): """ Context manager to set/unset customized cudagraph partition wrappers. @@ -553,6 +556,13 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata): partition_id = metadata.partition_index num_partitions = metadata.num_partitions + + # If no_weak_ref_output passed to compile decorator + # do not convert last partition's output to a weakref + weak_ref_output = ( + partition_id == num_partitions - 1 and not no_weak_ref_output + ) + return static_graph_wrapper_class( runnable=f, vllm_config=vllm_config, @@ -560,7 +570,7 @@ def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata): cudagraph_options=CUDAGraphOptions( debug_log_enable=partition_id == 0, gc_disable=partition_id != 0, - weak_ref_output=partition_id == num_partitions - 1, + weak_ref_output=weak_ref_output, ), ) From c5da93eb94ee96e505e6ef4f6d9e3adf810aab37 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Tue, 11 Nov 2025 23:20:52 -0800 Subject: [PATCH 4/5] Fix test_no_weak_ref_output_decorator Signed-off-by: Yong Hoon Shin --- tests/compile/test_decorator.py | 57 ++++++++++++++++++++++----------- vllm/compilation/decorators.py | 4 ++- vllm/compilation/wrapper.py | 3 +- 3 files changed, 43 insertions(+), 21 deletions(-) diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index c7081fa8008d..c771eeeb408b 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -286,18 +286,40 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch run_model(vllm_config, mod_A, cudagraph_runtime_mode) -def test_no_weak_ref_output_decorator(): +@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) +def test_no_weak_ref_output_decorator(use_inductor_graph_partition, monkeypatch): + # disable compile cache so that we can count the number of compilations + # appropriately + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + # piecewise vllm_config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], + use_inductor_graph_partition=use_inductor_graph_partition, ) ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + expected_num_graphs_seen = 1 + expected_num_cudagraph_captured = ( + 4 # num_cudagraph_sizes * num cudagraphs to capture + ) + if use_inductor_graph_partition: + expected_num_piecewise_graphs_seen = 1 + expected_num_piecewise_capturable_graphs_seen = 1 + expected_num_backend_compilations = 1 + else: + expected_num_piecewise_graphs_seen = 3 + expected_num_piecewise_capturable_graphs_seen = 2 + expected_num_backend_compilations = 2 + @support_torch_compile(no_weak_ref_output=False) class A(nn.Module): def __init__( @@ -330,12 +352,11 @@ class C(B): ... # A has support_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=expected_num_graphs_seen, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, + num_cudagraph_captured=expected_num_cudagraph_captured, ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) @@ -346,11 +367,11 @@ class C(B): ... # B also has support_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, + num_graphs_seen=expected_num_graphs_seen, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, + num_cudagraph_captured=expected_num_cudagraph_captured, ): run_model(vllm_config, mod_B, cudagraph_runtime_mode) @@ -361,10 +382,10 @@ class C(B): ... # C has support_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, + num_graphs_seen=expected_num_graphs_seen, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, + num_cudagraph_captured=expected_num_cudagraph_captured, ): run_model(vllm_config, mod_C, cudagraph_runtime_mode) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 5cf3145292af..f5b3582815a0 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -495,7 +495,9 @@ def patched_inline_call(self_): InliningInstructionTranslator, "inline_call_", patched_inline_call ), torch._dynamo.config.patch(**dynamo_config_patches), - maybe_use_cudagraph_partition_wrapper(self.vllm_config), + maybe_use_cudagraph_partition_wrapper( + self.vllm_config, self.no_weak_ref_output + ), _torch27_patch_tensor_subclasses(), ): if envs.VLLM_USE_AOT_COMPILE: diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 9c05764057f5..3815b0815fd9 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -94,8 +94,7 @@ def __init__(self, no_weak_ref_output: bool = False): raise RuntimeError("Compilation mode cannot be NO_COMPILATION") backend = vllm_config.compilation_config.init_backend( - vllm_config, - no_weak_ref_output + vllm_config, no_weak_ref_output ) options = {} From 0250007d1a4385b7bdeb57bb16ddee852d43ff2a Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 19 Nov 2025 14:50:03 -0800 Subject: [PATCH 5/5] Fix lint Signed-off-by: Yong Hoon Shin --- vllm/compilation/wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 3815b0815fd9..524ff7bdb2e0 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -124,7 +124,7 @@ def __init__(self, no_weak_ref_output: bool = False): if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE: torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) - self._compiled_bytecode = None + self._compiled_bytecode: CodeType | None = None def aot_compile(self, *args, **kwargs): if not hasattr(self._compiled_callable, "aot_compile"):