1212from vllm .compilation .counter import compilation_counter
1313from vllm .compilation .decorators import (ignore_torch_compile ,
1414 support_torch_compile )
15- from vllm .config import (CompilationConfig , CompilationLevel , VllmConfig ,
16- set_current_vllm_config )
17- from vllm .envs import VLLM_USE_V1
18- from vllm .forward_context import set_forward_context
15+ from vllm .config import (CompilationConfig , CompilationLevel , CUDAGraphMode ,
16+ VllmConfig , set_current_vllm_config )
17+ from vllm .forward_context import BatchDescriptor , set_forward_context
1918from vllm .utils import direct_register_custom_op
2019
2120# create a library to hold the custom op
@@ -164,104 +163,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
164163 return x
165164
166165
167- def test_ignore_torch_compile_decorator ():
168- assert VLLM_USE_V1
169-
170- # piecewise
171- vllm_config = VllmConfig (compilation_config = CompilationConfig (
172- level = CompilationLevel .PIECEWISE ,
173- use_cudagraph = True ,
174- splitting_ops = ["silly.attention" ],
175- cudagraph_capture_sizes = [1 , 2 ],
176- ))
177-
178- @support_torch_compile
179- class A (nn .Module ):
180-
181- def __init__ (self ,
182- * ,
183- vllm_config : VllmConfig ,
184- prefix : str = '' ,
185- ** kwargs ) -> None :
186- super ().__init__ ()
187-
188- def forward (self , x : torch .Tensor ) -> torch .Tensor :
189- x = x + x
190- attn_output = torch .empty_like (x )
191- torch .ops .silly .attention (x , x , x , attn_output )
192- x = attn_output
193- x = x * 3
194- return x
195-
196- @ignore_torch_compile
197- class B (A ):
198- ...
199-
200- @support_torch_compile
201- class C (B ):
202- ...
203-
204- with set_current_vllm_config (vllm_config ):
205- mod_A = A (vllm_config = vllm_config , prefix = '' ).eval ().cuda ()
206-
207- # A has support_torch_compile
208- with compilation_counter .expect (
209- num_graphs_seen = 1 ,
210- num_piecewise_graphs_seen = 3 ,
211- num_piecewise_capturable_graphs_seen = 2 ,
212- num_backend_compilations = 2 ,
213- num_cudagraph_captured = 4 ,
214- # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
215- ), set_forward_context ({}, vllm_config = vllm_config ):
216- # first run is for compile
217- mod_A (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
218- # run cudagraph captured sizes
219- mod_A (torch .randn (2 , MLP_SIZE ).cuda ())
220- mod_A (torch .randn (1 , MLP_SIZE ).cuda ())
221-
222- with set_current_vllm_config (vllm_config ):
223- mod_B = B (vllm_config = vllm_config , prefix = '' ).eval ().cuda ()
224-
225- # B's ignore_torch_compile should override A's support_torch_compile
226- with compilation_counter .expect (
227- num_graphs_seen = 0 ,
228- num_piecewise_graphs_seen = 0 ,
229- num_piecewise_capturable_graphs_seen = 0 ,
230- num_backend_compilations = 0 ,
231- num_cudagraph_captured = 0 ,
232- ), set_forward_context ({}, vllm_config = vllm_config ):
233- mod_B (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
234- mod_B (torch .randn (2 , MLP_SIZE ).cuda ())
235- mod_B (torch .randn (1 , MLP_SIZE ).cuda ())
236-
237- with set_current_vllm_config (vllm_config ):
238- mod_C = C (vllm_config = vllm_config , prefix = '' ).eval ().cuda ()
239-
240- # C's support_torch_compile should override B's ignore_torch_compile
241- with compilation_counter .expect (
242- num_graphs_seen = 1 ,
243- num_piecewise_graphs_seen = 3 ,
244- num_piecewise_capturable_graphs_seen = 2 ,
245- num_backend_compilations = 2 ,
246- num_cudagraph_captured = 4 ,
247- # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
248- ), set_forward_context ({}, vllm_config = vllm_config ):
249- mod_C (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
250- mod_C (torch .randn (2 , MLP_SIZE ).cuda ())
251- mod_C (torch .randn (1 , MLP_SIZE ).cuda ())
252-
253-
254166@torch .inference_mode
255- def run_model (vllm_config , model : nn .Module , inputs : torch .Tensor ):
167+ def run_model (vllm_config : VllmConfig , model : nn .Module , inputs : torch .Tensor ,
168+ cudagraph_runtime_mode : CUDAGraphMode ):
256169 with set_forward_context ({}, vllm_config = vllm_config ):
257- # First run is for compile
170+ # warmup for the model with cudagraph_mode NONE
258171 model (inputs )
259172
260- # Run CUDAGraph captured sizes
261- model (inputs [:2 ])
262- model (inputs [:1 ])
263-
264- output = model (inputs [:2 ])
173+ # simulate cudagraphs capturing
174+ with set_forward_context ({},
175+ vllm_config = vllm_config ,
176+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
177+ batch_descriptor = BatchDescriptor (
178+ num_tokens = 2 , )):
179+ model (inputs [:2 ])
180+ with set_forward_context ({},
181+ vllm_config = vllm_config ,
182+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
183+ batch_descriptor = BatchDescriptor (
184+ num_tokens = 1 , )):
185+ model (inputs [:1 ])
186+
187+ # simulate cudagraphs replay
188+ with set_forward_context ({},
189+ vllm_config = vllm_config ,
190+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
191+ batch_descriptor = BatchDescriptor (
192+ num_tokens = 2 , )):
193+ output = model (inputs [:2 ])
265194
266195 output = output .cpu ()
267196 return output .cpu ()
@@ -277,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
277206 splitting_ops = ["silly.attention" ],
278207 cudagraph_capture_sizes = [1 , 2 ],
279208 ))
209+ cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
280210
281211 with set_current_vllm_config (vllm_config ):
282212 model = SimpleModelWithTwoGraphs (mlp_size = MLP_SIZE ,
@@ -299,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal():
299229 num_cudagraph_captured = 8 ,
300230 # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
301231 ):
302- outputs .append (run_model (vllm_config , model , inputs ))
232+ outputs .append (
233+ run_model (vllm_config , model , inputs , cudagraph_runtime_mode ))
303234
304235 # no compile or cudagraph
305236 vllm_config = VllmConfig (compilation_config = CompilationConfig (
306237 level = CompilationLevel .NO_COMPILATION , ))
238+ cudagraph_runtime_mode = CUDAGraphMode .NONE
307239
308240 with set_current_vllm_config (vllm_config ):
309241 model = SimpleModelWithTwoGraphs (mlp_size = MLP_SIZE ,
@@ -318,14 +250,16 @@ def test_multi_graph_piecewise_compile_outputs_equal():
318250 num_backend_compilations = 0 ,
319251 num_cudagraph_captured = 0 ,
320252 ):
321- outputs .append (run_model (vllm_config , model , inputs ))
253+ outputs .append (
254+ run_model (vllm_config , model , inputs , cudagraph_runtime_mode ))
322255
323256 # piecewise compile without CUDA graph
324257 vllm_config = VllmConfig (compilation_config = CompilationConfig (
325258 level = CompilationLevel .PIECEWISE ,
326259 use_cudagraph = False ,
327260 splitting_ops = ["silly.attention" ],
328261 ))
262+ cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
329263
330264 with set_current_vllm_config (vllm_config ):
331265 model = SimpleModelWithTwoGraphs (mlp_size = MLP_SIZE ,
@@ -340,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
340274 num_backend_compilations = 4 ,
341275 num_cudagraph_captured = 0 , # no cudagraph captured
342276 ):
343- outputs .append (run_model (vllm_config , model , inputs ))
277+ outputs .append (
278+ run_model (vllm_config , model , inputs , cudagraph_runtime_mode ))
344279
345280 # Generally don't expect outputs with and without inductor
346281 # to be bitwise equivalent
0 commit comments