-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass #10902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
5dc6d69
fe74515
cd15aca
427bb9d
acb8557
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from typing import Union | ||
| from typing import Iterable, Union | ||
|
|
||
| import torch.fx | ||
| from torch import SymInt | ||
|
|
@@ -15,13 +15,13 @@ | |
|
|
||
| class RedundantReshapesPass(VllmInductorPass): | ||
| """ | ||
| This is an inductor pass that removes redundant reshape operations. | ||
| This is an inductor pass that removes redundant reshape/slice operations. | ||
| It is required for RMSNorm-quant fusion to work properly. | ||
| That's because apply_fp8_linear adds a reshape, which is redundant | ||
| in the 2D-case. | ||
|
|
||
| Example graph: | ||
| in the 2D-case. Additionally, torch internal no-op elimination pass does | ||
| not handle certain slice variants. | ||
|
|
||
| Example graph 1: | ||
| getitem_1: "f16[s0, 4096]" = ... | ||
| view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) | ||
| at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) | ||
|
|
@@ -31,6 +31,22 @@ class RedundantReshapesPass(VllmInductorPass): | |
| getitem_1: "f16[s0, 4096]" = ... | ||
| at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) | ||
| out: "f8e4m3fn[s0, 4096]" = at[1] | ||
|
|
||
| Example graph 2: | ||
| arg0: "s0" = ... | ||
| scaled_mm: "f16[s0, 4096]" = ... | ||
| slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0) | ||
| at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...) | ||
| out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0) | ||
|
|
||
| Can be replaced with: | ||
| arg0: "s0" = ... | ||
| scaled_mm: "f16[s0, 4096]" = ... | ||
| at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...) | ||
| out: "f16[s0, 4096]" = at[1] | ||
|
|
||
| TODO(luka): This is currently tested in test_fusion, | ||
| but separate tests could be good. | ||
| """ | ||
|
|
||
| def __call__(self, graph: torch.fx.Graph): | ||
|
|
@@ -50,18 +66,48 @@ def __call__(self, graph: torch.fx.Graph): | |
| # Invalid reshape args, skip | ||
| continue | ||
|
|
||
| if all( | ||
| self.dims_equivalent(s, i_s) | ||
| for s, i_s in zip(shape, input_shape)): | ||
| if self.all_dims_equivalent(shape, input_shape): | ||
| node.replace_all_uses_with(input) | ||
| graph.erase_node(node) | ||
| count += 1 | ||
|
|
||
| elif is_func(node, torch.ops.aten.slice.Tensor): | ||
| input, dim_index, start, end = node.args[:4] | ||
| input_shape = input.meta["val"].shape | ||
| i_dim = input_shape[dim_index] | ||
|
|
||
| if start == 0 and self.dims_equivalent(end, i_dim): | ||
| node.replace_all_uses_with(input) | ||
| graph.erase_node(node) | ||
| count += 1 | ||
|
|
||
| elif is_func(node, torch.ops.aten.slice_scatter.default): | ||
|
||
| base, view, dim_index, start, end = node.args[:5] | ||
| base_shape = base.meta["val"].shape | ||
| view_shape = view.meta["val"].shape | ||
|
|
||
| view_dim = view_shape[dim_index] | ||
|
|
||
| # Check that view fully covers base and the full view is used | ||
| # (if the view fully covered the base after slicing but was not | ||
| # fully used, we could replace slice_scatter with a simple slice | ||
| # but that's a niche case). | ||
| if (base_shape == view_shape and start == 0 | ||
| and self.dims_equivalent(end, view_dim)): | ||
| node.replace_all_uses_with(view) | ||
| graph.erase_node(node) | ||
| count += 1 | ||
|
|
||
| logger.debug("Removed %s no-op reshapes", count) | ||
|
|
||
| self.dump_graph(graph, "after_reshapes") | ||
| self.end_and_log() | ||
|
|
||
| def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]], | ||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| i_dims: Iterable[Union[int, SymInt]]): | ||
| return all( | ||
| self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) | ||
|
|
||
| def dims_equivalent(self, dim: Union[int, torch.fx.Node], | ||
| i_dim: Union[int, SymInt]) -> bool: | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| import torch | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.config import CompilationLevel, get_current_vllm_config | ||
| from vllm.platforms import current_platform | ||
|
|
||
| # Input scaling factors are no longer optional in _scaled_mm starting | ||
|
|
@@ -161,10 +162,14 @@ def apply_fp8_linear( | |
| # Note: we pad the input because torch._scaled_mm is more performant | ||
| # for matrices with batch dimension > 16. | ||
| # This could change in the future. | ||
| # We also don't pad when using torch.compile, | ||
| # as it breaks with dynamic shapes. | ||
| config = get_current_vllm_config().compilation_config | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this cached? It could be expensive each forward call
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, in eager mode this will get called on every forward pass, but it will only happen once when compiled. In eager mode there isn't really a better way that's still correct - the only way is to check the config context. I don't think this getter is significant but I haven't measured it.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could pass in a
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we'd have to pass that flag through the whole call stack though so I don't think it's worth it. I'll run a small model. |
||
| do_pad = config.level < CompilationLevel.PIECEWISE | ||
| qinput, x_scale = ops.scaled_fp8_quant( | ||
| input_2d, | ||
| input_scale, | ||
| num_token_padding=17, | ||
| num_token_padding=17 if do_pad else None, | ||
| use_per_token_if_dynamic=use_per_token_if_dynamic) | ||
|
|
||
| per_tensor_weights = (weight_scale.numel() == 1) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.