99import torch .fx as fx
1010
1111import vllm .envs as envs
12- from vllm .config import CompilationConfig
12+ from vllm .config import CompilationConfig , VllmConfig
1313from vllm .logger import init_logger
1414from vllm .utils import weak_ref_tensors
1515
@@ -149,14 +149,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
149149 """
150150
151151 def __init__ (self , module : torch .fx .GraphModule ,
152- compile_submod_names : List [str ],
153- compilation_configs : CompilationConfig , graph_pool ):
152+ compile_submod_names : List [str ], vllm_config : VllmConfig ,
153+ graph_pool ):
154154 super ().__init__ (module )
155155 from torch ._guards import detect_fake_mode
156156 self .fake_mode = detect_fake_mode ()
157157 self .compile_submod_names = compile_submod_names
158- self .compilation_configs = compilation_configs
158+ self .compilation_config = vllm_config . compilation_config
159159 self .graph_pool = graph_pool
160+ self .vllm_config = vllm_config
160161
161162 def run (self , * args ):
162163 fake_args = [
@@ -182,15 +183,15 @@ def call_module(self, target: torch.fx.node.Target,
182183 compiled_graph_for_general_shape = wrap_inductor (
183184 submod ,
184185 args ,
185- self .compilation_configs .inductor_compile_config ,
186- self .compilation_configs ,
186+ self .compilation_config .inductor_compile_config ,
187+ self .compilation_config ,
187188 graph_index = index ,
188189 num_graphs = len (self .compile_submod_names ),
189190 runtime_shape = None ,
190- use_inductor = self .compilation_configs .use_inductor )
191+ use_inductor = self .compilation_config .use_inductor )
191192
192193 self .module .__dict__ [target ] = PiecewiseBackend (
193- submod , self .compilation_configs , self .graph_pool , index ,
194+ submod , self .vllm_config , self .graph_pool , index ,
194195 len (self .compile_submod_names ), sym_shape_indices ,
195196 compiled_graph_for_general_shape )
196197
@@ -211,7 +212,8 @@ class VllmBackend:
211212 which handles the post-grad passes.
212213 """
213214
214- compilation_configs : CompilationConfig
215+ vllm_config : VllmConfig
216+ compilation_config : CompilationConfig
215217 graph_pool : Any
216218 _called : bool = False
217219 # the graph we compiled
@@ -227,7 +229,7 @@ class VllmBackend:
227229
228230 def __init__ (
229231 self ,
230- compilation_configs : CompilationConfig ,
232+ vllm_config : VllmConfig ,
231233 ):
232234 global global_graph_pool
233235 if global_graph_pool is None :
@@ -244,13 +246,14 @@ def __init__(
244246 self .sym_tensor_indices = []
245247 self .input_buffers = []
246248
247- self .compilation_configs = compilation_configs
249+ self .vllm_config = vllm_config
250+ self .compilation_config = vllm_config .compilation_config
248251
249252 # `torch.compile` is JIT compiled, so we don't need to
250253 # do anything here
251254
252255 def configure_post_pass (self ):
253- config = self .compilation_configs
256+ config = self .compilation_config
254257 self .post_grad_pass_manager .configure (config .pass_config )
255258
256259 # Post-grad custom passes are run using the post_grad_custom_post_pass
@@ -271,7 +274,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
271274 from .monitor import torch_compile_start_time
272275 dynamo_time = time .time () - torch_compile_start_time
273276 logger .info ("Dynamo bytecode transform time: %.2f s" , dynamo_time )
274- self .compilation_configs .compilation_time += dynamo_time
277+ self .compilation_config .compilation_time += dynamo_time
275278
276279 # we control the compilation process, each instance can only be
277280 # called once
@@ -281,7 +284,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
281284 self .configure_post_pass ()
282285
283286 self .split_gm , self .piecewise_graphs = split_graph (
284- graph , self .compilation_configs .splitting_ops )
287+ graph , self .compilation_config .splitting_ops )
285288
286289 from torch ._dynamo .utils import lazy_format_graph_code
287290 logger .debug ("%s" , lazy_format_graph_code ("before split" , self .graph ))
@@ -298,13 +301,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
298301 # propagate the split graph to the piecewise backend,
299302 # compile submodules with symbolic shapes
300303 PiecewiseCompileInterpreter (self .split_gm , submod_names_to_compile ,
301- self .compilation_configs ,
304+ self .vllm_config ,
302305 self .graph_pool ).run (* example_inputs )
303306
304307 self ._called = True
305308
306- if not self .compilation_configs .use_cudagraph or \
307- not self .compilation_configs .cudagraph_copy_inputs :
309+ if not self .compilation_config .use_cudagraph or \
310+ not self .compilation_config .cudagraph_copy_inputs :
308311 return self .split_gm
309312
310313 # if we need to copy input buffers for cudagraph
@@ -364,26 +367,26 @@ class ConcreteSizeEntry:
364367
365368class PiecewiseBackend :
366369
367- def __init__ (self , graph : fx .GraphModule ,
368- compilation_configs : CompilationConfig , graph_pool : Any ,
369- piecewise_compile_index : int , total_piecewise_compiles : int ,
370- sym_shape_indices : List [int ],
370+ def __init__ (self , graph : fx .GraphModule , vllm_config : VllmConfig ,
371+ graph_pool : Any , piecewise_compile_index : int ,
372+ total_piecewise_compiles : int , sym_shape_indices : List [int ],
371373 compiled_graph_for_general_shape : Callable ):
372374 """
373375 The backend for piecewise compilation.
374376 It mainly handles the compilation and cudagraph capturing.
375377
376378 We will compile `self.graph` once for the general shape,
377379 and then compile for different shapes specified in
378- `compilation_configs .compile_sizes`.
380+ `compilation_config .compile_sizes`.
379381
380382 Independently, we will capture cudagraph for different shapes.
381383
382384 If a shape needs both compilation and cudagraph, we will
383385 compile it first, and then capture cudagraph.
384386 """
385387 self .graph = graph
386- self .compilation_configs = compilation_configs
388+ self .vllm_config = vllm_config
389+ self .compilation_config = vllm_config .compilation_config
387390 self .graph_pool = graph_pool
388391 self .piecewise_compile_index = piecewise_compile_index
389392 self .total_piecewise_compiles = total_piecewise_compiles
@@ -393,10 +396,10 @@ def __init__(self, graph: fx.GraphModule,
393396 piecewise_compile_index == total_piecewise_compiles - 1 )
394397
395398 self .compile_sizes : Set [int ] = set (
396- self .compilation_configs .compile_sizes )
399+ self .compilation_config .compile_sizes )
397400 self .capture_sizes : Set [int ] = set (
398- self .compilation_configs .capture_sizes
399- ) if self .compilation_configs .use_cudagraph else set ()
401+ self .compilation_config .capture_sizes
402+ ) if self .compilation_config .use_cudagraph else set ()
400403
401404 self .first_run_finished = False
402405
@@ -423,7 +426,7 @@ def __call__(self, *args) -> Any:
423426 self .first_run_finished = True
424427 # no specific sizes to compile
425428 if self .is_last_graph and not self .to_be_compiled_sizes :
426- end_monitoring_torch_compile (self .compilation_configs )
429+ end_monitoring_torch_compile (self .vllm_config )
427430 return self .compiled_graph_for_general_shape (* args )
428431
429432 runtime_shape = args [self .sym_shape_indices [0 ]]
@@ -443,28 +446,28 @@ def __call__(self, *args) -> Any:
443446 entry .runnable = wrap_inductor (
444447 self .graph ,
445448 args ,
446- self .compilation_configs .inductor_compile_config ,
447- self .compilation_configs ,
449+ self .compilation_config .inductor_compile_config ,
450+ self .compilation_config ,
448451 graph_index = self .piecewise_compile_index ,
449452 num_graphs = self .total_piecewise_compiles ,
450453 runtime_shape = runtime_shape ,
451- use_inductor = self .compilation_configs .use_inductor )
454+ use_inductor = self .compilation_config .use_inductor )
452455
453456 # finished compilations for all required shapes
454457 if self .is_last_graph and not self .to_be_compiled_sizes :
455- end_monitoring_torch_compile (self .compilation_configs )
458+ end_monitoring_torch_compile (self .vllm_config )
456459
457460 if not entry .use_cudagraph :
458461 return entry .runnable (* args )
459462
460463 if entry .cudagraph is None :
461- if entry .num_finished_warmup < self .compilation_configs .cudagraph_num_of_warmups : # noqa
464+ if entry .num_finished_warmup < self .compilation_config .cudagraph_num_of_warmups : # noqa
462465 entry .num_finished_warmup += 1
463466 if self .is_first_graph :
464467 logger .debug (
465468 "Warming up %s/%s for shape %s" ,
466469 entry .num_finished_warmup ,
467- self .compilation_configs .cudagraph_num_of_warmups ,
470+ self .compilation_config .cudagraph_num_of_warmups ,
468471 runtime_shape )
469472 return entry .runnable (* args )
470473
0 commit comments