Skip to content

Commit eecb574

Browse files
committed
Revert "[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)"
This reverts commit bd56c98.
1 parent ca2ca8d commit eecb574

File tree

9 files changed

+170
-249
lines changed

9 files changed

+170
-249
lines changed

tests/compile/backend.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,21 @@ class TestBackend:
1313
This class provides a simple Inductor backend that can be used for testing.
1414
It takes a list of custom passes and runs them after Inductor's passes.
1515
It also saves the graph before and after the custom passes for inspection.
16-
17-
Inductor config can be modified directly by editing the inductor_config
18-
property. This can be helpful for adding passes like the
19-
'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
2016
"""
2117

2218
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
2319
None]]):
2420
self.custom_passes = list(passes)
2521
from torch._inductor import config
26-
self.inductor_config = config.shallow_copy_dict()
27-
self.inductor_config['force_disable_caches'] = True
28-
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
22+
self.current_config = config.shallow_copy_dict()
23+
self.current_config['force_disable_caches'] = True
24+
self.current_config['post_grad_custom_post_pass'] = self.post_pass
2925

3026
def __call__(self, graph: fx.GraphModule, example_inputs):
31-
self.graph_pre_compile = deepcopy(graph)
3227
from torch._inductor.compile_fx import compile_fx
3328
return compile_fx(graph,
3429
example_inputs,
35-
config_patches=self.inductor_config)
30+
config_patches=self.current_config)
3631

3732
def post_pass(self, graph: fx.Graph):
3833
self.graph_pre_pass = deepcopy(graph)

tests/compile/test_functionalization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
1010
kFp8DynamicTokenSym, kFp8StaticTensorSym)
1111
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
12-
from vllm.compilation.noop_elimination import NoOpEliminationPass
12+
from vllm.compilation.reshapes import RedundantReshapesPass
1313
from vllm.config import CompilationConfig
1414

1515
from .backend import TestBackend
@@ -50,11 +50,11 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
5050
torch.set_default_device("cuda")
5151

5252
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
53-
enable_noop=True)
54-
noop_pass = NoOpEliminationPass(config)
53+
enable_reshape=True)
54+
reshape_pass = RedundantReshapesPass(config)
5555
fusion_pass = FusionPass.instance(config)
5656

57-
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
57+
passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
5858
func_pass = FixFunctionalizationPass(config)
5959
backend_func = TestBackend(*passes, func_pass)
6060
backend_no_func = TestBackend(*passes)

tests/compile/test_fusion.py

Lines changed: 58 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,23 @@
55
from compressed_tensors.quantization import FP8_DTYPE
66

77
import vllm.envs as envs
8-
import vllm.plugins
98
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
109
FusionPass, QuantKey)
1110
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
12-
from vllm.compilation.noop_elimination import NoOpEliminationPass
13-
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
11+
from vllm.compilation.reshapes import RedundantReshapesPass
12+
from vllm.config import CompilationConfig
1413
from vllm.model_executor.layers.layernorm import RMSNorm
1514
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
16-
CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity)
15+
apply_fp8_linear)
1716

1817
from .backend import TestBackend
1918

2019

2120
class TestModel(torch.nn.Module):
2221

23-
def __init__(self, hidden_size: int, eps: float, static: bool,
24-
cutlass_fp8_enabled: bool, *args, **kwargs):
22+
def __init__(self, hidden_size: int, eps: float, static: bool, *args,
23+
**kwargs):
2524
super().__init__(*args, **kwargs)
26-
self.cutlass_fp8_enabled = cutlass_fp8_enabled
2725
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
2826
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
2927
if static:
@@ -43,17 +41,15 @@ def forward(self, x):
4341
self.w[0],
4442
self.wscale[0],
4543
self.scale[0],
46-
use_per_token_if_dynamic=True,
47-
cutlass_fp8_supported=self.cutlass_fp8_enabled)
44+
use_per_token_if_dynamic=True)
4845
# make sure resid is used for replacement to work
4946
y2, resid = self.norm[1](x2, resid)
5047

5148
x3 = apply_fp8_linear(y2,
5249
self.w[1],
5350
self.wscale[1],
5451
self.scale[1],
55-
use_per_token_if_dynamic=True,
56-
cutlass_fp8_supported=self.cutlass_fp8_enabled)
52+
use_per_token_if_dynamic=True)
5753
y3, resid = self.norm[2](x3, resid) # use resid here
5854
return y3
5955

@@ -63,67 +59,60 @@ def forward(self, x):
6359
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
6460
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
6561
@pytest.mark.parametrize("static", [True, False])
66-
@pytest.mark.parametrize("cutlass_fp8_enabled",
67-
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
6862
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
6963
reason="Only test on CUDA")
70-
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
71-
cutlass_fp8_enabled):
64+
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
7265
torch.set_default_device("cuda")
7366
torch.set_default_dtype(dtype)
7467
torch.manual_seed(1)
75-
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
7668

77-
vllm_config = VllmConfig(compilation_config=CompilationConfig(
78-
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
79-
with vllm.config.set_current_vllm_config(vllm_config):
80-
# Reshape pass is needed for the fusion pass to work
81-
config = CompilationConfig.PassConfig(enable_fusion=True,
82-
enable_noop=True)
83-
noop_pass = NoOpEliminationPass(config)
84-
fusion_pass = FusionPass.instance(config)
85-
86-
backend = TestBackend(noop_pass, fusion_pass)
87-
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
88-
89-
# First dimension dynamic
90-
x = torch.rand(num_tokens, hidden_size)
91-
torch._dynamo.mark_dynamic(x, 0)
92-
93-
result = model(x)
94-
95-
model2 = torch.compile(model, backend=backend)
96-
result2 = model2(x)
97-
98-
# Higher tol for dynamic, even higher for bfloat16
99-
if static:
100-
ATOL, RTOL = (1e-3, 1e-3)
101-
elif dtype == torch.float16:
102-
ATOL, RTOL = (2e-3, 2e-3)
103-
else:
104-
ATOL, RTOL = (1e-2, 1e-2)
105-
106-
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
107-
108-
# Check substitution worked
109-
pre_nodes = backend.graph_pre_pass.nodes
110-
post_nodes = backend.graph_post_pass.nodes
111-
112-
# static is per-tensor, dynamic is per-token
113-
key = QuantKey(dtype=FP8_DTYPE,
114-
static=static,
115-
per_tensor=static,
116-
symmetric=True)
117-
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
118-
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
119-
fp8_quant = QUANT_OPS[key]
120-
121-
# In pre-nodes, fp8 quant should be there and fused kernels should not
122-
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
123-
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
124-
find_auto_fn(pre_nodes, fp8_quant)
125-
126-
# In post-nodes, fused kernels should be there and fp8 quant should not
127-
find_auto_fn(post_nodes, rms_quant)
128-
find_auto_fn(post_nodes, add_rms_quant)
129-
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
69+
# Reshape pass is needed for the fusion pass to work
70+
config = CompilationConfig.PassConfig(enable_fusion=True,
71+
enable_reshape=True)
72+
reshape_pass = RedundantReshapesPass(config)
73+
fusion_pass = FusionPass.instance(config)
74+
75+
backend = TestBackend(reshape_pass, fusion_pass)
76+
model = TestModel(hidden_size, eps, static)
77+
78+
# First dimension dynamic
79+
x = torch.rand(num_tokens, hidden_size)
80+
torch._dynamo.mark_dynamic(x, 0)
81+
82+
result = model(x)
83+
84+
model2 = torch.compile(model, backend=backend)
85+
result2 = model2(x)
86+
87+
# Higher tol for dynamic, even higher for bfloat16
88+
if static:
89+
ATOL, RTOL = (1e-3, 1e-3)
90+
elif dtype == torch.float16:
91+
ATOL, RTOL = (2e-3, 2e-3)
92+
else:
93+
ATOL, RTOL = (1e-2, 1e-2)
94+
95+
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
96+
97+
# Check substitution worked
98+
pre_nodes = backend.graph_pre_pass.nodes
99+
post_nodes = backend.graph_post_pass.nodes
100+
101+
# static is per-tensor, dynamic is per-token
102+
key = QuantKey(dtype=FP8_DTYPE,
103+
static=static,
104+
per_tensor=static,
105+
symmetric=True)
106+
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
107+
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
108+
fp8_quant = QUANT_OPS[key]
109+
110+
# In pre-nodes, fp8 quant should be present and fused kernels should not
111+
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
112+
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
113+
find_auto_fn(pre_nodes, fp8_quant)
114+
115+
# In post-nodes, fused kernels should be present and fp8 quant should not
116+
find_auto_fn(post_nodes, rms_quant)
117+
find_auto_fn(post_nodes, add_rms_quant)
118+
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None

vllm/compilation/noop_elimination.py

Lines changed: 0 additions & 135 deletions
This file was deleted.

vllm/compilation/pass_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .fix_functionalization import FixFunctionalizationPass
1212
from .fusion import FusionPass
1313
from .inductor_pass import InductorPass
14-
from .noop_elimination import NoOpEliminationPass
14+
from .reshapes import RedundantReshapesPass
1515

1616
logger = init_logger(__name__)
1717

@@ -36,7 +36,7 @@ class PostGradPassManager(Parent):
3636
3737
The order of the post-grad post-passes is:
3838
1. passes (constructor parameter)
39-
2. default passes (NoopEliminationPass, FusionPass)
39+
2. default passes (RedundantReshapesPass, FusionPass)
4040
3. config["post_grad_custom_post_pass"] (if it exists)
4141
4. fix_functionalization
4242
This way, all passes operate on a functionalized graph.
@@ -54,8 +54,8 @@ def __call__(self, graph: fx.Graph):
5454

5555
def configure(self, pass_config: CompilationConfig.PassConfig):
5656
self.pass_config = pass_config
57-
if pass_config.enable_noop:
58-
self.passes += [NoOpEliminationPass(pass_config)]
57+
if pass_config.enable_reshape:
58+
self.passes += [RedundantReshapesPass(pass_config)]
5959

6060
if pass_config.enable_fusion:
6161
self.passes += [FusionPass.instance(pass_config)]

0 commit comments

Comments
 (0)