Skip to content

Commit 782521e

Browse files
youkaichaoweilong.yu
authored andcommitted
[torch.compile] use depyf to dump torch.compile internals (vllm-project#10972)
Signed-off-by: youkaichao <[email protected]>
1 parent e30096b commit 782521e

File tree

7 files changed

+66
-42
lines changed

7 files changed

+66
-42
lines changed

requirements-common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need
3333
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
3434
einops # Required for Qwen2-VL.
3535
compressed-tensors == 0.8.0 # required for compressed-tensors
36+
depyf==0.18.0 # required for profiling and debugging torch.compile

vllm/compilation/backends.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.fx as fx
1010

1111
import vllm.envs as envs
12-
from vllm.config import CompilationConfig
12+
from vllm.config import CompilationConfig, VllmConfig
1313
from vllm.logger import init_logger
1414
from vllm.utils import weak_ref_tensors
1515

@@ -149,14 +149,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
149149
"""
150150

151151
def __init__(self, module: torch.fx.GraphModule,
152-
compile_submod_names: List[str],
153-
compilation_configs: CompilationConfig, graph_pool):
152+
compile_submod_names: List[str], vllm_config: VllmConfig,
153+
graph_pool):
154154
super().__init__(module)
155155
from torch._guards import detect_fake_mode
156156
self.fake_mode = detect_fake_mode()
157157
self.compile_submod_names = compile_submod_names
158-
self.compilation_configs = compilation_configs
158+
self.compilation_config = vllm_config.compilation_config
159159
self.graph_pool = graph_pool
160+
self.vllm_config = vllm_config
160161

161162
def run(self, *args):
162163
fake_args = [
@@ -182,15 +183,15 @@ def call_module(self, target: torch.fx.node.Target,
182183
compiled_graph_for_general_shape = wrap_inductor(
183184
submod,
184185
args,
185-
self.compilation_configs.inductor_compile_config,
186-
self.compilation_configs,
186+
self.compilation_config.inductor_compile_config,
187+
self.compilation_config,
187188
graph_index=index,
188189
num_graphs=len(self.compile_submod_names),
189190
runtime_shape=None,
190-
use_inductor=self.compilation_configs.use_inductor)
191+
use_inductor=self.compilation_config.use_inductor)
191192

192193
self.module.__dict__[target] = PiecewiseBackend(
193-
submod, self.compilation_configs, self.graph_pool, index,
194+
submod, self.vllm_config, self.graph_pool, index,
194195
len(self.compile_submod_names), sym_shape_indices,
195196
compiled_graph_for_general_shape)
196197

@@ -211,7 +212,8 @@ class VllmBackend:
211212
which handles the post-grad passes.
212213
"""
213214

214-
compilation_configs: CompilationConfig
215+
vllm_config: VllmConfig
216+
compilation_config: CompilationConfig
215217
graph_pool: Any
216218
_called: bool = False
217219
# the graph we compiled
@@ -227,7 +229,7 @@ class VllmBackend:
227229

228230
def __init__(
229231
self,
230-
compilation_configs: CompilationConfig,
232+
vllm_config: VllmConfig,
231233
):
232234
global global_graph_pool
233235
if global_graph_pool is None:
@@ -244,13 +246,14 @@ def __init__(
244246
self.sym_tensor_indices = []
245247
self.input_buffers = []
246248

247-
self.compilation_configs = compilation_configs
249+
self.vllm_config = vllm_config
250+
self.compilation_config = vllm_config.compilation_config
248251

249252
# `torch.compile` is JIT compiled, so we don't need to
250253
# do anything here
251254

252255
def configure_post_pass(self):
253-
config = self.compilation_configs
256+
config = self.compilation_config
254257
self.post_grad_pass_manager.configure(config.pass_config)
255258

256259
# Post-grad custom passes are run using the post_grad_custom_post_pass
@@ -271,7 +274,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
271274
from .monitor import torch_compile_start_time
272275
dynamo_time = time.time() - torch_compile_start_time
273276
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
274-
self.compilation_configs.compilation_time += dynamo_time
277+
self.compilation_config.compilation_time += dynamo_time
275278

276279
# we control the compilation process, each instance can only be
277280
# called once
@@ -281,7 +284,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
281284
self.configure_post_pass()
282285

283286
self.split_gm, self.piecewise_graphs = split_graph(
284-
graph, self.compilation_configs.splitting_ops)
287+
graph, self.compilation_config.splitting_ops)
285288

286289
from torch._dynamo.utils import lazy_format_graph_code
287290
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
@@ -298,13 +301,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
298301
# propagate the split graph to the piecewise backend,
299302
# compile submodules with symbolic shapes
300303
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
301-
self.compilation_configs,
304+
self.vllm_config,
302305
self.graph_pool).run(*example_inputs)
303306

