Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.utils import direct_register_custom_op

global_counter = 0
Expand Down
4 changes: 1 addition & 3 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_compilation_config
from vllm.utils import direct_register_custom_op

Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
from vllm.utils import cuda_device_count_stateless

from ..utils import compare_all_settings
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel

from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS, check_full_graph_support
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from compressed_tensors.quantization import FP8_DTYPE

import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
Expand Down
4 changes: 3 additions & 1 deletion tests/compile/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel


class MyMod(torch.nn.Module):
Expand All @@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model):
self.model = model
compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(compiled_callable)
super().__init__(compiled_callable,
compilation_level=CompilationLevel.DYNAMO_ONCE)

def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# this is the function to be compiled
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
from vllm.platforms import current_platform

TEST_MODELS = [
Expand Down
52 changes: 26 additions & 26 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import pytest

from vllm.config import CompilationConfig, VllmConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.plugins import set_current_vllm_config


# Registered subclass for test
Expand Down Expand Up @@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation):
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool):
os.environ["VLLM_CUSTOM_OPS"] = env
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on

# Reset default_on (computed once):
CustomOp.default_on.cache_clear()
ops_enabled = [bool(x) for x in ops_enabled]

assert CustomOp.default_on() == default_on
assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]

ops_enabled = [bool(x) for x in ops_enabled]
assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]

assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]

assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]

assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass

# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]

# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass

# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()


@pytest.mark.parametrize(
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
def test_enabled_ops_invalid(env: str):
os.environ["VLLM_CUSTOM_OPS"] = env
CustomOp.default_on.cache_clear()

with pytest.raises(AssertionError):
RMSNorm(1024).enabled()
with pytest.raises(Exception): # noqa
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled()
2 changes: 1 addition & 1 deletion tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import depyf

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel

# disable custom dispatcher, let Dynamo takes over
# all the control
Expand Down
2 changes: 1 addition & 1 deletion tests/tpu/test_custom_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel

from ..utils import compare_two_settings

Expand Down
3 changes: 1 addition & 2 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
import torch.fx as fx

import vllm.envs as envs
from vllm.config import CompilationConfig, CompilationLevel
from vllm.logger import init_logger
from vllm.utils import combine_fx_passes, weak_ref_tensors

from .config import CompilationConfig
from .counter import compilation_counter
from .fusion import FusionPass
from .levels import CompilationLevel
from .reshapes import RedundantReshapesPass

logger = init_logger(__name__)
Expand Down
159 changes: 0 additions & 159 deletions vllm/compilation/config.py

This file was deleted.

10 changes: 5 additions & 5 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@

import torch

import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
Expand Down Expand Up @@ -126,12 +124,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [
self.do_not_compile = \
vllm_config.compilation_config.level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo()
if self.do_not_compile:
return
TorchCompileWrapperWithCustomDispatcher.__init__(self)
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)

cls.__init__ = __init__ # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
fwd_only, register_replacement)

from vllm.compilation.config import CompilationConfig
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import CompilationConfig
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/inductor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from vllm.compilation.config import CompilationConfig
from vllm.config import CompilationConfig
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (
Expand Down
Loading