Skip to content

Commit b158df2

Browse files
authored
remove resolve_op_overloads and use splitting_ops directly (#28081)
Signed-off-by: Boyuan Feng <[email protected]>
1 parent 1aaecda commit b158df2

File tree

3 files changed

+89
-65
lines changed

3 files changed

+89
-65
lines changed

tests/compile/test_config.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -214,28 +214,72 @@ def test_splitting_ops_dynamic():
214214
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
215215

216216

217-
def test_resolve_operator_overload():
217+
def test_should_split():
218218
import torch
219219

220-
from vllm.compilation.partition_rules import resolve_defined_ops
221-
222-
# Test valid operator names
223-
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
224-
assert len(resolved) == 2
225-
assert resolved[0] is torch.ops.aten.mm.default
226-
assert resolved[1] is torch.ops.aten.addmm.default
227-
228-
# Test that invalid operators are skipped (not raising exceptions)
229-
resolved = resolve_defined_ops(
230-
[
231-
"aten::mm.default",
232-
"aten::nonexistent_op.default", # This should be skipped
233-
"aten::addmm.default",
234-
]
220+
from vllm.compilation.partition_rules import should_split
221+
222+
graph = torch.fx.Graph()
223+
node = torch.fx.Node(
224+
graph=graph,
225+
name="dummy_node",
226+
op="call_function",
227+
target=torch.ops.aten.add.default,
228+
args=(),
229+
kwargs={},
230+
)
231+
232+
# supports OpOverloadPacket
233+
splitting_ops = ["aten::add"]
234+
assert should_split(node, splitting_ops)
235+
236+
# supports OpOverload
237+
splitting_ops = ["aten::add.default"]
238+
assert should_split(node, splitting_ops)
239+
240+
# supports OpOverload
241+
splitting_ops = ["aten::add.Tensor"]
242+
assert not should_split(node, splitting_ops)
243+
244+
@torch.library.custom_op(
245+
"silly::attention",
246+
mutates_args=["out"],
235247
)
236-
assert len(resolved) == 2 # Only 2 valid ops
237-
assert resolved[0] is torch.ops.aten.mm.default
238-
assert resolved[1] is torch.ops.aten.addmm.default
248+
def attention(
249+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
250+
) -> None:
251+
out.copy_(q + k + v)
252+
253+
q, k, v, out = [torch.randn(1)] * 4
254+
255+
# supports custom ops as OpOverloadPacket
256+
node = torch.fx.Node(
257+
graph=graph,
258+
name="dummy_node",
259+
op="call_function",
260+
target=torch.ops.silly.attention,
261+
args=(q, k, v, out),
262+
kwargs={},
263+
)
264+
265+
splitting_ops = ["silly::attention"]
266+
assert should_split(node, splitting_ops)
267+
268+
# supports custom ops as OpOverload
269+
node = torch.fx.Node(
270+
graph=graph,
271+
name="dummy_node",
272+
op="call_function",
273+
target=torch.ops.silly.attention.default,
274+
args=(q, k, v, out),
275+
kwargs={},
276+
)
277+
278+
splitting_ops = ["silly::attention"]
279+
assert should_split(node, splitting_ops)
280+
281+
splitting_ops = ["silly::attention.default"]
282+
assert should_split(node, splitting_ops)
239283

240284

241285
@pytest.mark.skipif(

vllm/compilation/backends.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from vllm.compilation.inductor_pass import pass_context
2020
from vllm.compilation.partition_rules import (
2121
inductor_partition_rule_context,
22-
resolve_defined_ops,
22+
should_split,
2323
)
2424
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
2525
from vllm.logger import init_logger
@@ -303,7 +303,7 @@ class SplitItem:
303303

304304

305305
def split_graph(
306-
graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
306+
graph: fx.GraphModule, splitting_ops: list[str]
307307
) -> tuple[fx.GraphModule, list[SplitItem]]:
308308
# split graph by ops
309309
subgraph_id = 0
@@ -312,12 +312,8 @@ def split_graph(
312312
for node in graph.graph.nodes:
313313
if node.op in ("output", "placeholder"):
314314
continue
315-
# Match node.target against resolved_ops
316-
# node.target can be OpOverloadPacket, need to check .default
317-
if node.op == "call_function" and (
318-
node.target in resolved_ops
319-
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
320-
):
315+
316+
if should_split(node, splitting_ops):
321317
subgraph_id += 1
322318
node_to_subgraph_id[node] = subgraph_id
323319
split_op_graphs.append(subgraph_id)
@@ -653,8 +649,7 @@ def __call__(
653649
else:
654650
fx_split_ops = self.compilation_config.splitting_ops or []
655651

656-
resolved_split_ops = resolve_defined_ops(fx_split_ops)
657-
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
652+
self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
658653

659654
from torch._dynamo.utils import lazy_format_graph_code
660655

vllm/compilation/partition_rules.py

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,54 +2,39 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import contextlib
5-
import logging
65

76
import torch
8-
from torch._library.utils import lookup_op
97

108
from vllm.logger import init_logger
119

1210
logger = init_logger(__name__)
1311

1412

15-
def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
16-
"""Resolve operator names to OpOverload objects.
13+
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
14+
"""
15+
Check if a node should be split for dynamo graph partition.
16+
It operates on dynamo graph, so the node.target can be anything.
17+
We need to check and split only on OpOverload and OpOverloadPacket.
18+
"""
1719

18-
Skips operators that fail to resolve (e.g., operators not registered or
19-
model-specific operators not present in the current model).
20+
if node.op != "call_function":
21+
return False
2022

21-
Note: Users should inspect the operator graph before lowering and ensure
22-
the specified operators are present in the final graph. Built-in PyTorch
23-
operators (aten::*, torch::*) may be decomposed, fused, or transformed
24-
during Inductor's compilation passes, so use them with caution.
23+
target = node.target
2524

26-
Args:
27-
op_names: List of operator names in PyTorch format
28-
(e.g., "vllm::unified_attention")
25+
if isinstance(target, torch._ops.OpOverloadPacket):
26+
# Example: "aten::add"
27+
return target._qualified_op_name in splitting_ops
2928

30-
Returns:
31-
List of successfully resolved operator overloads
32-
"""
33-
resolved = []
34-
for op_name in op_names:
35-
try:
36-
resolved.append(lookup_op(op_name))
37-
except Exception:
38-
# Skip operators that don't exist (e.g., model-specific ops)
39-
# Do not warn for attention ops, warn for others
40-
# (most likely manually specified)
41-
from vllm.config import CompilationConfig
42-
43-
logger.log(
44-
logging.DEBUG
45-
if op_name in CompilationConfig._attention_ops
46-
else logging.WARNING,
47-
"Failed to resolve operator for CUDAGraph partition: %s",
48-
op_name,
49-
)
50-
continue
51-
52-
return resolved
29+
if isinstance(target, torch._ops.OpOverload):
30+
# Example: "aten::add"
31+
packet_name = target.name()
32+
33+
# Example: "aten::add.default"
34+
op_overload_name = f"{packet_name}.{target._overloadname}"
35+
return op_overload_name in splitting_ops or packet_name in splitting_ops
36+
37+
return False
5338

5439

5540
@contextlib.contextmanager

0 commit comments

Comments
 (0)