304307
self._called = True
305308

306-
if not self.compilation_configs.use_cudagraph or \
307-
not self.compilation_configs.cudagraph_copy_inputs:
309+
if not self.compilation_config.use_cudagraph or \
310+
not self.compilation_config.cudagraph_copy_inputs:
308311
return self.split_gm
309312

310313
# if we need to copy input buffers for cudagraph
@@ -364,26 +367,26 @@ class ConcreteSizeEntry:
364367

365368
class PiecewiseBackend:
366369

367-
def __init__(self, graph: fx.GraphModule,
368-
compilation_configs: CompilationConfig, graph_pool: Any,
369-
piecewise_compile_index: int, total_piecewise_compiles: int,
370-
sym_shape_indices: List[int],
370+
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
371+
graph_pool: Any, piecewise_compile_index: int,
372+
total_piecewise_compiles: int, sym_shape_indices: List[int],
371373
compiled_graph_for_general_shape: Callable):
372374
"""
373375
The backend for piecewise compilation.
374376
It mainly handles the compilation and cudagraph capturing.
375377
376378
We will compile `self.graph` once for the general shape,
377379
and then compile for different shapes specified in
378-
`compilation_configs.compile_sizes`.
380+
`compilation_config.compile_sizes`.
379381
380382
Independently, we will capture cudagraph for different shapes.
381383
382384
If a shape needs both compilation and cudagraph, we will
383385
compile it first, and then capture cudagraph.
384386
"""
385387
self.graph = graph
386-
self.compilation_configs = compilation_configs
388+
self.vllm_config = vllm_config
389+
self.compilation_config = vllm_config.compilation_config
387390
self.graph_pool = graph_pool
388391
self.piecewise_compile_index = piecewise_compile_index
389392
self.total_piecewise_compiles = total_piecewise_compiles
@@ -393,10 +396,10 @@ def __init__(self, graph: fx.GraphModule,
393396
piecewise_compile_index == total_piecewise_compiles - 1)
394397

395398
self.compile_sizes: Set[int] = set(
396-
self.compilation_configs.compile_sizes)
399+
self.compilation_config.compile_sizes)
397400
self.capture_sizes: Set[int] = set(
398-
self.compilation_configs.capture_sizes
399-
) if self.compilation_configs.use_cudagraph else set()
401+
self.compilation_config.capture_sizes
402+
) if self.compilation_config.use_cudagraph else set()
400403

401404
self.first_run_finished = False
402405

@@ -423,7 +426,7 @@ def __call__(self, *args) -> Any:
423426
self.first_run_finished = True
424427
# no specific sizes to compile
425428
if self.is_last_graph and not self.to_be_compiled_sizes:
426-
end_monitoring_torch_compile(self.compilation_configs)
429+
end_monitoring_torch_compile(self.vllm_config)
427430
return self.compiled_graph_for_general_shape(*args)
428431

429432
runtime_shape = args[self.sym_shape_indices[0]]
@@ -443,28 +446,28 @@ def __call__(self, *args) -> Any:
443446
entry.runnable = wrap_inductor(
444447
self.graph,
445448
args,
446-
self.compilation_configs.inductor_compile_config,
447-
self.compilation_configs,
449+
self.compilation_config.inductor_compile_config,
450+
self.compilation_config,
448451
graph_index=self.piecewise_compile_index,
449452
num_graphs=self.total_piecewise_compiles,
450453
runtime_shape=runtime_shape,
451-
use_inductor=self.compilation_configs.use_inductor)
454+
use_inductor=self.compilation_config.use_inductor)
452455

453456
# finished compilations for all required shapes
454457
if self.is_last_graph and not self.to_be_compiled_sizes:
455-
end_monitoring_torch_compile(self.compilation_configs)
458+
end_monitoring_torch_compile(self.vllm_config)
456459

457460
if not entry.use_cudagraph:
458461
return entry.runnable(*args)
459462

460463
if entry.cudagraph is None:
461-
if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups: # noqa
464+
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
462465
entry.num_finished_warmup += 1
463466
if self.is_first_graph:
464467
logger.debug(
465468
"Warming up %s/%s for shape %s",
466469
entry.num_finished_warmup,
467-
self.compilation_configs.cudagraph_num_of_warmups,
470+
self.compilation_config.cudagraph_num_of_warmups,
468471
runtime_shape)
469472
return entry.runnable(*args)
470473

