-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Kernel] Support W8A8 channel-wise weights and per-token activations in triton fused_moe_kernel #16366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
robertgshaw2-redhat
merged 5 commits into
vllm-project:main
from
neuralmagic:triton-fused-moe-support-w8a8-ptpc
Apr 11, 2025
Merged
[Kernel] Support W8A8 channel-wise weights and per-token activations in triton fused_moe_kernel #16366
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,221 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_block_int8.py | ||
| import itertools | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.activation import SiluAndMul | ||
| from vllm.model_executor.layers.fused_moe import fused_moe | ||
| from vllm.platforms import current_platform | ||
|
|
||
| if current_platform.get_device_capability() < (7, 0): | ||
| pytest.skip("INT8 Triton requires CUDA 7.0 or higher", | ||
| allow_module_level=True) | ||
|
|
||
|
|
||
| # For test | ||
| def native_per_token_group_quant_int8(x, | ||
| group_size, | ||
| eps=1e-10, | ||
| dtype=torch.int8): | ||
| """Function to perform per-token-group quantization on an input tensor | ||
| `x` using native torch. | ||
|
|
||
| It converts the tensor values into int8 values and returns the | ||
| quantized tensor along with the scaling factor used for quantization. | ||
| """ | ||
| assert (x.shape[-1] % group_size == 0 | ||
| ), "the last dimension of `x` cannot be divisible by `group_size`" | ||
| assert x.is_contiguous(), "`x` is not contiguous" | ||
|
|
||
| iinfo = torch.iinfo(dtype) | ||
| int8_min = iinfo.min | ||
| int8_max = iinfo.max | ||
|
|
||
| x_ = x.reshape(x.numel() // group_size, group_size) | ||
| # Use float32 for scale calculation for stability | ||
| amax = x_.abs().max(dim=-1, | ||
| keepdim=True)[0].clamp(min=eps).to(torch.float32) | ||
| x_s = amax / int8_max | ||
| x_q = (x_.to(torch.float32) / x_s).round().clamp( | ||
| min=int8_min, max=int8_max).to(dtype) # Round before clamping | ||
| x_q = x_q.reshape(x.shape) | ||
| x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) | ||
|
|
||
| return x_q, x_s | ||
|
|
||
|
|
||
| # For test | ||
| def native_w8a8_block_int8_matmul(A, | ||
| B, | ||
| As, | ||
| Bs, | ||
| block_size, | ||
| output_dtype=torch.float16): | ||
| """This function performs matrix multiplication with block-wise | ||
| quantization using native torch. | ||
|
|
||
| It takes two input tensors `A` and `B` (int8) with scales `As` and | ||
| `Bs` (float32). | ||
| The output is returned in the specified `output_dtype`. | ||
| """ | ||
| A = A.to(torch.float32) | ||
| B = B.to(torch.float32) | ||
| assert A.shape[-1] == B.shape[-1] | ||
| assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 | ||
| assert len(block_size) == 2 | ||
| block_n, block_k = block_size[0], block_size[1] | ||
| assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] | ||
| assert A.shape[:-1] == As.shape[:-1] | ||
|
|
||
| M = A.numel() // A.shape[-1] | ||
| N, K = B.shape | ||
| origin_C_shape = A.shape[:-1] + (N, ) | ||
| A = A.reshape(M, A.shape[-1]) | ||
| As = As.reshape(M, As.shape[-1]) | ||
| n_tiles = (N + block_n - 1) // block_n | ||
| k_tiles = (K + block_k - 1) // block_k | ||
| assert n_tiles == Bs.shape[0] | ||
| assert k_tiles == Bs.shape[1] | ||
|
|
||
| C_shape = (M, N) | ||
| C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) | ||
|
|
||
| A_tiles = [ | ||
| A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) | ||
| ] | ||
| B_tiles = [[ | ||
| B[ | ||
| j * block_n:min((j + 1) * block_n, N), | ||
| i * block_k:min((i + 1) * block_k, K), | ||
| ] for i in range(k_tiles) | ||
| ] for j in range(n_tiles)] | ||
| C_tiles = [ | ||
| C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) | ||
| ] | ||
| As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] | ||
|
|
||
| for i in range(k_tiles): | ||
| for j in range(n_tiles): | ||
| a = A_tiles[i] | ||
| b = B_tiles[j][i] | ||
| c = C_tiles[j] | ||
| s = As_tiles[i] * Bs[j][i] | ||
| c[:, :] += torch.matmul(a, b.t()) * s | ||
|
|
||
| C = C.reshape(origin_C_shape).to(output_dtype) | ||
| return C | ||
|
|
||
|
|
||
| # For test | ||
| def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): | ||
| """This function performs fused moe with block-wise quantization using | ||
| native torch.""" | ||
| B, D = a.shape | ||
| a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) | ||
| out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) | ||
| score = torch.softmax(score, dim=-1, dtype=torch.float32) | ||
| topk_weight, topk_ids = torch.topk(score, topk) | ||
| topk_weight = topk_weight.view(-1) | ||
| topk_ids = topk_ids.view(-1) | ||
|
|
||
| _, block_k = block_shape[0], block_shape[1] | ||
| a_q, a_s = native_per_token_group_quant_int8(a, block_k) | ||
| for i in range(w1.shape[0]): | ||
| mask = topk_ids == i | ||
| if mask.sum(): | ||
| inter_out = native_w8a8_block_int8_matmul(a_q[mask], | ||
| w1[i], | ||
| a_s[mask], | ||
| w1_s[i], | ||
| block_shape, | ||
| output_dtype=a.dtype) | ||
| act_out = SiluAndMul().forward_native(inter_out) | ||
| act_out_q, act_out_s = native_per_token_group_quant_int8( | ||
| act_out, block_k) | ||
| act_out = act_out.to(torch.float32) | ||
| out[mask] = native_w8a8_block_int8_matmul(act_out_q, | ||
| w2[i], | ||
| act_out_s, | ||
| w2_s[i], | ||
| block_shape, | ||
| output_dtype=a.dtype) | ||
| return (out.view(B, -1, w2.shape[1]) * | ||
| topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) | ||
|
|
||
|
|
||
| DTYPES = [torch.half, torch.bfloat16] | ||
| M = [1, 33, 64, 222] | ||
| N = [128, 1024] | ||
| K = [256, 4096] | ||
| E = [8, 24] | ||
| TOP_KS = [2, 6] | ||
| # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] | ||
| BLOCK_SIZE = [[128, 128]] | ||
| SEEDS = [0] | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True, scope="module") | ||
| def setup_cuda(): | ||
| """Sets the default CUDA device for all tests in this module.""" | ||
| torch.set_default_device("cuda") | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "M, N, K, E, topk, block_size, dtype, seed", | ||
| itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) | ||
| @torch.inference_mode() | ||
| def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): | ||
| """Tests the fused_moe kernel with W8A8 INT8 block quantization against a | ||
| native torch reference.""" | ||
| torch.manual_seed(seed) | ||
| # Use a smaller factor for scale initialization to prevent large | ||
| # values/overflow especially when output dtype might be float16 | ||
| factor_for_scale = 1e-2 | ||
| int8_info = torch.iinfo(torch.int8) | ||
| int8_max, int8_min = int8_info.max, int8_info.min | ||
|
|
||
| a = torch.randn((M, K), dtype=dtype) / 10 | ||
|
|
||
| w1_fp32 = (torch.rand( | ||
| (E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max | ||
| w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) | ||
|
|
||
| w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max | ||
| w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) | ||
|
|
||
| block_n, block_k = block_size[0], block_size[1] | ||
| n_tiles_w1 = (2 * N + block_n - 1) // block_n | ||
| n_tiles_w2 = (K + block_n - 1) // block_n | ||
| k_tiles_w1 = (K + block_k - 1) // block_k | ||
| k_tiles_w2 = (N + block_k - 1) // block_k | ||
|
|
||
| w1_s = (torch.rand( | ||
| (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale) | ||
| w2_s = (torch.rand( | ||
| (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale) | ||
|
|
||
| score = torch.randn((M, E), dtype=dtype) | ||
|
|
||
| out = fused_moe( | ||
| a, | ||
| w1, | ||
| w2, | ||
| score, | ||
| topk, | ||
| renormalize=False, | ||
| use_int8_w8a8=True, | ||
| w1_scale=w1_s, | ||
| w2_scale=w2_s, | ||
| block_shape=block_size, | ||
| ) | ||
| ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, | ||
| block_size) | ||
|
|
||
| # Check results | ||
| rel_diff = (torch.mean( | ||
| torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / | ||
| torch.mean(torch.abs(ref_out.to(torch.float32)))) | ||
| assert rel_diff < 0.06 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_int8_kernel.py | ||
| import itertools | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.activation import SiluAndMul | ||
| from vllm.model_executor.layers.fused_moe import fused_moe | ||
| from vllm.model_executor.layers.quantization.utils.int8_utils import ( | ||
| per_token_quant_int8) | ||
| from vllm.platforms import current_platform | ||
|
|
||
| if current_platform.get_device_capability() < (7, 0): | ||
| pytest.skip("INT8 Triton requires CUDA 7.0 or higher", | ||
| allow_module_level=True) | ||
|
|
||
|
|
||
| def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): | ||
| """Matrix multiplication function that supports per-token input | ||
| quantization and per-column weight quantization""" | ||
| A = A.to(torch.float32) | ||
| B = B.to(torch.float32) | ||
|
|
||
| assert A.shape[-1] == B.shape[-1], "Dimension mismatch" | ||
| assert B.ndim == 2 and B.is_contiguous( | ||
| ), "B must be a 2D contiguous tensor" | ||
|
|
||
| # Reshape input | ||
| M = A.numel() // A.shape[-1] | ||
| B = B.t() # Transpose weight matrix | ||
| N, K = B.shape | ||
| origin_C_shape = A.shape[:-1] + (K, ) | ||
| A = A.reshape(M, N) | ||
|
|
||
| # As is per-token [M, 1], Bs is per-column [1, K] | ||
| C = torch.matmul(A, B) # [M, K] | ||
| C = As * C * Bs.view(1, -1) # Broadcast per-column scale | ||
|
|
||
| return C.reshape(origin_C_shape).to(output_dtype) | ||
|
|
||
|
|
||
| def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): | ||
| """This function performs fused moe with per-column int8 quantization | ||
| using native torch.""" | ||
|
|
||
| B, D = a.shape | ||
| # Perform per-token quantization | ||
| a_q, a_s = per_token_quant_int8(a) | ||
| # Repeat tokens to match topk | ||
| a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) | ||
| # Also repeat the scale | ||
| a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1] | ||
|
|
||
| out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) | ||
|
|
||
| # Calculate routing | ||
| score = torch.softmax(score, dim=-1, dtype=torch.float32) | ||
| topk_weight, topk_ids = torch.topk(score, topk) | ||
| topk_weight = topk_weight.view(-1) | ||
| topk_ids = topk_ids.view(-1) | ||
| # Process each expert | ||
| for i in range(w1.shape[0]): | ||
| mask = topk_ids == i | ||
| if mask.sum(): | ||
| # First MLP layer: note that a_s is now per-token | ||
| inter_out = native_w8a8_per_token_matmul(a_q[mask], | ||
| w1[i], | ||
| a_s[mask], | ||
| w1_s[i], | ||
| output_dtype=a.dtype) | ||
| # Activation function | ||
| act_out = SiluAndMul().forward_native(inter_out) | ||
| # Quantize activation output with per-token | ||
| act_out_q, act_out_s = per_token_quant_int8(act_out) | ||
|
|
||
| # Second MLP layer | ||
| out[mask] = native_w8a8_per_token_matmul(act_out_q, | ||
| w2[i], | ||
| act_out_s, | ||
| w2_s[i], | ||
| output_dtype=a.dtype) | ||
| # Apply routing weights and sum | ||
| return (out.view(B, -1, w2.shape[1]) * | ||
| topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True, scope="module") | ||
| def setup_cuda(): | ||
| """Sets the default CUDA device for all tests in this module.""" | ||
| torch.set_default_device("cuda") | ||
|
|
||
|
|
||
| DTYPES = [torch.half, torch.bfloat16] | ||
| M = [1, 33] | ||
| N = [128, 1024] | ||
| K = [256, 4096] | ||
| E = [8] | ||
| TOP_KS = [2, 6] | ||
| SEEDS = [0] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("M, N, K, E, topk, dtype, seed", | ||
| itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS)) | ||
| @torch.inference_mode() | ||
| def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): | ||
| torch.manual_seed(seed) | ||
| # Initialize int8 quantization parameters | ||
| factor_for_scale = 1e-2 | ||
| int8_max = 127 | ||
| int8_min = -128 | ||
|
|
||
| # Input tensor | ||
| # M * K | ||
| a = torch.randn((M, K), dtype=dtype) / 10 | ||
|
|
||
| # Generate int8 weights | ||
| w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 | ||
| w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) | ||
|
|
||
| w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 | ||
| w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) | ||
|
|
||
| # Generate scale for each column (per-column quantization) | ||
| w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale | ||
| w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale | ||
| score = torch.randn((M, E), dtype=dtype) | ||
|
|
||
| ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) | ||
| out = fused_moe( | ||
| a, | ||
| w1, | ||
| w2, | ||
| score, | ||
| topk, | ||
| renormalize=False, | ||
| use_int8_w8a8=True, # Using int8-w8a8 | ||
| per_channel_quant=True, | ||
| w1_scale=w1_s, | ||
| w2_scale=w2_s, | ||
| block_shape=None, # Not using block quantization | ||
| ) | ||
|
|
||
| # Check results | ||
| rel_diff = (torch.mean( | ||
| torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / | ||
| torch.mean(torch.abs(ref_out.to(torch.float32)))) | ||
| assert rel_diff < 0.05 |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.