Skip to content

Commit 4fd9375

Browse files
authored
[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)
Signed-off-by: youkaichao <[email protected]>
1 parent 661a34f commit 4fd9375

27 files changed

+359
-283
lines changed

tests/compile/piecewise/test_simple.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from vllm.compilation.compile_context import set_compile_context
1212
from vllm.compilation.counter import compilation_counter
1313
from vllm.compilation.decorators import support_torch_compile
14-
from vllm.compilation.levels import CompilationLevel
15-
from vllm.config import VllmConfig
14+
from vllm.config import CompilationLevel, VllmConfig
15+
from vllm.plugins import set_current_vllm_config
1616
from vllm.utils import direct_register_custom_op
1717

1818
global_counter = 0
@@ -82,7 +82,9 @@ def test_simple_piecewise_compile():
8282
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
8383
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
8484

85-
model = SillyModel(vllm_config=VllmConfig(), prefix='')
85+
vllm_config = VllmConfig()
86+
with set_current_vllm_config(vllm_config):
87+
model = SillyModel(vllm_config=vllm_config, prefix='')
8688

8789
inputs = torch.randn(100).cuda()
8890

tests/compile/piecewise/test_toy_llama.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
from torch.library import Library
1616

1717
from vllm.compilation.compile_context import set_compile_context
18-
from vllm.compilation.config import CompilationConfig
1918
from vllm.compilation.counter import compilation_counter
2019
from vllm.compilation.decorators import support_torch_compile
21-
from vllm.compilation.levels import CompilationLevel
22-
from vllm.config import VllmConfig
23-
from vllm.plugins import set_compilation_config
20+
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
21+
from vllm.plugins import set_compilation_config, set_current_vllm_config
2422
from vllm.utils import direct_register_custom_op
2523

2624
# create a library to hold the custom op
@@ -272,9 +270,11 @@ def run_model(llama_config,
272270
CompilationLevel.NO_COMPILATION)
273271
set_compilation_config(None)
274272

275-
model = LlamaModel(config=llama_config,
276-
vllm_config=VllmConfig(),
277-
prefix="").eval().cuda()
273+
vllm_config = VllmConfig()
274+
with set_current_vllm_config(vllm_config):
275+
model = LlamaModel(config=llama_config,
276+
vllm_config=vllm_config,
277+
prefix="").eval().cuda()
278278

279279
B = 16 # max batch size
280280
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
@@ -395,9 +395,11 @@ def benchmark():
395395
else:
396396
set_compilation_config(None)
397397

398-
model = LlamaModel(config=llama_config,
399-
vllm_config=VllmConfig(),
400-
prefix="").eval().cuda().to(torch.bfloat16)
398+
vllm_config = VllmConfig()
399+
with set_current_vllm_config(vllm_config):
400+
model = LlamaModel(config=llama_config,
401+
vllm_config=vllm_config,
402+
prefix="").eval().cuda().to(torch.bfloat16)
401403

402404
B = 256 # max batch size
403405
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()

tests/compile/test_basic_correctness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from vllm.compilation.levels import CompilationLevel
6+
from vllm.config import CompilationLevel
77
from vllm.utils import cuda_device_count_stateless
88

99
from ..utils import compare_all_settings

tests/compile/test_full_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from vllm.compilation.levels import CompilationLevel
3+
from vllm.config import CompilationLevel
44

55
from ..utils import fork_new_process_for_each_test
66
from .utils import TEST_MODELS, check_full_graph_support

tests/compile/test_fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from compressed_tensors.quantization import FP8_DTYPE
44

55
import vllm.envs as envs
6-
from vllm.compilation.config import CompilationConfig
76
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
87
find_auto_fn_maybe)
98
from vllm.compilation.reshapes import RedundantReshapesPass
9+
from vllm.config import CompilationConfig
1010
from vllm.model_executor.layers.layernorm import RMSNorm
1111
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1212
apply_fp8_linear)

