Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions tests/compile/fullgraph/test_multiple_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@support_torch_compile
@support_torch_compile(no_weak_ref_output=True)
class CompiledAttention(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -144,8 +144,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.hidden_states[:bsz].copy_(x)
x = self.attn_one(self.hidden_states[:bsz])
self.hidden_states[:bsz].copy_(x)
x = self.attn_two(self.hidden_states[:bsz])
return x
y = self.attn_two(self.hidden_states[:bsz])
# Use value x in the final output to test that value of x
# is not overwritten by call to self.attn_two when using cudagraph
return x + y


@torch.inference_mode
Expand Down
105 changes: 105 additions & 0 deletions tests/compile/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,108 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
# num_cudagraph_sizes * num cudagraphable graphs to capture
):
run_model(vllm_config, mod_A, cudagraph_runtime_mode)


@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
def test_no_weak_ref_output_decorator(use_inductor_graph_partition, monkeypatch):
# disable compile cache so that we can count the number of compilations
# appropriately
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")

# piecewise
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE

expected_num_graphs_seen = 1
expected_num_cudagraph_captured = (
4 # num_cudagraph_sizes * num cudagraphs to capture
)
if use_inductor_graph_partition:
expected_num_piecewise_graphs_seen = 1
expected_num_piecewise_capturable_graphs_seen = 1
expected_num_backend_compilations = 1
else:
expected_num_piecewise_graphs_seen = 3
expected_num_piecewise_capturable_graphs_seen = 2
expected_num_backend_compilations = 2

@support_torch_compile(no_weak_ref_output=False)
class A(nn.Module):
def __init__(
self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs
) -> None:
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = x * 3
return x

@support_torch_compile(no_weak_ref_output=True)
@support_torch_compile(no_weak_ref_output=False)
class B(A): ...

# no_weak_ref_output defaults to False
@support_torch_compile()
class C(B): ...

with compilation_counter.expect(
num_weakref_output_graphs=1,
# Single compile target (mod A), one VllmBackend initialized
# no_weak_ref_output set to False
) and set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()

# A has support_torch_compile
with compilation_counter.expect(
num_graphs_seen=expected_num_graphs_seen,
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured,
):
run_model(vllm_config, mod_A, cudagraph_runtime_mode)

with compilation_counter.expect(
num_weakref_output_graphs=0,
) and set_current_vllm_config(vllm_config):
mod_B = B(vllm_config=vllm_config, prefix="").eval().cuda()

# B also has support_torch_compile
with compilation_counter.expect(
num_graphs_seen=expected_num_graphs_seen,
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured,
):
run_model(vllm_config, mod_B, cudagraph_runtime_mode)

with compilation_counter.expect(
num_weakref_output_graphs=1,
) and set_current_vllm_config(vllm_config):
mod_C = C(vllm_config=vllm_config, prefix="").eval().cuda()

