Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,15 @@ def __init__(self, config: VllmConfig):

self.dump_patterns(config, self.patterns)

def is_applicable_for_shape(self, shape: int | None) -> bool:
# only do replace for specific shapes
def is_applicable(self, shape: int | None) -> bool:
# This pass is applied on top of the sequence parallelism pass.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for now but @cascade812 isn't this pass technically fine no matter the shape? Obviously it won't match anything if sequence parallelism didn't run, but still

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, removing this implementation would work as well. But it would trigger matching logic for this pass which could add some overhead.

# It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details.
if (
not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition
):
return True
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0

Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/inductor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def hash_dict(dict_: dict[Any, Any]):
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()

def is_applicable_for_shape(self, shape: int | None):
def is_applicable(self, shape: int | None):
return True


Expand Down
4 changes: 3 additions & 1 deletion vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ def __call__(self, graph: fx.Graph):

shape = get_pass_context().runtime_shape
for pass_ in self.passes:
if pass_.is_applicable_for_shape(shape):
if pass_.is_applicable(shape):
pass_(graph)
VllmInductorPass.dump_prefix += 1
else:
logger.debug("Skipping %s with shape %s", pass_, shape)

# post-cleanup goes before fix_functionalization
# because it requires a functional graph
Expand Down
20 changes: 19 additions & 1 deletion vllm/compilation/sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,25 @@ def __init__(self, config: VllmConfig):
).register(self.patterns)
self.dump_patterns(config, self.patterns)

def is_applicable_for_shape(self, shape: int | None) -> bool:
def is_applicable(self, shape: int | None) -> bool:
# When sequence parallelism is enabled, the residual tensor from RMSNorm
# needs to be split along the sequence dimension. However, this dimension
# is symbolic during piecewise compilation, and splitting symbolic shapes
# is not supported.
#
# This pass is therefore only applied when the sequence dimension is
# concrete:
# 1. In full-graph compilation mode (no Dynamo splitting ops are used).
# For this case we always pad num_tokens to be a multiple of
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
# 2. For specific shape provided during compilation (e.g., from
# `compile_sizes`), which must be divisible by the tensor-parallel
# size.
if (
not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition
):
return True
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0

Expand Down
2 changes: 2 additions & 0 deletions vllm/compilation/vllm_inductor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import operator
import time
import weakref
from typing import ClassVar

import regex as re
Expand All @@ -28,6 +29,7 @@ class VllmInductorPass(InductorPass):
"""Keep track of pass index for debug dump ordering."""

def __init__(self, config: VllmConfig):
self.compilation_config = weakref.proxy(config.compilation_config)
self.pass_config = config.compilation_config.pass_config
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
Expand Down