-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[torch.compile] CUDAGraph Inductor partition integration #24281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
4f6e1b4
1c1b600
50d1dda
7218e2b
71209e2
202b6f3
0b1e18a
b66568b
87c74dd
c0bd3fb
e16e23a
892ab46
eabb1b6
04e9801
6cf5bd5
70f45da
7eb5d57
3a6abd8
4cce30c
d3809fb
d7a73db
289a60e
29ae5f0
b5972fa
7570f4b
c7ff7c4
4a38b36
d4269d9
20b9ef1
e055458
45b7588
91c03a4
19787d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -326,6 +326,40 @@ def call_module(self, target: torch.fx.node.Target, | |
| i for i, x in enumerate(args) if isinstance(x, torch.SymInt) | ||
| ] | ||
| global compilation_start_time | ||
|
|
||
| if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE | ||
| and self.compilation_config.use_inductor_graph_partition): | ||
| # If we're using Inductor-based graph partitioning, we currently | ||
| # have the whole `fx.Graph` before Inductor lowering and | ||
| # and the piecewise splitting happens after all graph | ||
| # passes and fusions. Here, we add a custom hook for Inductor | ||
| # to wrap each partition with our static graph wrapper class to | ||
| # maintain more control over static graph capture and replay. | ||
|
|
||
| from torch._inductor.utils import CUDAGraphWrapperMetadata | ||
|
|
||
| from .cuda_graph import CUDAGraphOptions | ||
|
|
||
| static_graph_wrapper_class = resolve_obj_by_qualname( | ||
| current_platform.get_static_graph_wrapper_cls()) | ||
|
|
||
| def customized_cudagraph_wrapper( | ||
| f, metadata: CUDAGraphWrapperMetadata): | ||
| partition_id = metadata.partition_index | ||
| num_partitions = metadata.num_partitions | ||
| return static_graph_wrapper_class( | ||
| runnable=f, | ||
| vllm_config=self.vllm_config, | ||
| runtime_mode=CUDAGraphMode.PIECEWISE, | ||
| cudagraph_options=CUDAGraphOptions( | ||
| debug_log_enable=partition_id == 0, | ||
| gc_disable=partition_id != 0, | ||
| weak_ref_output=partition_id == num_partitions - 1, | ||
| )) | ||
|
|
||
| torch._inductor.utils.set_customized_partition_wrappers( | ||
|
||
| customized_cudagraph_wrapper) | ||
|
|
||
| compiled_graph_for_dynamic_shape = self.vllm_backend.\ | ||
| compiler_manager.compile( | ||
| submod, | ||
|
|
@@ -336,15 +370,20 @@ def call_module(self, target: torch.fx.node.Target, | |
| num_graphs=len(self.compile_submod_names), | ||
| runtime_shape=None) | ||
| # Lazy import here to avoid circular import | ||
| from .cuda_graph import CUDAGraphOptions | ||
| from .cuda_piecewise_backend import PiecewiseBackend | ||
|
|
||
| piecewise_backend = PiecewiseBackend( | ||
| submod, self.vllm_config, index, | ||
| len(self.compile_submod_names), sym_shape_indices, | ||
| compiled_graph_for_dynamic_shape, self.vllm_backend) | ||
|
|
||
| if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: | ||
| if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE | ||
| and | ||
| not self.compilation_config.use_inductor_graph_partition): | ||
| # We're using Dynamo-based piecewise splitting, so we wrap | ||
| # the whole subgraph with a static graph wrapper. | ||
| from .cuda_graph import CUDAGraphOptions | ||
BoyuanFeng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # resolve the static graph wrapper class (e.g. CUDAGraphWrapper | ||
| # class) as platform dependent. | ||
| static_graph_wrapper_class = resolve_obj_by_qualname( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.