88from vllm .compilation .decorators import (ignore_torch_compile ,
99 support_torch_compile )
1010from vllm .config import (CacheConfig , CompilationConfig , CompilationLevel ,
11- VllmConfig , set_current_vllm_config )
12- from vllm .forward_context import set_forward_context
11+ CUDAGraphMode , VllmConfig , set_current_vllm_config )
12+ from vllm .forward_context import BatchDescriptor , set_forward_context
1313from vllm .utils import direct_register_custom_op
1414
1515# create a library to hold the custom op
@@ -40,6 +40,39 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
4040)
4141
4242
43+ @torch .inference_mode
44+ def run_model (vllm_config : VllmConfig , model : nn .Module ,
45+ cudagraph_runtime_mode : CUDAGraphMode ):
46+ with set_forward_context ({}, vllm_config = vllm_config ):
47+ # warmup for the model with cudagraph_mode NONE
48+ model (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
49+
50+ # simulate cudagraphs capturing
51+ with set_forward_context ({},
52+ vllm_config = vllm_config ,
53+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
54+ batch_descriptor = BatchDescriptor (
55+ num_tokens = 2 , )):
56+ model (torch .randn (2 , MLP_SIZE ).cuda ())
57+ with set_forward_context ({},
58+ vllm_config = vllm_config ,
59+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
60+ batch_descriptor = BatchDescriptor (
61+ num_tokens = 1 , )):
62+ model (torch .randn (1 , MLP_SIZE ).cuda ())
63+
64+ # simulate cudagraphs replay
65+ with set_forward_context ({},
66+ vllm_config = vllm_config ,
67+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
68+ batch_descriptor = BatchDescriptor (
69+ num_tokens = 2 , )):
70+ output = model (torch .randn (2 , MLP_SIZE ).cuda ())
71+
72+ output = output .cpu ()
73+ return output .cpu ()
74+
75+
4376def test_ignore_torch_compile_decorator ():
4477 # piecewise
4578 vllm_config = VllmConfig (compilation_config = CompilationConfig (
@@ -48,6 +81,7 @@ def test_ignore_torch_compile_decorator():
4881 splitting_ops = ["silly.attention" ],
4982 cudagraph_capture_sizes = [1 , 2 ],
5083 ))
84+ cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
5185
5286 @support_torch_compile
5387 class A (nn .Module ):
@@ -86,12 +120,8 @@ class C(B):
86120 num_backend_compilations = 2 ,
87121 num_cudagraph_captured = 4 ,
88122 # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
89- ), set_forward_context ({}, vllm_config = vllm_config ):
90- # first run is for compile
91- mod_A (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
92- # run cudagraph captured sizes
93- mod_A (torch .randn (2 , MLP_SIZE ).cuda ())
94- mod_A (torch .randn (1 , MLP_SIZE ).cuda ())
123+ ):
124+ run_model (vllm_config , mod_A , cudagraph_runtime_mode )
95125
96126 with set_current_vllm_config (vllm_config ):
97127 mod_B = B (vllm_config = vllm_config , prefix = '' ).eval ().cuda ()
@@ -103,10 +133,8 @@ class C(B):
103133 num_piecewise_capturable_graphs_seen = 0 ,
104134 num_backend_compilations = 0 ,
105135 num_cudagraph_captured = 0 ,
106- ), set_forward_context ({}, vllm_config = vllm_config ):
107- mod_B (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
108- mod_B (torch .randn (2 , MLP_SIZE ).cuda ())
109- mod_B (torch .randn (1 , MLP_SIZE ).cuda ())
136+ ):
137+ run_model (vllm_config , mod_B , cudagraph_runtime_mode )
110138
111139 with set_current_vllm_config (vllm_config ):
112140 mod_C = C (vllm_config = vllm_config , prefix = '' ).eval ().cuda ()
@@ -119,10 +147,8 @@ class C(B):
119147 num_backend_compilations = 2 ,
120148 num_cudagraph_captured = 4 ,
121149 # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
122- ), set_forward_context ({}, vllm_config = vllm_config ):
123- mod_C (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
124- mod_C (torch .randn (2 , MLP_SIZE ).cuda ())
125- mod_C (torch .randn (1 , MLP_SIZE ).cuda ())
150+ ):
151+ run_model (vllm_config , mod_C , cudagraph_runtime_mode )
126152
127153
128154# Only enable torch.compile if
@@ -180,6 +206,7 @@ def test_conditional_compile_enable_if():
180206 splitting_ops = ["silly.attention" ],
181207 cudagraph_capture_sizes = [1 , 2 ],
182208 ))
209+ cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
183210
184211 with set_current_vllm_config (vllm_config ):
185212 mod_A = A (vllm_config = vllm_config , prefix = '' ).eval ().cuda ()
@@ -195,12 +222,8 @@ def test_conditional_compile_enable_if():
195222 num_backend_compilations = 4 ,
196223 num_cudagraph_captured = 8 ,
197224 # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
198- ), set_forward_context ({}, vllm_config = vllm_config ):
199- # first run is for compile
200- mod_A (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
201- # run cudagraph captured sizes
202- mod_A (torch .randn (2 , MLP_SIZE ).cuda ())
203- mod_A (torch .randn (1 , MLP_SIZE ).cuda ())
225+ ):
226+ run_model (vllm_config , mod_A , cudagraph_runtime_mode )
204227
205228 # Set kv_sharing_fast_prefill=False
206229 # which will cause A to be compiled and B to not be compiled
@@ -224,9 +247,5 @@ def test_conditional_compile_enable_if():
224247 num_backend_compilations = 4 ,
225248 num_cudagraph_captured = 8 ,
226249 # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
227- ), set_forward_context ({}, vllm_config = vllm_config ):
228- # first run is for compile
229- mod_A (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
230- # run cudagraph captured sizes
231- mod_A (torch .randn (2 , MLP_SIZE ).cuda ())
232- mod_A (torch .randn (1 , MLP_SIZE ).cuda ())
250+ ):
251+ run_model (vllm_config , mod_A , cudagraph_runtime_mode )
0 commit comments