tests/compile/test_wrapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
6+
from vllm.config import CompilationLevel
67

78

89
class MyMod(torch.nn.Module):
@@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
1819
def __init__(self, model):
1920
self.model = model
2021
compiled_callable = torch.compile(self.forward, backend="eager")
21-
super().__init__(compiled_callable)
22+
super().__init__(compiled_callable,
23+
compilation_level=CompilationLevel.DYNAMO_ONCE)
2224

2325
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
2426
# this is the function to be compiled

tests/compile/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from tests.quantization.utils import is_quant_method_supported
66
from vllm import LLM, SamplingParams
7-
from vllm.compilation.levels import CompilationLevel
7+
from vllm.config import CompilationLevel
88
from vllm.platforms import current_platform
99

1010
TEST_MODELS = [

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
import pytest
55

6+
from vllm.config import CompilationConfig, VllmConfig
67
from vllm.model_executor.custom_op import CustomOp
78
from vllm.model_executor.layers.activation import (GeluAndMul,
89
ReLUSquaredActivation,
910
SiluAndMul)
1011
from vllm.model_executor.layers.layernorm import RMSNorm
12+
from vllm.plugins import set_current_vllm_config
1113

1214

1315
# Registered subclass for test
@@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation):
5153
])
5254
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
5355
default_on: bool):
54-
os.environ["VLLM_CUSTOM_OPS"] = env
5556
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
57+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
58+
custom_ops=env.split(",")))
59+
with set_current_vllm_config(vllm_config):
60+
assert CustomOp.default_on() == default_on
5661

57-
# Reset default_on (computed once):
58-
CustomOp.default_on.cache_clear()
62+
ops_enabled = [bool(x) for x in ops_enabled]
5963

60-
assert CustomOp.default_on() == default_on
64+
assert RMSNorm(1024).enabled() == ops_enabled[0]
65+
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
6166

62-
ops_enabled = [bool(x) for x in ops_enabled]
67+
assert SiluAndMul().enabled() == ops_enabled[1]
68+
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
6369

64-
assert RMSNorm(1024).enabled() == ops_enabled[0]
65-
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
70+
assert GeluAndMul().enabled() == ops_enabled[2]
71+
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
6672

67-
assert SiluAndMul().enabled() == ops_enabled[1]
68-
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
73+
# If registered, subclasses should follow their own name
74+
assert Relu3().enabled() == ops_enabled[3]
75+
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
6976

70-
assert GeluAndMul().enabled() == ops_enabled[2]
71-
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
77+
# Unregistered subclass
78+
class SiluAndMul2(SiluAndMul):
79+
pass
7280

73-
# If registered, subclasses should follow their own name
74-
assert Relu3().enabled() == ops_enabled[3]
75-
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
76-
77-
# Unregistered subclass
78-
class SiluAndMul2(SiluAndMul):
79-
pass
80-
81-
# Subclasses should not require registration
82-
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
81+
# Subclasses should not require registration
82+
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
8383

8484

8585
@pytest.mark.parametrize(
8686
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
8787
def test_enabled_ops_invalid(env: str):
88-
os.environ["VLLM_CUSTOM_OPS"] = env
89-
CustomOp.default_on.cache_clear()
90-
91-
with pytest.raises(AssertionError):
92-
RMSNorm(1024).enabled()
88+
with pytest.raises(Exception): # noqa
89+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
90+
custom_ops=env.split(",")))
91+
with set_current_vllm_config(vllm_config):
92+
RMSNorm(1024).enabled()

tests/tpu/test_compilation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import depyf
77

8-
from vllm.compilation.levels import CompilationLevel
8+
from vllm.config import CompilationLevel
99

1010
# disable custom dispatcher, let Dynamo takes over
1111
# all the control

tests/tpu/test_custom_dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22

3-
from vllm.compilation.levels import CompilationLevel
3+
from vllm.config import CompilationLevel
44

55
from ..utils import compare_two_settings
66

0 commit comments

Comments
 (0)