55from compressed_tensors .quantization import FP8_DTYPE
66
77import vllm .envs as envs
8- import vllm .plugins
98from vllm .compilation .fusion import (FUSED_OPS , QUANT_OPS , FusedRMSQuantKey ,
109 FusionPass , QuantKey )
1110from vllm .compilation .fx_utils import find_auto_fn , find_auto_fn_maybe
12- from vllm .compilation .noop_elimination import NoOpEliminationPass
13- from vllm .config import CompilationConfig , CompilationLevel , VllmConfig
11+ from vllm .compilation .reshapes import RedundantReshapesPass
12+ from vllm .config import CompilationConfig
1413from vllm .model_executor .layers .layernorm import RMSNorm
1514from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
16- CUTLASS_FP8_SUPPORTED , apply_fp8_linear , maybe_create_device_identity )
15+ apply_fp8_linear )
1716
1817from .backend import TestBackend
1918
2019
2120class TestModel (torch .nn .Module ):
2221
23- def __init__ (self , hidden_size : int , eps : float , static : bool ,
24- cutlass_fp8_enabled : bool , * args , ** kwargs ):
22+ def __init__ (self , hidden_size : int , eps : float , static : bool , * args ,
23+ ** kwargs ):
2524 super ().__init__ (* args , ** kwargs )
26- self .cutlass_fp8_enabled = cutlass_fp8_enabled
2725 self .norm = [RMSNorm (hidden_size , eps ) for _ in range (3 )]
2826 self .wscale = [torch .rand (1 , dtype = torch .float32 ) for _ in range (2 )]
2927 if static :
@@ -43,17 +41,15 @@ def forward(self, x):
4341 self .w [0 ],
4442 self .wscale [0 ],
4543 self .scale [0 ],
46- use_per_token_if_dynamic = True ,
47- cutlass_fp8_supported = self .cutlass_fp8_enabled )
44+ use_per_token_if_dynamic = True )
4845 # make sure resid is used for replacement to work
4946 y2 , resid = self .norm [1 ](x2 , resid )
5047
5148 x3 = apply_fp8_linear (y2 ,
5249 self .w [1 ],
5350 self .wscale [1 ],
5451 self .scale [1 ],
55- use_per_token_if_dynamic = True ,
56- cutlass_fp8_supported = self .cutlass_fp8_enabled )
52+ use_per_token_if_dynamic = True )
5753 y3 , resid = self .norm [2 ](x3 , resid ) # use resid here
5854 return y3
5955
@@ -63,67 +59,60 @@ def forward(self, x):
6359@pytest .mark .parametrize ("num_tokens" , [7 , 256 , 533 , 2048 , 2049 ])
6460@pytest .mark .parametrize ("eps" , [1e-5 , 1e-6 ])
6561@pytest .mark .parametrize ("static" , [True , False ])
66- @pytest .mark .parametrize ("cutlass_fp8_enabled" ,
67- [True , False ] if CUTLASS_FP8_SUPPORTED else [False ])
6862@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE != "cuda" ,
6963 reason = "Only test on CUDA" )
70- def test_fusion_rmsnorm_quant (dtype , hidden_size , num_tokens , eps , static ,
71- cutlass_fp8_enabled ):
64+ def test_fusion_rmsnorm_quant (dtype , hidden_size , num_tokens , eps , static ):
7265 torch .set_default_device ("cuda" )
7366 torch .set_default_dtype (dtype )
7467 torch .manual_seed (1 )
75- maybe_create_device_identity () # needed for certain non-cutlass fp8 paths
7668
77- vllm_config = VllmConfig (compilation_config = CompilationConfig (
78- level = CompilationLevel .PIECEWISE , custom_ops = ["+rms_norm" ]))
79- with vllm .config .set_current_vllm_config (vllm_config ):
80- # Reshape pass is needed for the fusion pass to work
81- config = CompilationConfig .PassConfig (enable_fusion = True ,
82- enable_noop = True )
83- noop_pass = NoOpEliminationPass (config )
84- fusion_pass = FusionPass .instance (config )
85-
86- backend = TestBackend (noop_pass , fusion_pass )
87- model = TestModel (hidden_size , eps , static , cutlass_fp8_enabled )
88-
89- # First dimension dynamic
90- x = torch .rand (num_tokens , hidden_size )
91- torch ._dynamo .mark_dynamic (x , 0 )
92-
93- result = model (x )
94-
95- model2 = torch .compile (model , backend = backend )
96- result2 = model2 (x )
97-
98- # Higher tol for dynamic, even higher for bfloat16
99- if static :
100- ATOL , RTOL = (1e-3 , 1e-3 )
101- elif dtype == torch .float16 :
102- ATOL , RTOL = (2e-3 , 2e-3 )
103- else :
104- ATOL , RTOL = (1e-2 , 1e-2 )
105-
106- torch .testing .assert_close (result , result2 , atol = ATOL , rtol = RTOL )
107-
108- # Check substitution worked
109- pre_nodes = backend .graph_pre_pass .nodes
110- post_nodes = backend .graph_post_pass .nodes
111-
112- # static is per-tensor, dynamic is per-token
113- key = QuantKey (dtype = FP8_DTYPE ,
114- static = static ,
115- per_tensor = static ,
116- symmetric = True )
117- rms_quant = FUSED_OPS [FusedRMSQuantKey (key , False )]
118- add_rms_quant = FUSED_OPS [FusedRMSQuantKey (key , True )]
119- fp8_quant = QUANT_OPS [key ]
120-
121- # In pre-nodes, fp8 quant should be there and fused kernels should not
122- assert find_auto_fn_maybe (pre_nodes , rms_quant ) is None
123- assert find_auto_fn_maybe (pre_nodes , add_rms_quant ) is None
124- find_auto_fn (pre_nodes , fp8_quant )
125-
126- # In post-nodes, fused kernels should be there and fp8 quant should not
127- find_auto_fn (post_nodes , rms_quant )
128- find_auto_fn (post_nodes , add_rms_quant )
129- assert find_auto_fn_maybe (post_nodes , fp8_quant ) is None
69+ # Reshape pass is needed for the fusion pass to work
70+ config = CompilationConfig .PassConfig (enable_fusion = True ,
71+ enable_reshape = True )
72+ reshape_pass = RedundantReshapesPass (config )
73+ fusion_pass = FusionPass .instance (config )
74+
75+ backend = TestBackend (reshape_pass , fusion_pass )
76+ model = TestModel (hidden_size , eps , static )
77+
78+ # First dimension dynamic
79+ x = torch .rand (num_tokens , hidden_size )
80+ torch ._dynamo .mark_dynamic (x , 0 )
81+
82+ result = model (x )
83+
84+ model2 = torch .compile (model , backend = backend )
85+ result2 = model2 (x )
86+
87+ # Higher tol for dynamic, even higher for bfloat16
88+ if static :
89+ ATOL , RTOL = (1e-3 , 1e-3 )
90+ elif dtype == torch .float16 :
91+ ATOL , RTOL = (2e-3 , 2e-3 )
92+ else :
93+ ATOL , RTOL = (1e-2 , 1e-2 )
94+
95+ torch .testing .assert_close (result , result2 , atol = ATOL , rtol = RTOL )
96+
97+ # Check substitution worked
98+ pre_nodes = backend .graph_pre_pass .nodes
99+ post_nodes = backend .graph_post_pass .nodes
100+
101+ # static is per-tensor, dynamic is per-token
102+ key = QuantKey (dtype = FP8_DTYPE ,
103+ static = static ,
104+ per_tensor = static ,
105+ symmetric = True )
106+ rms_quant = FUSED_OPS [FusedRMSQuantKey (key , False )]
107+ add_rms_quant = FUSED_OPS [FusedRMSQuantKey (key , True )]
108+ fp8_quant = QUANT_OPS [key ]
109+
110+ # In pre-nodes, fp8 quant should be present and fused kernels should not
111+ assert find_auto_fn_maybe (pre_nodes , rms_quant ) is None
112+ assert find_auto_fn_maybe (pre_nodes , add_rms_quant ) is None
113+ find_auto_fn (pre_nodes , fp8_quant )
114+
115+ # In post-nodes, fused kernels should be present and fp8 quant should not
116+ find_auto_fn (post_nodes , rms_quant )
117+ find_auto_fn (post_nodes , add_rms_quant )
118+ assert find_auto_fn_maybe (post_nodes , fp8_quant ) is None
0 commit comments