Skip to content

Commit 32d222a

Browse files
committed
[compile] Turn on TP/SP when use_inductor_graph_partition=True
Signed-off-by: angelayi <[email protected]>
1 parent f178780 commit 32d222a

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,8 @@ def is_applicable(self, shape: int | None) -> bool:
435435
# This pass is applied on top of the sequence parallelism pass.
436436
# It inherits the same applicability condition as `SequenceParallelismPass`.
437437
# See `SequenceParallelismPass.is_applicable` for more details.
438-
if self.splitting_ops is None or self.splitting_ops == []:
438+
splitting_ops = self.compilation_config.splitting_ops
439+
if not splitting_ops or self.compilation_config.use_inductor_graph_partition:
439440
return True
440441
tp_size = get_tensor_model_parallel_world_size()
441442
return shape is not None and shape % tp_size == 0

vllm/compilation/pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def __call__(self, graph: fx.Graph):
7474
if pass_.is_applicable(shape):
7575
pass_(graph)
7676
VllmInductorPass.dump_prefix += 1
77+
else:
78+
logger.debug(f"Skipping {pass_} with shape {shape}")
7779

7880
# post-cleanup goes before fix_functionalization
7981
# because it requires a functional graph

vllm/compilation/sequence_parallelism.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,8 @@ def is_applicable(self, shape: int | None) -> bool:
496496
# 2. For specific shape provided during compilation (e.g., from
497497
# `compile_sizes`), which must be divisible by the tensor-parallel
498498
# size.
499-
if self.splitting_ops is None or self.splitting_ops == []:
499+
splitting_ops = self.compilation_config.splitting_ops
500+
if not splitting_ops or self.compilation_config.use_inductor_graph_partition:
500501
return True
501502
tp_size = get_tensor_model_parallel_world_size()
502503
return shape is not None and shape % tp_size == 0

vllm/compilation/vllm_inductor_pass.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,10 @@ class VllmInductorPass(InductorPass):
2828
"""Keep track of pass index for debug dump ordering."""
2929

3030
def __init__(self, config: VllmConfig):
31+
self.compilation_config = config.compilation_config
3132
self.pass_config = config.compilation_config.pass_config
32-
self.splitting_ops = config.compilation_config.splitting_ops
33-
self.model_dtype = config.model_config.dtype if config.model_config \
34-
else None
35-
self.device = config.device_config.device if config.device_config \
36-
else None
33+
self.model_dtype = config.model_config.dtype if config.model_config else None
34+
self.device = config.device_config.device if config.device_config else None
3735
self.pass_name = self.__class__.__name__
3836

3937
@staticmethod

0 commit comments

Comments
 (0)