-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[ Kernel ] AWQ Fused MoE #6422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ Kernel ] AWQ Fused MoE #6422
Changes from 25 commits
d40fd4d
f1d5836
16baf11
03d9d8e
54d6a87
524a94c
febb027
8bca009
8527d6e
36d1d82
6943e80
71e5129
703e792
5b73064
2ef2c92
db33c3f
f6f60cd
d9def7e
16eacd0
0674d2f
d6a032e
8d52ae5
c08a5da
7325e78
0538dcc
0ba00ab
419eb7d
5666fcb
8013ad4
be34dc0
6e7bbf9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| """Fused MoE utilities for AWQ.""" | ||
| import torch | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.logger import init_logger | ||
|
|
||
| from .fused_moe import fused_experts, moe_align_block_size | ||
mgoin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| NAIVE_THRESHOLD = 1024 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems a bit high and it is worth commenting how it was calibrated (what model, benchmark, GPU used)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @robertgshaw2-neuralmagic do we know why this is 1024 specifically? |
||
|
|
||
|
|
||
| def fused_experts_awq( | ||
| hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| w1_scales: torch.Tensor, | ||
| w2_scales: torch.Tensor, | ||
| w1_qzeros: torch.Tensor, | ||
| w2_qzeros: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| pack_factor: int, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| This function computes an AWQ fused_expert. | ||
| Parameters: | ||
| - hidden_states (torch.Tensor): The input tensor to the MoE layer. | ||
| - w1 (torch.Tensor): The first set of expert weights. | ||
| - w2 (torch.Tensor): The second set of expert weights. | ||
| - w1_scales (torch.Tensor): scale to be used for w1. | ||
| - w2_scales (torch.Tensor): scale to be used for w2. | ||
| - w1_qzeros (torch.Tensor): zero point to be used for w1. | ||
| - w2_qzeros (torch.Tensor): zero point to be used for w2. | ||
| - pack_factor (int): Weight packing factor (int4 in int32 == 8) | ||
| Returns: | ||
| - torch.Tensor: The output tensor after applying the MoE layer. | ||
| """ | ||
|
|
||
| # If large seq_len prefill, dequantize and use the fp16 MoE kernel. | ||
| do_naive_dequant = hidden_states.shape[:-1].numel() >= NAIVE_THRESHOLD | ||
| if do_naive_dequant: | ||
| # TODO: why is this not contiguous already? | ||
| # from @dsikka: because of the permutation operation | ||
mgoin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0, | ||
| 0).permute(0, 2, 1) | ||
| dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0, | ||
| 0).permute(0, 2, 1) | ||
|
|
||
| return fused_experts(hidden_states, dequant_w1, dequant_w2, | ||
| topk_weights, topk_ids) | ||
|
|
||
| (sorted_token_ids, expert_ids, | ||
| num_tokens_post_padded) = moe_align_block_size(topk_ids, 16, w1.shape[0]) | ||
|
|
||
| x = hidden_states.view(hidden_states.shape[0], 1, *hidden_states.shape[1:]) | ||
|
|
||
| gate_up = ops.awq_fused_moe(x, w1, w1_scales, w1_qzeros, topk_weights, | ||
| sorted_token_ids, expert_ids, | ||
| num_tokens_post_padded, False, pack_factor) | ||
mgoin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), | ||
| dtype=hidden_states.dtype, | ||
| device=hidden_states.device) | ||
| ops.silu_and_mul(out, gate_up) | ||
|
|
||
| out = ops.awq_fused_moe(out, w2, w2_scales, w2_qzeros, topk_weights, | ||
| sorted_token_ids, expert_ids, | ||
| num_tokens_post_padded, True, pack_factor) | ||
|
|
||
| return torch.sum(out, dim=1) | ||
Uh oh!
There was an error while loading. Please reload this page.