From 84b93a5bc26b780384de1281e03e2ebeb2f58e70 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Wed, 17 Sep 2025 14:05:22 -0700 Subject: [PATCH] remove the bytecode hook and replace TorchCompileWrapperWithCustomDispatcher with TorchCompileGuardsStripWrapper Signed-off-by: Laith Sakka --- .../compile/piecewise/test_multiple_graphs.py | 11 +- tests/compile/piecewise/test_simple.py | 3 + tests/compile/piecewise/test_toy_llama.py | 9 +- tests/compile/test_wrapper.py | 155 +++++++++--- .../multimodal/generation/test_qwen2_5_vl.py | 10 + tests/v1/e2e/test_spec_decode.py | 8 + vllm/compilation/decorators.py | 234 +++++++++--------- vllm/compilation/wrapper.py | 212 ++++++++++------ vllm/envs.py | 6 + vllm/v1/worker/tpu_model_runner.py | 10 +- 10 files changed, 422 insertions(+), 236 deletions(-) diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index 64d626bae483..6d3788af9de0 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -22,6 +22,8 @@ from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils.torch_utils import is_torch_equal_or_newer +from ...utils import create_new_process_for_each_test + # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 @@ -193,7 +195,14 @@ def run_model( @pytest.mark.parametrize("use_inductor_graph_partition", [False, True]) -def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool): +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) +@create_new_process_for_each_test("spawn") +def test_multi_graph_piecewise_compile( + use_inductor_graph_partition: bool, use_bytecode_hook: bool, monkeypatch +): + # Set the environment variable for this test + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index a48af8a8952a..e258133ab50a 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -21,6 +21,8 @@ from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils.torch_utils import is_torch_equal_or_newer +from ...utils import create_new_process_for_each_test + # This import automatically registers `torch.ops.silly.attention` from ..silly_attention import get_global_counter, reset_global_counter @@ -124,6 +126,7 @@ def _run_simple_model( @pytest.mark.parametrize("use_inductor", [True, False]) @torch.inference_mode() +@create_new_process_for_each_test("spawn") def test_simple_piecewise_compile(use_inductor): _run_simple_model( splitting_ops=["silly::attention"], diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 92998ede1699..915fbc6ce7f3 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -29,6 +29,8 @@ from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils.torch_utils import is_torch_equal_or_newer +from ...utils import create_new_process_for_each_test + # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 @@ -334,6 +336,7 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor: ("inductor", True), # Inductor, Inductor partition ], ) +@create_new_process_for_each_test("spawn") def test_toy_llama( backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path ): @@ -513,4 +516,8 @@ def benchmark(): if __name__ == "__main__": - benchmark() + # Protect against subprocess reimport when using spawn_new_process_for_each_test + import os + + if os.environ.get("RUNNING_IN_SUBPROCESS") != "1": + benchmark() diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index da0afd9eaa49..356cac7af258 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -2,59 +2,134 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +import pytest import torch -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import CompilationMode +from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper +from vllm.config import ( + CompilationConfig, + CompilationMode, + VllmConfig, + set_current_vllm_config, +) class MyMod(torch.nn.Module): def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None): - if cache is not None: - return x + cache - return x * 2 + if x.size()[0] >= 4: + return x * 2 + else: + return x * 100 -class MyWrapper(TorchCompileWrapperWithCustomDispatcher): +class MyWrapper(TorchCompileWithNoGuardsWrapper): def __init__(self, model): self.model = model - compiled_callable = torch.compile(self.forward, backend="eager") - super().__init__( - compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE - ) + super().__init__() - def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None): + def forward(self, x: torch.Tensor): # type: ignore[override] # this is the function to be compiled - return self.model(x, cache) - - def __call__(self, x: torch.Tensor, cache: torch.Tensor | None = None): - # let torch.compile compile twice - if len(self.compiled_codes) == 2: - dispatch_id = 0 if cache is None else 1 - with self.dispatch_to_code(dispatch_id): - return self.forward(x, cache) - else: - return self.compiled_callable(x, cache) + return self.model(x) + +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) +def test_torch_compile_wrapper(use_bytecode_hook, monkeypatch): + """Test basic functionality of TorchCompileWithNoGuardsWrapper.""" + # Set the environment variable for this test + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") -def test_torch_compile_wrapper(): - mod = MyMod() - wrappers = [] - for i in range(3): + # Create a proper vLLM config instead of mocking + vllm_config = VllmConfig() + vllm_config.compilation_config = CompilationConfig() + vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE + vllm_config.compilation_config.backend = "inductor" + + # Test DYNAMO_TRACE_ONCE + with set_current_vllm_config(vllm_config): torch._dynamo.reset() + mod = MyMod() + wrapper = MyWrapper(mod) + + # First call should trigger compilation + x = torch.tensor([1, 2, 3, 4]) + torch._dynamo.mark_dynamic(x, 0) + + result1 = wrapper(x) + expected1 = torch.tensor([2, 4, 6, 8]) + assert torch.allclose(result1, expected1), ( + f"Expected {expected1}, got {result1}" + ) + + # Second call should use compiled code + x2 = torch.tensor([1, 2, 3]) + result2 = wrapper(x2) + expected2 = torch.tensor([2, 4, 6]) + assert torch.allclose(result2, expected2), ( + f"Expected {expected2}, got {result2}" + ) + + # without the wrapper result would be different. + result3 = mod(x2) + expected3 = torch.tensor([100, 200, 300]) + + assert torch.allclose(result3, expected3), ( + f"Expected {result3}, got {expected3}" + ) + + # with STOCK_TORCH_COMPILE we do not remove guards. + vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE + torch._dynamo.reset() + with set_current_vllm_config(vllm_config): + mod = MyMod() wrapper = MyWrapper(mod) - wrappers.append(wrapper) - x = torch.tensor([1]) - wrapper(x, None) # profile run, compile - # create a cache tensor - cache = torch.tensor([2]) - wrapper(x, cache) # warm up with cache, recompile - - # for new input, dispatch to the compiled code directly - new_x = torch.tensor([3]) - assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code - assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code - - for wrapper in wrappers: - # make sure they have independent compiled codes - assert len(wrapper.compiled_codes) == 2 + + # First call should trigger compilation + x = torch.tensor([1, 2, 3, 4]) + torch._dynamo.mark_dynamic(x, 0) + + result1 = wrapper(x) + expected1 = torch.tensor([2, 4, 6, 8]) + assert torch.allclose(result1, expected1), ( + f"Expected {expected1}, got {result1}" + ) + + # Second call should triger another compilation + x2 = torch.tensor([1, 2, 3]) + result2 = wrapper(x2) + expected2 = torch.tensor([100, 200, 300]) + assert torch.allclose(result2, expected2), ( + f"Expected {expected2}, got {result2}" + ) + + # NO_COMPILATION level not supported. + vllm_config.compilation_config.mode = None + torch._dynamo.reset() + with set_current_vllm_config(vllm_config): + torch._dynamo.reset() + mod = MyMod() + + try: + wrapper = MyWrapper(mod) + except Exception: + return + raise AssertionError("expected an exception to be raised") + + +if __name__ == "__main__": + # Run with both parameter values + + class MockMonkeypatch: + def setenv(self, name, value): + os.environ[name] = value + + mp = MockMonkeypatch() + + print("Testing with VLLM_USE_BYTECODE_HOOK=False") + test_torch_compile_wrapper(False, mp) + + print("Testing with VLLM_USE_BYTECODE_HOOK=True") + test_torch_compile_wrapper(True, mp) + + print("All tests passed!") diff --git a/tests/models/multimodal/generation/test_qwen2_5_vl.py b/tests/models/multimodal/generation/test_qwen2_5_vl.py index 6b009075abfa..3ba665710af4 100644 --- a/tests/models/multimodal/generation/test_qwen2_5_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_5_vl.py @@ -34,6 +34,7 @@ def qwen2_5_vl_chat_template(*query): @pytest.mark.parametrize("num_frames", [16]) @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) def test_qwen2_5_vl_evs_functionality( vllm_runner, video_assets, @@ -42,10 +43,14 @@ def test_qwen2_5_vl_evs_functionality( num_frames: int, dtype: str, max_tokens: int, + use_bytecode_hook: bool, + monkeypatch, ) -> None: """Test EVS (Efficient Video Sampling) functionality with different pruning rates. """ + # Set the environment variable for this test + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") # Sample frames from video assets sampled_vids = [ @@ -86,6 +91,7 @@ def test_qwen2_5_vl_evs_functionality( @pytest.mark.parametrize("num_frames", [16]) @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) def test_qwen2_5_vl_evs_batched_videos( vllm_runner, video_assets, @@ -94,6 +100,8 @@ def test_qwen2_5_vl_evs_batched_videos( num_frames: int, dtype: str, max_tokens: int, + use_bytecode_hook: bool, + monkeypatch, ) -> None: """Test EVS functionality with batched videos. @@ -102,6 +110,8 @@ def test_qwen2_5_vl_evs_batched_videos( 2. Both pruning configurations work with multiple videos 3. The model doesn't crash when processing multiple videos simultaneously """ + # Set the environment variable for this test + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") # Sample frames from video assets sampled_vids = [ sample_frames_from_video(asset.np_ndarrays, num_frames) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 4a6b84ae4817..b81ab5c65da9 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -75,6 +75,14 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" +@pytest.fixture(autouse=True) +def reset_torch_dynamo(): + """Reset torch dynamo cache before each test""" + yield + # Cleanup after test + torch._dynamo.reset() + + @pytest.mark.parametrize( "speculative_config", [ diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 0946fa69171b..e325bca73abb 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -17,7 +17,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper from vllm.config import ( CompilationMode, VllmConfig, @@ -246,14 +246,14 @@ def _support_torch_compile( """ A decorator to add support for compiling the forward method of a class. """ - if TorchCompileWrapperWithCustomDispatcher in cls.__bases__: + if TorchCompileWithNoGuardsWrapper in cls.__bases__: # support decorating multiple times return cls # take care of method resolution order # make sure super().__init__ is called on the base class - # other than TorchCompileWrapperWithCustomDispatcher - cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,) + # other than TorchCompileWithNoGuardsWrapper + cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,) old_init = cls.__init__ @@ -290,12 +290,43 @@ def __init__( return compilation_counter.num_models_seen += 1 - TorchCompileWrapperWithCustomDispatcher.__init__( - self, compilation_mode=vllm_config.compilation_config.mode - ) + self.compiled = False + TorchCompileWithNoGuardsWrapper.__init__(self) cls.__init__ = __init__ + def _mark_dynamic_inputs(mod, *args, **kwargs): + sig = inspect.signature(mod.__class__.forward) + bound_args = sig.bind(mod, *args, **kwargs) + bound_args.apply_defaults() + for k, dims in dynamic_arg_dims.items(): + arg = bound_args.arguments.get(k) + if arg is not None: + dims = [dims] if isinstance(dims, int) else dims + if isinstance(arg, torch.Tensor): + # In case dims is specified with negative indexing + dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] + torch._dynamo.mark_dynamic(arg, dims) + elif isinstance(arg, IntermediateTensors): + for tensor in arg.tensors.values(): + # In case dims is specified with negative indexing + dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims] + torch._dynamo.mark_dynamic(tensor, dims) + else: + raise ValueError( + "Unsupported dynamic dimensions" + f" {dims} for argument {k} with type {type(arg)}." + ) + if mark_unbacked_dims: + for k, dims in mark_unbacked_dims.items(): + arg = bound_args.arguments.get(k) + if arg is not None: + dims = [dims] if isinstance(dims, int) else dims + if isinstance(arg, torch.Tensor): + # In case dims is specified with negative indexing + dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] + torch._dynamo.decorators.mark_unbacked(arg, dims) + def __call__(self, *args, **kwargs): # torch.compiler.is_compiling() means we are inside the compilation # e.g. TPU has the compilation logic in model runner, so we don't @@ -303,6 +334,7 @@ def __call__(self, *args, **kwargs): if self.do_not_compile or torch.compiler.is_compiling(): return self.forward(*args, **kwargs) + # if aot_compiled_fn is set, just call it. if getattr(self, "aot_compiled_fn", None) is not None: return self.aot_compiled_fn(self, *args, **kwargs) @@ -362,120 +394,84 @@ def __call__(self, *args, **kwargs): ) return self.aot_compiled_fn(self, *args, **kwargs) + if self.compiled: + assert not envs.VLLM_USE_AOT_COMPILE + return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) + + # This is the path for the first compilation. + # the first compilation needs to have dynamic shapes marked - if len(self.compiled_codes) < 1: - sig = inspect.signature(self.__class__.forward) - bound_args = sig.bind(self, *args, **kwargs) - bound_args.apply_defaults() - for k, dims in dynamic_arg_dims.items(): - arg = bound_args.arguments.get(k) - if arg is not None: - dims = [dims] if isinstance(dims, int) else dims - if isinstance(arg, torch.Tensor): - # In case dims is specified with negative indexing - dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] - torch._dynamo.mark_dynamic(arg, dims) - elif isinstance(arg, IntermediateTensors): - for tensor in arg.tensors.values(): - # In case dims is specified with negative indexing - dims = [ - tensor.ndim + dim if dim < 0 else dim for dim in dims - ] - torch._dynamo.mark_dynamic(tensor, dims) - else: - raise ValueError( - "Unsupported dynamic dimensions" - f" {dims} for argument {k} with type {type(arg)}." - ) - if mark_unbacked_dims: - for k, dims in mark_unbacked_dims.items(): - arg = bound_args.arguments.get(k) - if arg is not None: - dims = [dims] if isinstance(dims, int) else dims - if isinstance(arg, torch.Tensor): - # In case dims is specified with negative indexing - dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] - torch._dynamo.decorators.mark_unbacked(arg, dims) - # here, it is the starting point of the `torch.compile` process - start_monitoring_torch_compile(self.vllm_config) - logger.debug("Start compiling function %s", self.original_code_object) - - # if we don't use custom dispatcher, we can directly call the - # compiled function and let torch.compile handle the dispatching, - # with the overhead of guard evaluation and recompilation. - if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: - # it seems Dynamo reuse the compilation across instances, - # while we need to make sure the compiled code is not reused. - # we need to control all the compilation of the model. - torch._dynamo.eval_frame.remove_from_cache(self.original_code_object) - - # collect all relevant files traced by Dynamo, - # so that the compilation cache can trigger re-compilation - # properly when any of these files change. - - # 1. the file containing the top-level forward function - self.vllm_config.compilation_config.traced_files.add( - self.original_code_object.co_filename - ) + _mark_dynamic_inputs(self, *args, **kwargs) - # 2. every time Dynamo sees a function call, it will inline - # the function by calling InliningInstructionTranslator.inline_call_ - # we hijack this function to know all the functions called - # during Dynamo tracing, and their corresponding files - inline_call = InliningInstructionTranslator.inline_call_ - - def patched_inline_call(self_): - code = self_.f_code - self.vllm_config.compilation_config.traced_files.add(code.co_filename) - return inline_call(self_) - - # Disable the C++ compilation of symbolic shape guards. C++-fication - # of symbolic shape guards can improve guard overhead. But, since - # vllm skip guards anyways, setting this flag to False can improve - # compile time. - dynamo_config_patches = {} - try: - _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards - dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False - except AttributeError: - # Note: this config is not available in torch 2.6, we can skip - # if the config doesn't exist - logger.debug("enable_cpp_symbolic_shape_guards config not available") - - with ( - patch.object( - InliningInstructionTranslator, "inline_call_", patched_inline_call - ), - torch._dynamo.config.patch(**dynamo_config_patches), - maybe_use_cudagraph_partition_wrapper(self.vllm_config), - _torch27_patch_tensor_subclasses(), - ): - if envs.VLLM_USE_AOT_COMPILE: - self.aot_compiled_fn = self.aot_compile(*args, **kwargs) - output = self.aot_compiled_fn(self, *args, **kwargs) - assert aot_compilation_path is not None - assert cache_dir is not None - try: - os.makedirs(cache_dir, exist_ok=True) - self.aot_compiled_fn.save_compiled_function( - aot_compilation_path - ) - except Exception as e: - logger.warning( - "Cannot save aot compilation to path %s, error: %s", - aot_compilation_path, - str(e), - ) - else: - output = self.compiled_callable(*args, **kwargs) - return output - - # usually, capturing the model once is enough, and then we can - # dispatch to the compiled code directly, without going through - # the Dynamo guard mechanism. - with self.dispatch_to_code(0): - model_output = self.forward(*args, **kwargs) - return model_output + # here, it is the starting point of the `torch.compile` process + start_monitoring_torch_compile(self.vllm_config) + original_code_object = self.original_code_object() + logger.debug("Start compiling function %s", original_code_object) + + # we do not want tp delete the original code object entries since + # we depend on them now to look up cached compiled functions. + # torch._dynamo.eval_frame.remove_from_cache(original_code_object) + + # collect all relevant files traced by Dynamo, + # so that the compilation cache can trigger re-compilation + # properly when any of these files change. + + # 1. the file containing the top-level forward function + self.vllm_config.compilation_config.traced_files.add( + original_code_object.co_filename + ) + + # 2. every time Dynamo sees a function call, it will inline + # the function by calling InliningInstructionTranslator.inline_call_ + # we hijack this function to know all the functions called + # during Dynamo tracing, and their corresponding files + inline_call = InliningInstructionTranslator.inline_call_ + + def patched_inline_call(self_): + code = self_.f_code + self.vllm_config.compilation_config.traced_files.add(code.co_filename) + return inline_call(self_) + + # Disable the C++ compilation of symbolic shape guards. C++-fication + # of symbolic shape guards can improve guard overhead. But, since + # vllm skip guards anyways, setting this flag to False can improve + # compile time. + dynamo_config_patches = {} + try: + _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards + dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False + except AttributeError: + # Note: this config is not available in torch 2.6, we can skip + # if the config doesn't exist + logger.debug("enable_cpp_symbolic_shape_guards config not available") + + with ( + patch.object( + InliningInstructionTranslator, "inline_call_", patched_inline_call + ), + torch._dynamo.config.patch(**dynamo_config_patches), + maybe_use_cudagraph_partition_wrapper(self.vllm_config), + _torch27_patch_tensor_subclasses(), + ): + if envs.VLLM_USE_AOT_COMPILE: + self.aot_compiled_fn = self.aot_compile(*args, **kwargs) + output = self.aot_compiled_fn(self, *args, **kwargs) + assert aot_compilation_path is not None + assert cache_dir is not None + try: + os.makedirs(cache_dir, exist_ok=True) + self.aot_compiled_fn.save_compiled_function(aot_compilation_path) + except Exception as e: + logger.warning( + "Cannot save aot compilation to path %s, error: %s", + aot_compilation_path, + str(e), + ) + else: + output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) + + self.compiled = True + return output cls.__call__ = __call__ return cls diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 4d26619bd128..493e57f97f0f 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -4,11 +4,11 @@ import os import sys from abc import abstractmethod -from collections.abc import Callable from contextlib import contextmanager from types import CodeType import torch +import torch._C._dynamo.guards import vllm.envs as envs from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config @@ -17,88 +17,153 @@ logger = init_logger(__name__) -class TorchCompileWrapperWithCustomDispatcher: +def _noop_add_global_state_guard(self, *args, **kwargs): + """No-op to skip the GLOBAL_STATE guard entirely""" + pass + + +def _noop_add_torch_function_mode_stack_guard(self, *args, **kwargs): + """No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely""" + pass + + +@contextmanager +def _compilation_context(): + """Context manager for compilation settings and patches. + + This manager: + 1. Sets higher dynamo cache limits for compilation. (Needed for + qwen2_5_vl see test_qwen2_5_vl_evs_functionality). + Generally a recompilation can happen whenever we use a new + backend instance in torch.compile. + 2. Patches out add_global_state_guard to skip GLOBAL_STATE guards + 3. Patches out add_torch_function_mode_stack_guard to skip + TORCH_FUNCTION_MODE_STACK guards. + 4. Restores everything when compilation completes """ - A wrapper class for torch.compile, with a custom dispatch logic. - Subclasses should: - 1. Implement the forward method - 2. Implement the dispatch logic in the __call__ method - It can use `self.compiled_codes` to access the compiled bytecode, - and `with self.dispatch_to_code(index):` to dispatch to - the compiled code. - 3. Implement the `__init__` method to determine how to call - `torch.compile` over the forward method. + # Save original values + original_global_state_guard = ( + torch._C._dynamo.guards.GuardManager.add_global_state_guard + ) + original_torch_function_mode_stack_guard = ( + torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard + ) + original_cache_size = torch._dynamo.config.cache_size_limit + original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit + + try: + # Set higher cache limits for compilation + torch._dynamo.config.cache_size_limit = 2048 + torch._dynamo.config.accumulated_cache_size_limit = 8192 + + # Patch guard manager + torch._C._dynamo.guards.GuardManager.add_global_state_guard = ( + _noop_add_global_state_guard + ) + torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = ( + _noop_add_torch_function_mode_stack_guard + ) + yield + finally: + # Restore original values + torch._C._dynamo.guards.GuardManager.add_global_state_guard = ( + original_global_state_guard + ) + torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = ( + original_torch_function_mode_stack_guard + ) + torch._dynamo.config.cache_size_limit = original_cache_size + torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache + + +class TorchCompileWithNoGuardsWrapper: """ + A wrapper class for torch.compile, it ensures that all guards are dropped + when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE. + When guards are dropped, the first time __call__ is invoked, a single + compilation is triggered. Dynamo should never be traced again after that + since we drop all guards. + """ + + def __init__(self): + self.compiled = False - def __init__( - self, - compiled_callable: Callable | None = None, - compilation_mode: CompilationMode = CompilationMode.NONE, - ): vllm_config = get_current_vllm_config() self.vllm_config = vllm_config - if compiled_callable is None: - # default compilation settings - # compiling the forward method - - backend = vllm_config.compilation_config.init_backend(vllm_config) - options = None - if isinstance(backend, str) and backend == "inductor": - options = ( - get_current_vllm_config().compilation_config.inductor_compile_config - ) - if envs.VLLM_USE_AOT_COMPILE: - options = options or {} - # This effectively drop all the guards. - # We need this because bytecode hook is not used any more to - # drop guards in the AOT compile mode. - options["guard_filter_fn"] = lambda guards: [False for _ in guards] - if hasattr(torch._dynamo.config, "enable_aot_compile"): - torch._dynamo.config.enable_aot_compile = True - else: - msg = "torch._dynamo.config.enable_aot_compile is not " - msg += "available. AOT compile is disabled and please " - msg += "upgrade PyTorch version to use AOT compile." - logger.warning(msg) - - compiled_callable = torch.compile( - self.forward, fullgraph=True, backend=backend, options=options - ) - - self.compiled_callable = compiled_callable - self.original_code_object = self.__class__.forward.__code__ - self.compiled_codes: list[CodeType] = [] - torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) - - # read the env var to determine whether to use the custom dispatcher - # subclasses can use this to switch between the custom dispatcher - # and the default Dynamo guard mechanism. - self.use_custom_dispatcher: bool = ( - compilation_mode >= CompilationMode.DYNAMO_TRACE_ONCE + mode = vllm_config.compilation_config.mode + if mode is None: + raise RuntimeError("Compilation mode cannot be NO_COMPILATION") + + backend = vllm_config.compilation_config.init_backend(vllm_config) + options = {} + + if isinstance(backend, str) and backend == "inductor": + options = vllm_config.compilation_config.inductor_compile_config + + if mode != CompilationMode.STOCK_TORCH_COMPILE: + # Drop all the guards. + options["guard_filter_fn"] = lambda x: [False for _ in x] + + if envs.VLLM_USE_AOT_COMPILE: + if hasattr(torch._dynamo.config, "enable_aot_compile"): + torch._dynamo.config.enable_aot_compile = True + else: + msg = "torch._dynamo.config.enable_aot_compile is not " + msg += "available. AOT compile is disabled and please " + msg += "upgrade PyTorch version to use AOT compile." + logger.warning(msg) + + self._compiled_callable = torch.compile( + self.forward, + fullgraph=True, + dynamic=False, + backend=backend, + options=options, ) + if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE: + torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + self._compiled_bytecode = None + def aot_compile(self, *args, **kwargs): - if not hasattr(self.compiled_callable, "aot_compile"): + if not hasattr(self._compiled_callable, "aot_compile"): raise RuntimeError( "aot_compile is not supported by the current configuration. " + "Please make sure torch.compile is enabled with the latest " + f"version of PyTorch (current using torch: {torch.__version__})" ) - return self.compiled_callable.aot_compile((args, kwargs)) + return self._compiled_callable.aot_compile((args, kwargs)) def __call__(self, *args, **kwargs): - """Implement the dispatch logic here, beyond the torch.compile mode. - NOTE: this function can have additional arguments beyond the forward - method, for directly dispatching to the compiled code. - """ - return self.compiled_callable(*args, **kwargs) + if envs.VLLM_USE_BYTECODE_HOOK: + if ( + self.vllm_config.compilation_config.mode + == CompilationMode.STOCK_TORCH_COMPILE + ): + return self._compiled_callable(*args, **kwargs) + + if not self._compiled_bytecode: + # Make sure a compilation is triggered by clearing dynamo + # cache. + torch._dynamo.eval_frame.remove_from_cache(self.original_code_object()) + return self._compiled_callable(*args, **kwargs) + else: + with self._dispatch_to_compiled_code(): + return self.forward(*args, **kwargs) + else: + with _compilation_context(): + return self._compiled_callable(*args, **kwargs) @abstractmethod def forward(self, *args, **kwargs): ... + def original_code_object(self) -> CodeType: + """Return the original code object of the forward method.""" + return self.__class__.forward.__code__ + def bytecode_hook(self, old_code: CodeType, new_code: CodeType): """Hook to save the compiled bytecode for direct execution.""" - if old_code is not self.original_code_object: + if old_code is not self.original_code_object(): return # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 frame = sys._getframe() @@ -114,7 +179,7 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): if frame.f_locals["self"] is not self: return - self.compiled_codes.append(new_code) + self._compiled_bytecode = new_code path = self.vllm_config.compile_debug_dump_path() if path: @@ -153,16 +218,21 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): raise RuntimeError(msg) @contextmanager - def dispatch_to_code(self, index: int): - """Context manager to dispatch to the compiled code. + def _dispatch_to_compiled_code(self): + # noqa: E501 + """ + Context manager to dispatch to internally compiled code for torch<2.8. Why does this work? Because Dynamo guarantees that the compiled bytecode has exactly the same arguments, cell variables, and free variables as the original code. Therefore we can directly switch the code object in the function and call it. - See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 - for more details. - """ - self.__class__.forward.__code__ = self.compiled_codes[index] - yield - self.__class__.forward.__code__ = self.original_code_object + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. + """ # noqa: E501 line too long + original = self.original_code_object() + assert self._compiled_bytecode is not None + self.__class__.forward.__code__ = self._compiled_bytecode + try: + yield + finally: + self.__class__.forward.__code__ = original diff --git a/vllm/envs.py b/vllm/envs.py index 0530938c32f9..7987e5fb83fd 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -92,6 +92,7 @@ VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False VLLM_USE_AOT_COMPILE: bool = False + VLLM_USE_BYTECODE_HOOK: bool = False VLLM_FORCE_AOT_LOAD: bool = False VLLM_TORCH_PROFILER_WITH_STACK: bool = True VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False @@ -556,6 +557,11 @@ def get_vllm_port() -> int | None: # compilation is done in warmup phase and the compilation will be # reused in subsequent calls. "VLLM_USE_AOT_COMPILE": use_aot_compile, + # Feature flag to enable/disable bytecode in + # TorchCompileWithNoGuardsWrapper. + "VLLM_USE_BYTECODE_HOOK": lambda: bool( + int(os.environ.get("VLLM_USE_BYTECODE_HOOK", "1")) + ), # Force vllm to always load AOT compiled models from disk. Failure # to load will result in a hard error when this is enabled. # Will be ignored when VLLM_USE_AOT_COMPILE is disabled. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0f90578671db..01490e0dfac9 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -21,7 +21,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import MLAAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper from vllm.config import ( ParallelConfig, VllmConfig, @@ -1895,12 +1895,14 @@ def reset_dynamo_cache(self): compiled_model = self.model.get_language_model().model else: compiled_model = self.model.model - if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher): + if isinstance(compiled_model, TorchCompileWithNoGuardsWrapper): logger.info("Clear dynamo cache and cached dynamo bytecode.") torch._dynamo.eval_frame.remove_from_cache( - compiled_model.original_code_object + compiled_model.original_code_object() ) - compiled_model.compiled_codes.clear() + # Reset the wrapper to re-initialize. + compiled_model.compiled = False + TorchCompileWithNoGuardsWrapper.__init__(compiled_model) @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def select_hidden_states(self, hidden_states, indices_do_sample):