|
3 | 3 |
|
4 | 4 | import pytest |
5 | 5 |
|
| 6 | +from vllm.config import CompilationConfig, VllmConfig |
6 | 7 | from vllm.model_executor.custom_op import CustomOp |
7 | 8 | from vllm.model_executor.layers.activation import (GeluAndMul, |
8 | 9 | ReLUSquaredActivation, |
9 | 10 | SiluAndMul) |
10 | 11 | from vllm.model_executor.layers.layernorm import RMSNorm |
| 12 | +from vllm.plugins import set_current_vllm_config |
11 | 13 |
|
12 | 14 |
|
13 | 15 | # Registered subclass for test |
@@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation): |
51 | 53 | ]) |
52 | 54 | def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], |
53 | 55 | default_on: bool): |
54 | | - os.environ["VLLM_CUSTOM_OPS"] = env |
55 | 56 | os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) |
| 57 | + vllm_config = VllmConfig(compilation_config=CompilationConfig( |
| 58 | + custom_ops=env.split(","))) |
| 59 | + with set_current_vllm_config(vllm_config): |
| 60 | + assert CustomOp.default_on() == default_on |
56 | 61 |
|
57 | | - # Reset default_on (computed once): |
58 | | - CustomOp.default_on.cache_clear() |
| 62 | + ops_enabled = [bool(x) for x in ops_enabled] |
59 | 63 |
|
60 | | - assert CustomOp.default_on() == default_on |
| 64 | + assert RMSNorm(1024).enabled() == ops_enabled[0] |
| 65 | + assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] |
61 | 66 |
|
62 | | - ops_enabled = [bool(x) for x in ops_enabled] |
| 67 | + assert SiluAndMul().enabled() == ops_enabled[1] |
| 68 | + assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] |
63 | 69 |
|
64 | | - assert RMSNorm(1024).enabled() == ops_enabled[0] |
65 | | - assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] |
| 70 | + assert GeluAndMul().enabled() == ops_enabled[2] |
| 71 | + assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] |
66 | 72 |
|
67 | | - assert SiluAndMul().enabled() == ops_enabled[1] |
68 | | - assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] |
| 73 | + # If registered, subclasses should follow their own name |
| 74 | + assert Relu3().enabled() == ops_enabled[3] |
| 75 | + assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] |
69 | 76 |
|
70 | | - assert GeluAndMul().enabled() == ops_enabled[2] |
71 | | - assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] |
| 77 | + # Unregistered subclass |
| 78 | + class SiluAndMul2(SiluAndMul): |
| 79 | + pass |
72 | 80 |
|
73 | | - # If registered, subclasses should follow their own name |
74 | | - assert Relu3().enabled() == ops_enabled[3] |
75 | | - assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] |
76 | | - |
77 | | - # Unregistered subclass |
78 | | - class SiluAndMul2(SiluAndMul): |
79 | | - pass |
80 | | - |
81 | | - # Subclasses should not require registration |
82 | | - assert SiluAndMul2().enabled() == SiluAndMul().enabled() |
| 81 | + # Subclasses should not require registration |
| 82 | + assert SiluAndMul2().enabled() == SiluAndMul().enabled() |
83 | 83 |
|
84 | 84 |
|
85 | 85 | @pytest.mark.parametrize( |
86 | 86 | "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) |
87 | 87 | def test_enabled_ops_invalid(env: str): |
88 | | - os.environ["VLLM_CUSTOM_OPS"] = env |
89 | | - CustomOp.default_on.cache_clear() |
90 | | - |
91 | | - with pytest.raises(AssertionError): |
92 | | - RMSNorm(1024).enabled() |
| 88 | + with pytest.raises(Exception): # noqa |
| 89 | + vllm_config = VllmConfig(compilation_config=CompilationConfig( |
| 90 | + custom_ops=env.split(","))) |
| 91 | + with set_current_vllm_config(vllm_config): |
| 92 | + RMSNorm(1024).enabled() |
0 commit comments