vllm/compilation/decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __call__(self, *args, **kwargs):
185185
"Unsupported dynamic dimensions"
186186
f" {dims} for argument {k} with type {type(arg)}.")
187187
# here, it is the starting point of the `torch.compile` process
188-
start_monitoring_torch_compile(self.vllm_config.compilation_config)
188+
start_monitoring_torch_compile(self.vllm_config)
189189

190190
# if we don't use custom dispatcher, we can directly call the
191191
# compiled function and let torch.compile handle the dispatching,

vllm/compilation/monitor.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,36 @@
1+
import os
12
import time
23

3-
from vllm.config import CompilationConfig, CompilationLevel
4+
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
45
from vllm.logger import init_logger
56

67
logger = init_logger(__name__)
78

9+
context_manager = None
810
torch_compile_start_time: float = 0.0
911

1012

11-
def start_monitoring_torch_compile(compilation_config: CompilationConfig):
13+
def start_monitoring_torch_compile(vllm_config: VllmConfig):
1214
global torch_compile_start_time
1315
torch_compile_start_time = time.time()
1416

17+
compilation_config: CompilationConfig = vllm_config.compilation_config
18+
if compilation_config.level == CompilationLevel.PIECEWISE and \
19+
compilation_config.debug_dump_path:
20+
import depyf
21+
path = os.path.join(compilation_config.debug_dump_path,
22+
f"rank_{vllm_config.parallel_config.rank}")
23+
global context_manager
24+
context_manager = depyf.prepare_debug(path)
25+
context_manager.__enter__()
1526

16-
def end_monitoring_torch_compile(compilation_config: CompilationConfig):
27+
28+
def end_monitoring_torch_compile(vllm_config: VllmConfig):
29+
compilation_config: CompilationConfig = vllm_config.compilation_config
1730
if compilation_config.level == CompilationLevel.PIECEWISE:
1831
logger.info("torch.compile takes %.2f s in total",
1932
compilation_config.compilation_time)
33+
global context_manager
34+
if context_manager is not None:
35+
context_manager.__exit__(None, None, None)
36+
context_manager = None

vllm/compilation/wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def __init__(self,
3232
# default compilation settings
3333
# compiling the forward method
3434

35-
backend = get_current_vllm_config(
36-
).compilation_config.init_backend()
35+
vllm_config = get_current_vllm_config()
36+
backend = vllm_config.compilation_config.init_backend(vllm_config)
3737

3838
compiled_callable = torch.compile(
3939
self.forward,

vllm/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2222,6 +2222,7 @@ class CompilationConfig(BaseModel):
22222222
- 1: dynamo as is.
22232223
- 2: dynamo once.
22242224
- 3: piecewise compilation.
2225+
- debug_dump_path: the path to dump the debug information.
22252226
- backend: the backend for compilation. It needs to be a string.
22262227
- "" (empty string): use the default backend.
22272228
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
@@ -2289,6 +2290,7 @@ class CompilationConfig(BaseModel):
22892290
certain small batchsizes, where inductor is good at optimizing.
22902291
""" # noqa
22912292
level: int = 0
2293+
debug_dump_path: str = ""
22922294
backend: str = ""
22932295
custom_ops: List[str] = Field(default_factory=list)
22942296
splitting_ops: List[str] = Field(default_factory=lambda: [
@@ -2394,7 +2396,7 @@ def model_post_init(self, __context: Any) -> None:
23942396
self.static_forward_context = {}
23952397
self.compilation_time = 0.0
23962398

2397-
def init_backend(self) -> Union[str, Callable]:
2399+
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
23982400
if self.level == CompilationLevel.NO_COMPILATION:
23992401
raise ValueError("No compilation level is set.")
24002402

@@ -2413,7 +2415,7 @@ def init_backend(self) -> Union[str, Callable]:
24132415
# merge with the config use_inductor
24142416
assert self.level == CompilationLevel.PIECEWISE
24152417
from vllm.compilation.backends import VllmBackend
2416-
return VllmBackend(self)
2418+
return VllmBackend(vllm_config)
24172419

24182420
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
24192421
"""To complete the initialization of config,

vllm/worker/model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,8 @@ def load_model(self) -> None:
11621162

11631163
if self.vllm_config.compilation_config.level ==\
11641164
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
1165-
backend = self.vllm_config.compilation_config.init_backend()
1165+
backend = self.vllm_config.compilation_config.init_backend(
1166+
self.vllm_config)
11661167
self.model = torch.compile(
11671168
self.model,
11681169
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,

0 commit comments

Comments
 (0)