Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@ def __init__(self, config: VllmConfig):
AllGatherGEMMPattern(self.model_dtype,
self.device).register(self.patterns)

def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
# only do replace for specific shapes
def is_applicable(self, splitting_ops: list[str],
shape: Optional[int]) -> bool:
if splitting_ops is None or splitting_ops == []:
return True
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0

Expand Down
3 changes: 2 additions & 1 deletion vllm/compilation/inductor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def hash_dict(dict_: dict[Any, Any]):
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()

def is_applicable_for_shape(self, shape: Optional[int]):
def is_applicable(self, splitting_ops: list[str],
runtime_shape: Optional[int]):
return True


Expand Down
3 changes: 2 additions & 1 deletion vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ def __init__(self):
def __call__(self, graph: fx.Graph):
shape = get_pass_context().runtime_shape
for pass_ in self.passes:
if pass_.is_applicable_for_shape(shape):
if pass_.is_applicable(self.splitting_ops, shape):
pass_(graph)

# always run fix_functionalization last
self.fix_functionalization(graph)

def configure(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
self.splitting_ops = config.compilation_config.splitting_ops
if self.pass_config.enable_noop:
self.passes += [NoOpEliminationPass(config)]

Expand Down
5 changes: 4 additions & 1 deletion vllm/compilation/sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,10 @@ def __init__(self, config: VllmConfig):
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()

def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
def is_applicable(self, splitting_ops: list[str],
shape: Optional[int]) -> bool:
if splitting_ops is None or splitting_ops == []:
return True
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0

Expand Down