# C has support_torch_compile
with compilation_counter.expect(
num_graphs_seen=expected_num_graphs_seen,
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured,
):
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
23 changes: 21 additions & 2 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def __init__(
module: torch.fx.GraphModule,
compile_submod_names: list[str],
vllm_config: VllmConfig,
no_weak_ref_output: bool,
vllm_backend: "VllmBackend",
):
super().__init__(module)
Expand All @@ -398,6 +399,7 @@ def __init__(
self.vllm_backend = vllm_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self.extra_traceback = False
self.no_weak_ref_output = no_weak_ref_output

def run(self, *args):
fake_args = [
Expand Down Expand Up @@ -462,6 +464,16 @@ def call_module(
current_platform.get_static_graph_wrapper_cls()
)

# By default, convert output of last graph in a compilation unit
# to a weakref to save some memory. However, if there are >1
# submodules compiled inside the model, last graph outputs from
# non-last submodule should not be converted to a weakref as it
# may result in memory being overwritten by subsequent graph
# replays. In these cases, no_weak_ref_output can be set to True
weak_ref_output = (
piecewise_backend.is_last_graph and not self.no_weak_ref_output
)

# Always assign PIECEWISE runtime mode to the
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime mode, no matter it
Expand All @@ -473,7 +485,7 @@ def call_module(
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
weak_ref_output=piecewise_backend.is_last_graph,
weak_ref_output=weak_ref_output,
),
)
else:
Expand Down Expand Up @@ -534,6 +546,7 @@ class VllmBackend:
def __init__(
self,
vllm_config: VllmConfig,
no_weak_ref_output: bool = False,
prefix: str = "",
):
# if the model is initialized with a non-empty prefix,
Expand All @@ -557,6 +570,8 @@ def __init__(
self.compilation_config
)

self.no_weak_ref_output = no_weak_ref_output

# `torch.compile` is JIT compiled, so we don't need to
# do anything here

Expand Down Expand Up @@ -742,7 +757,11 @@ def __call__(
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter(
self.split_gm, submod_names_to_compile, self.vllm_config, self
self.split_gm,
submod_names_to_compile,
self.vllm_config,
self.no_weak_ref_output,
self,
).run(*example_inputs)

graph_path = os.path.join(local_cache_dir, "computation_graph.py")
Expand Down
2 changes: 2 additions & 0 deletions vllm/compilation/counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class CompilationCounter:
num_compiled_artifacts_saved: int = 0
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
stock_torch_compile_count: int = 0
# Number of piecewise graphs that will return a weakref output
num_weakref_output_graphs: int = 0

def clone(self) -> "CompilationCounter":
return copy.deepcopy(self)
Expand Down
2 changes: 2 additions & 0 deletions vllm/compilation/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def __call__(self, *args, **kwargs):
# any other cuda graph.
output = weak_ref_tensors(output)

compilation_counter.num_weakref_output_graphs += 1

# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
Expand Down
67 changes: 62 additions & 5 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
logger = init_logger(__name__)

IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
LAST_PIECEWISE_GRAPH_WEAKREF_KEY = "_last_graph_weakref_vllm"

_T = TypeVar("_T", bound=type[nn.Module])

Expand Down Expand Up @@ -72,6 +73,13 @@ def support_torch_compile(
) -> Callable[[_T], _T]: ...


@overload
def support_torch_compile(
*,
no_weak_ref_output: bool = False,
) -> Callable[[_T], _T]: ...


@overload
def support_torch_compile(
*,
Expand Down Expand Up @@ -104,6 +112,7 @@ def support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
no_weak_ref_output: bool = False,
) -> Callable[[_T], _T] | _T:
"""
A decorator to add support for compiling the forward method of a class.
Expand Down Expand Up @@ -161,6 +170,32 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
dim to be decorated with `mark_unbacked`. This is useful if we would like to
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
such as for vision model compilation

If `no_weak_ref_output` is set to `True`, the output of the last graph
of each compiled nn.Module will not be converted to a weakref.
This conversion saves memory but is only safe when the output of the last
graph is not used by any subsequent CUDA graphs in the model forward.
In full cudagraph mode, this is ignored as full cudagraphs are always
captured for the full model.

This defaults to `True`, because in most cases the entire model is being
compiled, so the assumption that there is no other cuda graph after the last
graph holds. However, in rare cases, multiple submodules are compiled within
a single model. In this case, only the output of the last graph of the last
submodule is safe to be converted to a weakref. For example, if a model has
2 submodules mod_A and mod_B that are piecewise compiled + graph captured
separately, e.g.:

def forward(self, x):
a_out = self.mod_A(x)
b_out = self.mod_B(a_out)
return a_out + b_out

Then the output of mod_A should NOT be converted to a weakref, because the
call to mod_B may overwrite `a_out`. This is because vLLM shares a global
memory pool for all CUDA graphs, causing PyTorch to re-use memory where
possible. To avoid its output from being overwritten, mod_A should specify
`@support_torch_compile(no_weak_ref_output=True)`
"""

def cls_decorator_helper(cls: _T) -> _T:
Expand Down Expand Up @@ -199,7 +234,11 @@ def cls_decorator_helper(cls: _T) -> _T:
f"Argument {k} not found in the forward method of {cls}"
)
return _support_torch_compile(
cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
cls,
inferred_dynamic_arg_dims,
mark_unbacked_dims,
enable_if,
no_weak_ref_output,
)

if cls is not None:
Expand Down Expand Up @@ -242,6 +281,7 @@ def _support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
no_weak_ref_output: bool = False,
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
Expand All @@ -259,6 +299,9 @@ def _support_torch_compile(

setattr(cls, IGNORE_COMPILE_KEY, False)

# setting as attribute on cls ensures child class will override parent class
setattr(cls, LAST_PIECEWISE_GRAPH_WEAKREF_KEY, no_weak_ref_output)

def __init__(
self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs
):
Expand Down Expand Up @@ -289,9 +332,11 @@ def __init__(
if self.do_not_compile:
return

self.no_weak_ref_output = getattr(cls, LAST_PIECEWISE_GRAPH_WEAKREF_KEY, False)

compilation_counter.num_models_seen += 1
self.compiled = False
TorchCompileWithNoGuardsWrapper.__init__(self)
TorchCompileWithNoGuardsWrapper.__init__(self, no_weak_ref_output)

cls.__init__ = __init__

Expand Down Expand Up @@ -450,7 +495,9 @@ def patched_inline_call(self_):
InliningInstructionTranslator, "inline_call_", patched_inline_call
),
torch._dynamo.config.patch(**dynamo_config_patches),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
maybe_use_cudagraph_partition_wrapper(
self.vllm_config, self.no_weak_ref_output
),
_torch27_patch_tensor_subclasses(),
):
if envs.VLLM_USE_AOT_COMPILE:
Expand Down Expand Up @@ -478,7 +525,10 @@ def patched_inline_call(self_):


@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
def maybe_use_cudagraph_partition_wrapper(
vllm_config: VllmConfig,
no_weak_ref_output: bool,
):
"""
Context manager to set/unset customized cudagraph partition wrappers.

Expand Down Expand Up @@ -508,14 +558,21 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata):
partition_id = metadata.partition_index
num_partitions = metadata.num_partitions

# If no_weak_ref_output passed to compile decorator
# do not convert last partition's output to a weakref
weak_ref_output = (
partition_id == num_partitions - 1 and not no_weak_ref_output
)

return static_graph_wrapper_class(
runnable=f,
vllm_config=vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=partition_id == 0,
gc_disable=partition_id != 0,
weak_ref_output=partition_id == num_partitions - 1,
weak_ref_output=weak_ref_output,
),
)

Expand Down
Loading