Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ def test_full_graph(model):
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model="meta-llama/Meta-Llama-3-8B",
enforce_eager=True,
load_format="dummy")
llm.generate(prompts, sampling_params)
llm = LLM(model=model, enforce_eager=True)

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
156 changes: 156 additions & 0 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import operator

import torch
import torch.fx as fx


def fix_functionalization(graph: fx.Graph):
"""
Rewrite the graph module to replace the pattern involving
torch._higher_order_ops.auto_functionalize.auto_functionalized
with a direct call to the inplace custom op.

# TODO: check if PyTorch nightly has fixed this issue
"""

# debug code, if we want to see the graph before the transformation
# with open("before.py", "w") as f:
# print(graph.python_code(root_module="self", verbose=True).src, file=f)

nodes_to_remove = []

for node in graph.nodes:
# Identify the auto_functionalized node
if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa
if node.args[0] == torch.ops._C.rotary_embedding.default:
# manual replace for rotary_embedding

# Now, collect the arguments
kwargs = node.kwargs

query = kwargs['query']
mm_node = query.args[0].args[0]

# Create a new call to torch.ops._C.rotary_embedding.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(torch.ops._C.rotary_embedding.default,
kwargs=kwargs)

# Remove the auto_functionalized node
# Since the node may have outputs, we need to handle its users
# Replace uses of the outputs (getitem nodes) with mm_node
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
# Remove the getitem node
for getitem_user in list(user.users):
if (getitem_user.op == 'call_function'
and getitem_user.target
== torch.ops.aten.slice_scatter.default):
# Replace the uses of slice_scatter node
# with mm_node
getitem_user.replace_all_uses_with(mm_node)
nodes_to_remove.append(getitem_user)
nodes_to_remove.append(user)
nodes_to_remove.append(node)

elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
# manual replace for fused_add_rms_norm
# this is the most effective optimization for llama
# failing to do this will result in many unnecessary copies

kwargs = node.kwargs

input = kwargs['input']
residual = kwargs['residual']

# Create a new call to torch.ops._C.rotary_embedding.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs)

for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
# Remove the getitem node
if user.args[1] == 1:
replace_node = input
elif user.args[1] == 2:
replace_node = residual
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)

elif node.args[0] == torch.ops._C.rms_norm.default:
# manual replace for rms_norm

kwargs = node.kwargs

input = kwargs['input']
out = kwargs['out']
weight = kwargs['weight']
epsilon = kwargs['epsilon']
# Create a new call to torch.ops._C.rotary_embedding.default
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.rms_norm.default,
args=(out, input, weight, epsilon),
)

replace_node = out

for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)

elif node.args[0] == torch.ops._C.silu_and_mul.default:
# manual replace for silu_and_mul

kwargs = node.kwargs

input = kwargs['input']
out = kwargs['out']

# Create a new call to torch.ops._C.rotary_embedding.default
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.silu_and_mul.default,
args=(out, input),
)
replace_node = out

for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)

# Remove the nodes all at once
for node in nodes_to_remove:
graph.erase_node(node)

# debug code, if we want to see the graph after the transformation
# with open("after.py", "w") as f:
# print(graph.python_code(root_module="self", verbose=True).src, file=f)


def vllm_backend(graph, example_inputs):
from torch._inductor import config
current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx
current_config['post_grad_custom_post_pass'] = fix_functionalization
return compile_fx(graph, example_inputs, config_patches=current_config)
3 changes: 2 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,9 @@ def load_model(self) -> None:
"This may lead to less accurate results!")

if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
from vllm.compilation.backends import vllm_backend
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or "eager"
backend = get_torch_compile_backend() or vllm_backend
self.model = torch.compile(
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
Expand Down