Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion tests/compile/piecewise/test_multiple_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer

from ...utils import create_new_process_for_each_test

# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401

Expand Down Expand Up @@ -193,7 +195,14 @@ def run_model(


@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
@create_new_process_for_each_test("spawn")
def test_multi_graph_piecewise_compile(
use_inductor_graph_partition: bool, use_bytecode_hook: bool, monkeypatch
):
# Set the environment variable for this test
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")

if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")

Expand Down
3 changes: 3 additions & 0 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer

from ...utils import create_new_process_for_each_test

# This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter

Expand Down Expand Up @@ -124,6 +126,7 @@ def _run_simple_model(

@pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode()
@create_new_process_for_each_test("spawn")
def test_simple_piecewise_compile(use_inductor):
_run_simple_model(
splitting_ops=["silly::attention"],
Expand Down
9 changes: 8 additions & 1 deletion tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer

from ...utils import create_new_process_for_each_test

# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401

Expand Down Expand Up @@ -334,6 +336,7 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
("inductor", True), # Inductor, Inductor partition
],
)
@create_new_process_for_each_test("spawn")
def test_toy_llama(
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
):
Expand Down Expand Up @@ -513,4 +516,8 @@ def benchmark():


if __name__ == "__main__":
benchmark()
# Protect against subprocess reimport when using spawn_new_process_for_each_test
import os

if os.environ.get("RUNNING_IN_SUBPROCESS") != "1":
benchmark()
155 changes: 115 additions & 40 deletions tests/compile/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,134 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import os

import pytest
import torch

from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationMode
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
from vllm.config import (
CompilationConfig,
CompilationMode,
VllmConfig,
set_current_vllm_config,
)


class MyMod(torch.nn.Module):
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
if cache is not None:
return x + cache
return x * 2
if x.size()[0] >= 4:
return x * 2
else:
return x * 100


class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
class MyWrapper(TorchCompileWithNoGuardsWrapper):
def __init__(self, model):
self.model = model
compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(
compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE
)
super().__init__()

def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
def forward(self, x: torch.Tensor): # type: ignore[override]
# this is the function to be compiled
return self.model(x, cache)

def __call__(self, x: torch.Tensor, cache: torch.Tensor | None = None):
# let torch.compile compile twice
if len(self.compiled_codes) == 2:
dispatch_id = 0 if cache is None else 1
with self.dispatch_to_code(dispatch_id):
return self.forward(x, cache)
else:
return self.compiled_callable(x, cache)
return self.model(x)


@pytest.mark.parametrize("use_bytecode_hook", [True, False])
def test_torch_compile_wrapper(use_bytecode_hook, monkeypatch):
"""Test basic functionality of TorchCompileWithNoGuardsWrapper."""
# Set the environment variable for this test
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")

def test_torch_compile_wrapper():
mod = MyMod()
wrappers = []
for i in range(3):
# Create a proper vLLM config instead of mocking
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig()
vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
vllm_config.compilation_config.backend = "inductor"

# Test DYNAMO_TRACE_ONCE
with set_current_vllm_config(vllm_config):
torch._dynamo.reset()
mod = MyMod()
wrapper = MyWrapper(mod)

# First call should trigger compilation
x = torch.tensor([1, 2, 3, 4])
torch._dynamo.mark_dynamic(x, 0)

result1 = wrapper(x)
expected1 = torch.tensor([2, 4, 6, 8])
assert torch.allclose(result1, expected1), (
f"Expected {expected1}, got {result1}"
)

# Second call should use compiled code
x2 = torch.tensor([1, 2, 3])
result2 = wrapper(x2)
expected2 = torch.tensor([2, 4, 6])
assert torch.allclose(result2, expected2), (
f"Expected {expected2}, got {result2}"
)

# without the wrapper result would be different.
result3 = mod(x2)
expected3 = torch.tensor([100, 200, 300])

assert torch.allclose(result3, expected3), (
f"Expected {result3}, got {expected3}"
)

# with STOCK_TORCH_COMPILE we do not remove guards.
vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE
torch._dynamo.reset()
with set_current_vllm_config(vllm_config):
mod = MyMod()
wrapper = MyWrapper(mod)
wrappers.append(wrapper)
x = torch.tensor([1])
wrapper(x, None) # profile run, compile
# create a cache tensor
cache = torch.tensor([2])
wrapper(x, cache) # warm up with cache, recompile

# for new input, dispatch to the compiled code directly
new_x = torch.tensor([3])
assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code
assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code

for wrapper in wrappers:
# make sure they have independent compiled codes
assert len(wrapper.compiled_codes) == 2

# First call should trigger compilation
x = torch.tensor([1, 2, 3, 4])
torch._dynamo.mark_dynamic(x, 0)

result1 = wrapper(x)
expected1 = torch.tensor([2, 4, 6, 8])
assert torch.allclose(result1, expected1), (
f"Expected {expected1}, got {result1}"
)

# Second call should triger another compilation
x2 = torch.tensor([1, 2, 3])
result2 = wrapper(x2)
expected2 = torch.tensor([100, 200, 300])
assert torch.allclose(result2, expected2), (
f"Expected {expected2}, got {result2}"
)

# NO_COMPILATION level not supported.
vllm_config.compilation_config.mode = None
torch._dynamo.reset()
with set_current_vllm_config(vllm_config):
torch._dynamo.reset()
mod = MyMod()

try:
wrapper = MyWrapper(mod)
except Exception:
return
raise AssertionError("expected an exception to be raised")


if __name__ == "__main__":
# Run with both parameter values

class MockMonkeypatch:
def setenv(self, name, value):
os.environ[name] = value

mp = MockMonkeypatch()

print("Testing with VLLM_USE_BYTECODE_HOOK=False")
test_torch_compile_wrapper(False, mp)

print("Testing with VLLM_USE_BYTECODE_HOOK=True")
test_torch_compile_wrapper(True, mp)

print("All tests passed!")
10 changes: 10 additions & 0 deletions tests/models/multimodal/generation/test_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def qwen2_5_vl_chat_template(*query):
@pytest.mark.parametrize("num_frames", [16])
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
def test_qwen2_5_vl_evs_functionality(
vllm_runner,
video_assets,
Expand All @@ -42,10 +43,14 @@ def test_qwen2_5_vl_evs_functionality(
num_frames: int,
dtype: str,
max_tokens: int,
use_bytecode_hook: bool,
monkeypatch,
) -> None:
"""Test EVS (Efficient Video Sampling) functionality with different
pruning rates.
"""
# Set the environment variable for this test
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")

# Sample frames from video assets
sampled_vids = [
Expand Down Expand Up @@ -86,6 +91,7 @@ def test_qwen2_5_vl_evs_functionality(
@pytest.mark.parametrize("num_frames", [16])
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
def test_qwen2_5_vl_evs_batched_videos(
vllm_runner,
video_assets,
Expand All @@ -94,6 +100,8 @@ def test_qwen2_5_vl_evs_batched_videos(
num_frames: int,
dtype: str,
max_tokens: int,
use_bytecode_hook: bool,
monkeypatch,
) -> None:
"""Test EVS functionality with batched videos.

Expand All @@ -102,6 +110,8 @@ def test_qwen2_5_vl_evs_batched_videos(
2. Both pruning configurations work with multiple videos
3. The model doesn't crash when processing multiple videos simultaneously
"""
# Set the environment variable for this test
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
# Sample frames from video assets
sampled_vids = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
Expand Down
8 changes: 8 additions & 0 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"


@pytest.fixture(autouse=True)
def reset_torch_dynamo():
"""Reset torch dynamo cache before each test"""
yield
# Cleanup after test
torch._dynamo.reset()


@pytest.mark.parametrize(
"speculative_config",
[
Expand Down
Loading