From fd49cc6dc618c63de5a00e44c736f5af6d5957ab Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 9 Apr 2025 23:13:47 +0000 Subject: [PATCH 1/4] Support W8A8 channelwise in triton moe Signed-off-by: mgoin --- tests/kernels/test_block_int8.py | 221 +++++++++ tests/kernels/test_int8_kernel.py | 148 ++++++ tests/kernels/test_triton_moe_ptpc_fp8.py | 158 ++++++ .../layers/fused_moe/fused_moe.py | 270 +++++++---- .../layers/quantization/utils/int8_utils.py | 459 ++++++++++++++++++ 5 files changed, 1169 insertions(+), 87 deletions(-) create mode 100644 tests/kernels/test_block_int8.py create mode 100644 tests/kernels/test_int8_kernel.py create mode 100644 tests/kernels/test_triton_moe_ptpc_fp8.py create mode 100644 vllm/model_executor/layers/quantization/utils/int8_utils.py diff --git a/tests/kernels/test_block_int8.py b/tests/kernels/test_block_int8.py new file mode 100644 index 000000000000..9161e63675e4 --- /dev/null +++ b/tests/kernels/test_block_int8.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_block_int8.py +import itertools + +import pytest +import torch + +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.platforms import current_platform + +if current_platform.get_device_capability() < (7, 0): + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", + allow_module_level=True) + + +# For test +def native_per_token_group_quant_int8(x, + group_size, + eps=1e-10, + dtype=torch.int8): + """Function to perform per-token-group quantization on an input tensor + `x` using native torch. + + It converts the tensor values into int8 values and returns the + quantized tensor along with the scaling factor used for quantization. + """ + assert (x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + iinfo = torch.iinfo(dtype) + int8_min = iinfo.min + int8_max = iinfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + # Use float32 for scale calculation for stability + amax = x_.abs().max(dim=-1, + keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / int8_max + x_q = (x_.to(torch.float32) / x_s).round().clamp( + min=int8_min, max=int8_max).to(dtype) # Round before clamping + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + + return x_q, x_s + + +# For test +def native_w8a8_block_int8_matmul(A, + B, + As, + Bs, + block_size, + output_dtype=torch.float16): + """This function performs matrix multiplication with block-wise + quantization using native torch. + + It takes two input tensors `A` and `B` (int8) with scales `As` and + `Bs` (float32). + The output is returned in the specified `output_dtype`. + """ + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +# For test +def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """This function performs fused moe with block-wise quantization using + native torch.""" + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_int8(a, block_k) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_int8_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_int8( + act_out, block_k) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_int8_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +DTYPES = [torch.half, torch.bfloat16] +M = [1, 33, 64, 222] +N = [128, 1024] +K = [256, 4096] +E = [8, 24] +TOP_KS = [2, 6] +# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] +BLOCK_SIZE = [[128, 128]] +SEEDS = [0] + + +@pytest.fixture(autouse=True, scope="module") +def setup_cuda(): + """Sets the default CUDA device for all tests in this module.""" + torch.set_default_device("cuda") + + +@pytest.mark.parametrize( + "M, N, K, E, topk, block_size, dtype, seed", + itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): + """Tests the fused_moe kernel with W8A8 INT8 block quantization against a + native torch reference.""" + torch.manual_seed(seed) + # Use a smaller factor for scale initialization to prevent large + # values/overflow especially when output dtype might be float16 + factor_for_scale = 1e-2 + int8_info = torch.iinfo(torch.int8) + int8_max, int8_min = int8_info.max, int8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_fp32 = (torch.rand( + (E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max + w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) + + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max + w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = (torch.rand( + (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale) + w2_s = (torch.rand( + (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale) + + score = torch.randn((M, E), dtype=dtype) + + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_int8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) + + # Check results + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.06 diff --git a/tests/kernels/test_int8_kernel.py b/tests/kernels/test_int8_kernel.py new file mode 100644 index 000000000000..15420211a0ae --- /dev/null +++ b/tests/kernels/test_int8_kernel.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools + +import pytest +import torch + +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_quant_int8) +from vllm.platforms import current_platform + +if current_platform.get_device_capability() < (7, 0): + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", + allow_module_level=True) + + +def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): + """Matrix multiplication function that supports per-token input + quantization and per-column weight quantization""" + A = A.to(torch.float32) + B = B.to(torch.float32) + + assert A.shape[-1] == B.shape[-1], "Dimension mismatch" + assert B.ndim == 2 and B.is_contiguous( + ), "B must be a 2D contiguous tensor" + + # Reshape input + M = A.numel() // A.shape[-1] + B = B.t() # Transpose weight matrix + N, K = B.shape + origin_C_shape = A.shape[:-1] + (K, ) + A = A.reshape(M, N) + + # As is per-token [M, 1], Bs is per-column [1, K] + C = torch.matmul(A, B) # [M, K] + C = As * C * Bs.view(1, -1) # Broadcast per-column scale + + return C.reshape(origin_C_shape).to(output_dtype) + + +def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): + """This function performs fused moe with per-column int8 quantization + using native torch.""" + + B, D = a.shape + # Perform per-token quantization + a_q, a_s = per_token_quant_int8(a) + # Repeat tokens to match topk + a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + # Also repeat the scale + a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1] + + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + # Calculate routing + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + # Process each expert + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + # First MLP layer: note that a_s is now per-token + inter_out = native_w8a8_per_token_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + output_dtype=a.dtype) + # Activation function + act_out = SiluAndMul().forward_native(inter_out) + # Quantize activation output with per-token + act_out_q, act_out_s = per_token_quant_int8(act_out) + + # Second MLP layer + out[mask] = native_w8a8_per_token_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + output_dtype=a.dtype) + # Apply routing weights and sum + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.fixture(autouse=True, scope="module") +def setup_cuda(): + """Sets the default CUDA device for all tests in this module.""" + torch.set_default_device("cuda") + + +DTYPES = [torch.half, torch.bfloat16] +M = [1, 33] +N = [128, 1024] +K = [256, 4096] +E = [8] +TOP_KS = [2, 6] +SEEDS = [0] + + +@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed", + itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): + torch.manual_seed(seed) + # Initialize int8 quantization parameters + factor_for_scale = 1e-2 + int8_max = 127 + int8_min = -128 + + # Input tensor + # M * K + a = torch.randn((M, K), dtype=dtype) / 10 + + # Generate int8 weights + w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 + w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) + + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 + w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) + + # Generate scale for each column (per-column quantization) + w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale + w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale + score = torch.randn((M, E), dtype=dtype) + + ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_int8_w8a8=True, # Using int8-w8a8 + per_channel_quant=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ) + + # Check results + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.05 diff --git a/tests/kernels/test_triton_moe_ptpc_fp8.py b/tests/kernels/test_triton_moe_ptpc_fp8.py new file mode 100644 index 000000000000..27dddf4ffa69 --- /dev/null +++ b/tests/kernels/test_triton_moe_ptpc_fp8.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.platforms import current_platform + +if current_platform.get_device_capability() < (9, 0): + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", + allow_module_level=True) + + +def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): + """Matrix multiplication function that supports per-token input + quantization and per-column weight quantization""" + A = A.to(torch.float32) + B = B.to(torch.float32) + + assert A.shape[-1] == B.shape[-1], "Dimension mismatch" + assert B.ndim == 2 and B.is_contiguous( + ), "B must be a 2D contiguous tensor" + + # Reshape input + M = A.numel() // A.shape[-1] + B = B.t() # Transpose weight matrix + N, K = B.shape + origin_C_shape = A.shape[:-1] + (K, ) + A = A.reshape(M, N) + + # As is per-token [M, 1], Bs is per-column [1, K] + C = torch.matmul(A, B) # [M, K] + C = As * C * Bs.view(1, -1) # Broadcast per-column scale + + return C.reshape(origin_C_shape).to(output_dtype) + + +def fp8_mask(a, mask): + dtype = a.dtype + return a.view(torch.int8)[mask].view(dtype) + + +def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): + """This function performs fused moe with per-column int8 + quantization using native torch.""" + + B, D = a.shape + # Perform per-token quantization + a_q, a_s = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True) + # Repeat tokens to match topk + a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + # Also repeat the scale + a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1] + + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + # Calculate routing + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + # Process each expert + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + # First MLP layer: note that a_s is now per-token + inter_out = native_w8a8_per_token_matmul( + fp8_mask(a_q, mask), + w1[i], + fp8_mask(a_s, mask), + w1_s[i], + output_dtype=a.dtype, + ) + # Activation function + act_out = SiluAndMul().forward_native(inter_out) + # Quantize activation output with per-token + act_out_q, act_out_s = ops.scaled_fp8_quant( + act_out, use_per_token_if_dynamic=True) + + # Second MLP layer + out[mask] = native_w8a8_per_token_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + output_dtype=a.dtype) + # Apply routing weights and sum + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.fixture(autouse=True, scope="module") +def setup_cuda(): + """Sets the default CUDA device for all tests in this module.""" + torch.set_default_device("cuda") + + +DTYPES = [torch.half, torch.bfloat16] +M = [1, 33] +N = [128, 1024] +K = [256, 4096] +E = [8] +TOP_KS = [2, 6] +SEEDS = [0] + + +@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed", + itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): + torch.manual_seed(seed) + # Initialize int8 quantization parameters + factor_for_scale = 1e-2 + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_max = finfo.max + fp8_min = finfo.min + + # Input tensor + # M * K + a = torch.randn((M, K), dtype=dtype) / 10 + + # Generate int8 weights + w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 + w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, + max=fp8_max).to(torch.float8_e4m3fn) + + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 + w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, + max=fp8_max).to(torch.float8_e4m3fn) + + # Generate scale for each column (per-column quantization) + w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale + w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale + score = torch.randn((M, E), dtype=dtype) + + ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, # using fp8 + per_channel_quant=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ) + + # Check results + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.05 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a6a00040fb50..0378dcee4731 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -16,7 +16,10 @@ _valid_deep_gemm, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) -from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_group_quant_int8, per_token_quant_int8) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -251,50 +254,53 @@ def fused_moe_kernel_gptq_awq( @triton.jit def fused_moe_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - a_scale_ptr, - b_scale_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N, - K, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Block size for block-wise quantization - group_n: tl.constexpr, - group_k: tl.constexpr, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, - top_k: tl.constexpr, - compute_type: tl.constexpr, - use_fp8_w8a8: tl.constexpr, - use_int8_w8a16: tl.constexpr): + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, +): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. @@ -371,12 +377,23 @@ def fused_moe_kernel( None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) - if use_fp8_w8a8: + if use_fp8_w8a8 or use_int8_w8a8: + # block-wise if group_k > 0 and group_n > 0: a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm offs_bsn = offs_bn // group_n b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn) + # channel-wise + elif per_channel_quant: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, + None] + # tensor-wise else: a_scale = tl.load(a_scale_ptr) b_scale = tl.load(b_scale_ptr + off_experts) @@ -400,7 +417,7 @@ def fused_moe_kernel( # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) - elif use_fp8_w8a8: + elif use_fp8_w8a8 or use_int8_w8a8: if group_k > 0 and group_n > 0: k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k @@ -412,7 +429,11 @@ def fused_moe_kernel( accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: - accumulator = tl.dot(a, b, acc=accumulator) + # fix out of shared memory issue + if use_fp8_w8a8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -426,7 +447,7 @@ def fused_moe_kernel( accumulator = accumulator * moe_weight[:, None] if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) - elif use_fp8_w8a8: + elif use_fp8_w8a8 or use_int8_w8a8: if group_k > 0 and group_n > 0: accumulator = accumulator.to(compute_type) else: @@ -457,27 +478,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, config: Dict[str, Any], compute_type: tl.dtype, use_fp8_w8a8: bool, + use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, + per_channel_quant: bool, block_shape: Optional[List[int]] = None) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 - if use_fp8_w8a8: - assert B_scale is not None - assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) - == B_scale.shape[-2]) - assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) - == B_scale.shape[-1]) - - elif use_int8_w8a16 or use_int4_w4a16: - assert B_scale is not None - assert block_shape is None or block_shape[0] == 0 - else: - assert A_scale is None - assert B_scale is None - M = A.shape[0] num_tokens = M * top_k @@ -604,7 +613,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor, top_k=top_k, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, + per_channel_quant=per_channel_quant, BLOCK_SIZE_K=BLOCK_SIZE_K, **config, ) @@ -956,8 +967,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -969,9 +982,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor, block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, - use_int8_w8a16, use_int4_w4a16, global_num_experts, - expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, - a2_scale, block_shape) + use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, + per_channel_quant, global_num_experts, expert_map, + w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, + block_shape) def inplace_fused_experts_fake( @@ -983,8 +997,10 @@ def inplace_fused_experts_fake( activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1015,8 +1031,10 @@ def outplace_fused_experts( activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1028,7 +1046,8 @@ def outplace_fused_experts( block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, - use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, + use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, + use_int4_w4a16, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) @@ -1042,8 +1061,10 @@ def outplace_fused_experts_fake( topk_ids: torch.Tensor, activation: str = "silu", use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1092,8 +1113,10 @@ def fused_experts(hidden_states: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1132,8 +1155,10 @@ def fused_experts(hidden_states: torch.Tensor, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=w1_scale, @@ -1145,6 +1170,59 @@ def fused_experts(hidden_states: torch.Tensor, block_shape=block_shape) +def moe_kernel_prepare_input( + A: torch.Tensor, + B: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + per_channel_quant: bool, + block_shape: Optional[List[int]] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if use_fp8_w8a8: + assert B_scale is not None + if block_shape is None: + # If weights are per-channel (per_channel_quant=True), then + # activations apply per-token quantization. Otherwise, assume + # activation tensor-wise fp8 quantization, dynamic or static + A, A_scale = ops.scaled_fp8_quant( + A, A_scale, use_per_token_if_dynamic=per_channel_quant) + else: + # activation block-wise fp8 quantization + assert len(block_shape) == 2 + _, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + # assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + # assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a8: + assert B_scale is not None + if block_shape is None: + # activation channel-wise int8 quantization + assert (per_channel_quant + ), "int8 quantization only supports block or channel-wise" + A, A_scale = per_token_quant_int8(A) + else: + # activation block-wise int8 quantization + assert len(block_shape) == 2 + _, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_int8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + # assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + # assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + + return A, A_scale + + def fused_experts_impl(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1154,8 +1232,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1257,23 +1337,26 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - a1q_scale: Optional[torch.Tensor] = None - - if use_fp8_w8a8: - qcurr_hidden_states, a1q_scale = _fp8_quantize( - curr_hidden_states, a1_scale, block_shape) - else: - qcurr_hidden_states = curr_hidden_states - a1q_scale = a1_scale + curr_hidden_states, a1_scale = moe_kernel_prepare_input( + A=curr_hidden_states, + B=w1, + A_scale=a1_scale, + B_scale=w1_scale, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) - invoke_fused_moe_kernel(qcurr_hidden_states, + invoke_fused_moe_kernel(curr_hidden_states, w1, intermediate_cache1, - a1q_scale, + a1_scale, w1_scale, w1_zp, curr_topk_weights, @@ -1285,8 +1368,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, block_shape=block_shape) if activation == "silu": @@ -1298,19 +1383,22 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") - a2q_scale: Optional[torch.Tensor] = None - - if use_fp8_w8a8: - qintermediate_cache2, a2q_scale = _fp8_quantize( - intermediate_cache2, a2_scale, block_shape) - else: - qintermediate_cache2 = intermediate_cache2 - a2q_scale = a2_scale + intermediate_cache2, a2_scale = moe_kernel_prepare_input( + A=intermediate_cache2, + B=w2, + A_scale=a2_scale, + B_scale=w2_scale, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape) - invoke_fused_moe_kernel(qintermediate_cache2, + invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, - a2q_scale, + a2_scale, w2_scale, w2_zp, curr_topk_weights, @@ -1322,8 +1410,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, block_shape=block_shape) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), @@ -1346,8 +1436,10 @@ def fused_moe( topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -1380,6 +1472,8 @@ def fused_moe( note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. + - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 activation to compute the inner products for w1 and w2. Defaults to False. @@ -1426,8 +1520,10 @@ def fused_moe( inplace=inplace, activation=activation, use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=w1_scale, diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py new file mode 100644 index 000000000000..98b06b6c2ae9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/sgl-project/sglang/blob/4cb53ecd0cffceb6dee5c011a58f65997a86f151/python/sglang/srt/layers/quantization/int8_kernel.py +import functools +import json +import logging +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +import triton +import triton.language as tl + +from vllm.platforms import current_platform + +logger = logging.getLogger(__name__) + + +def apply_w8a8_block_int8_linear( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1]) + output = w8a8_block_int8_matmul(q_input, + weight, + x_scale, + weight_scale, + block_size, + output_dtype=input.dtype) + + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + +def input_to_int8( + x: torch.Tensor, + dtype: torch.dtype = torch.int8) -> Tuple[torch.Tensor, torch.Tensor]: + """This function quantizes input values to int8 values with + tensor-wise quantization.""" + iinfo = torch.iinfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + int8_min, int8_max = iinfo.min, iinfo.max + scale = int8_max / amax + x_scl_sat = (x * scale).clamp(min=int8_min, max=int8_max) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + +def block_dequant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, + block_size: List[int], +) -> torch.Tensor: + """This function conducts block-wise dequantization. + The inputs are block-wise quantization tensor `x_q_block`, + block-wise quantization scale and the block size. + The outputs are dequantized tensor. + """ + block_n, block_k = block_size[0], block_size[1] + n, k = x_q_block.shape + n_tiles = (n + block_n - 1) // block_n + k_tiles = (k + block_k - 1) // block_k + assert n_tiles == x_s.shape[0] + assert k_tiles == x_s.shape[1] + + x_dq_block = x_q_block.to(torch.float32) + + for i in range(k_tiles): + for j in range(n_tiles): + x_dq_block[ + j * block_n:min((j + 1) * block_n, n), + i * block_k:min((i + 1) * block_k, k), + ] *= x_s[j][i] + + return x_dq_block + + +@triton.jit +def _per_token_quant_int8( + x_ptr, + xq_ptr, + scale_ptr, + stride_x, + stride_xq, + N, + BLOCK: tl.constexpr, +): + # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 + row_id = tl.program_id(0) + + cols = tl.arange(0, BLOCK) + mask = cols < N + + x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, + other=0.0).to(tl.float32) + absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) + scale_x = absmax / 127 + x_q = x * (127 / absmax) + x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8) + + tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) + tl.store(scale_ptr + row_id, scale_x) + + +def per_token_quant_int8(x): + M = x.numel() // x.shape[-1] + N = x.shape[-1] + x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) + scales = torch.empty(x.shape[:-1] + (1, ), + device=x.device, + dtype=torch.float32) + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + + assert x.is_contiguous() + _per_token_quant_int8[(M, )]( + x, + x_q, + scales, + stride_x=x.stride(-2), + stride_xq=x_q.stride(-2), + N=N, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + + return x_q, scales + + +@triton.jit +def _per_token_group_quant_int8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + # Stride of input + y_stride, + # Columns of input + N, + # Avoid to divide zero + eps, + # Information for int8 + int8_min, + int8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + + This function converts the tensor values into int8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / int8_max + y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_int8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = torch.int8, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + + It converts the tensor values into signed int8 values and returns the + quantized tensor along with the scaling factor used for quantization. + + Args: + x: The input tenosr with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.int8` + is supported for now. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + assert (x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + iinfo = torch.iinfo(dtype) + int8_max = iinfo.max + int8_min = iinfo.min + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size, ), + device=x.device, + dtype=torch.float32, + ) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + _per_token_group_quant_int8[(M, )]( + x, + x_q, + x_s, + group_size, + N, + eps, + int8_min=int8_min, + int8_max=int8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +@triton.jit +def _w8a8_block_int8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and + store the result in output tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, + None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@functools.lru_cache +def get_w8a8_block_int8_configs(N: int, K: int, block_n: int, + block_k: int) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the w8a8 block fp8 kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + device_name = current_platform.get_device_name().replace(" ", "_") + json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501 + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info( + "Using configuration from %s for W8A8 Block INT8 kernel.", + config_file_path, + ) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.warning( + ("Using default W8A8 Block INT8 kernel config. Performance might " + "be sub-optimal! Config file not found at %s"), + config_file_path, + ) + return None + + +def w8a8_block_int8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise + quantization. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should be + 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N, ) + C = A.new_empty(C_shape, dtype=output_dtype) + + configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1]) + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Default config + # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + } + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * + triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + _w8a8_block_int8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + return C From b736f7d8889fa0bc3ab2685767c5ba5a1e2a1208 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 10 Apr 2025 00:44:27 +0000 Subject: [PATCH 2/4] Add comments Signed-off-by: mgoin --- tests/kernels/test_int8_kernel.py | 1 + tests/kernels/test_triton_moe_ptpc_fp8.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/kernels/test_int8_kernel.py b/tests/kernels/test_int8_kernel.py index 15420211a0ae..4c7543527c32 100644 --- a/tests/kernels/test_int8_kernel.py +++ b/tests/kernels/test_int8_kernel.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_int8_kernel.py import itertools import pytest diff --git a/tests/kernels/test_triton_moe_ptpc_fp8.py b/tests/kernels/test_triton_moe_ptpc_fp8.py index 27dddf4ffa69..44734e9340aa 100644 --- a/tests/kernels/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/test_triton_moe_ptpc_fp8.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_triton_moe_channel_fp8_kernel.py import itertools import pytest From 5bf67eeaf5d260a000020707d38fdacf3983a620 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 11 Apr 2025 00:13:30 +0000 Subject: [PATCH 3/4] Update unhelpful comment Signed-off-by: mgoin --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0378dcee4731..1d7356ad3a38 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -429,8 +429,8 @@ def fused_moe_kernel( accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: - # fix out of shared memory issue if use_fp8_w8a8: + # acc used to enable fp8_fast_accum accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) From 2c1eaac2cc30e0678bcf5613c299bf973af2ebec Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 11 Apr 2025 02:27:13 +0000 Subject: [PATCH 4/4] Review comments Signed-off-by: mgoin --- tests/kernels/test_block_fp8.py | 92 +++-------- tests/kernels/test_block_int8.py | 154 ++++++++---------- tests/kernels/utils_block.py | 63 +++++++ .../layers/fused_moe/fused_moe.py | 12 +- 4 files changed, 154 insertions(+), 167 deletions(-) create mode 100644 tests/kernels/utils_block.py diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 347319b303f4..c450048bf665 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -18,6 +18,8 @@ per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform +from .utils_block import native_w8a8_block_matmul + dg_available = False try: import deep_gemm @@ -75,61 +77,6 @@ def native_per_token_group_quant_fp8(x, return x_q, x_s -def native_w8a8_block_fp8_matmul(A, - B, - As, - Bs, - block_size, - output_dtype=torch.float16): - """Matrix multiplication with block-wise quantization using native torch.""" - A = A.to(torch.float32) - B = B.to(torch.float32) - assert A.shape[-1] == B.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] - assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] - assert A.shape[:-1] == As.shape[:-1] - - M = A.numel() // A.shape[-1] - N, K = B.shape - origin_C_shape = A.shape[:-1] + (N, ) - A = A.reshape(M, A.shape[-1]) - As = As.reshape(M, As.shape[-1]) - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - assert n_tiles == Bs.shape[0] - assert k_tiles == Bs.shape[1] - - C_shape = (M, N) - C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) - - A_tiles = [ - A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) - ] - B_tiles = [[ - B[ - j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), - ] for i in range(k_tiles) - ] for j in range(n_tiles)] - C_tiles = [ - C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) - ] - As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] - - for i in range(k_tiles): - for j in range(n_tiles): - a = A_tiles[i] - b = B_tiles[j][i] - c = C_tiles[j] - s = As_tiles[i] * Bs[j][i] - c[:, :] += torch.matmul(a, b.t()) * s - - C = C.reshape(origin_C_shape).to(output_dtype) - return C - - def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using native torch.""" B, D = a.shape @@ -146,22 +93,22 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - inter_out = native_w8a8_block_fp8_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) + inter_out = native_w8a8_block_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = native_per_token_group_quant_fp8( act_out, block_k) act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_fp8_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) + out[mask] = native_w8a8_block_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) @@ -215,8 +162,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale - ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) rel_diff = (torch.mean( @@ -239,8 +186,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min - vllm_config = VllmConfig() - a = torch.randn((M, K), dtype=dtype) / 10 w1_bf16 = (torch.rand( @@ -266,6 +211,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. + vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -334,8 +280,8 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): As = As_fp8.to(torch.float32) Bs = Bs_fp8.to(torch.float32) - ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) # Transpose earlier so that the testing will not trigger transposing kernels As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8) diff --git a/tests/kernels/test_block_int8.py b/tests/kernels/test_block_int8.py index 9161e63675e4..9447f9d69165 100644 --- a/tests/kernels/test_block_int8.py +++ b/tests/kernels/test_block_int8.py @@ -6,10 +6,15 @@ import pytest import torch +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + w8a8_block_int8_matmul) from vllm.platforms import current_platform +from .utils_block import native_w8a8_block_matmul + if current_platform.get_device_capability() < (7, 0): pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) @@ -47,68 +52,6 @@ def native_per_token_group_quant_int8(x, return x_q, x_s -# For test -def native_w8a8_block_int8_matmul(A, - B, - As, - Bs, - block_size, - output_dtype=torch.float16): - """This function performs matrix multiplication with block-wise - quantization using native torch. - - It takes two input tensors `A` and `B` (int8) with scales `As` and - `Bs` (float32). - The output is returned in the specified `output_dtype`. - """ - A = A.to(torch.float32) - B = B.to(torch.float32) - assert A.shape[-1] == B.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] - assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] - assert A.shape[:-1] == As.shape[:-1] - - M = A.numel() // A.shape[-1] - N, K = B.shape - origin_C_shape = A.shape[:-1] + (N, ) - A = A.reshape(M, A.shape[-1]) - As = As.reshape(M, As.shape[-1]) - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - assert n_tiles == Bs.shape[0] - assert k_tiles == Bs.shape[1] - - C_shape = (M, N) - C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) - - A_tiles = [ - A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) - ] - B_tiles = [[ - B[ - j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), - ] for i in range(k_tiles) - ] for j in range(n_tiles)] - C_tiles = [ - C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) - ] - As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] - - for i in range(k_tiles): - for j in range(n_tiles): - a = A_tiles[i] - b = B_tiles[j][i] - c = C_tiles[j] - s = As_tiles[i] * Bs[j][i] - c[:, :] += torch.matmul(a, b.t()) * s - - C = C.reshape(origin_C_shape).to(output_dtype) - return C - - # For test def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """This function performs fused moe with block-wise quantization using @@ -126,22 +69,22 @@ def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - inter_out = native_w8a8_block_int8_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) + inter_out = native_w8a8_block_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = native_per_token_group_quant_int8( act_out, block_k) act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_int8_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) + out[mask] = native_w8a8_block_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) @@ -163,6 +106,38 @@ def setup_cuda(): torch.set_default_device("cuda") +@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + int8_info = torch.iinfo(torch.int8) + int8_max, int8_min = int8_info.max, int8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * int8_max + A_fp8 = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * int8_max + B_fp8 = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) + out = w8a8_block_int8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.001 + + @pytest.mark.parametrize( "M, N, K, E, topk, block_size, dtype, seed", itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) @@ -199,20 +174,23 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_int8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) + # Set the context to avoid lots of warning spam. + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_int8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) # Check results rel_diff = (torch.mean( diff --git a/tests/kernels/utils_block.py b/tests/kernels/utils_block.py new file mode 100644 index 000000000000..c16cba50967e --- /dev/null +++ b/tests/kernels/utils_block.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, + As: torch.Tensor, Bs: torch.Tensor, block_size, + output_dtype): + """This function performs matrix multiplication with block-wise + quantization using native torch. + It is agnostic to the input data type and can be used for both int8 and + fp8 data types. + + It takes two input tensors `A` and `B` (int8) with scales `As` and + `Bs` (float32). + The output is returned in the specified `output_dtype`. + """ + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1d7356ad3a38..38d739d55e55 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1337,7 +1337,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - curr_hidden_states, a1_scale = moe_kernel_prepare_input( + qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input( A=curr_hidden_states, B=w1, A_scale=a1_scale, @@ -1353,10 +1353,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) - invoke_fused_moe_kernel(curr_hidden_states, + invoke_fused_moe_kernel(qcurr_hidden_states, w1, intermediate_cache1, - a1_scale, + qa1_scale, w1_scale, w1_zp, curr_topk_weights, @@ -1383,7 +1383,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") - intermediate_cache2, a2_scale = moe_kernel_prepare_input( + qintermediate_cache2, qa2_scale = moe_kernel_prepare_input( A=intermediate_cache2, B=w2, A_scale=a2_scale, @@ -1395,10 +1395,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - invoke_fused_moe_kernel(intermediate_cache2, + invoke_fused_moe_kernel(qintermediate_cache2, w2, intermediate_cache3, - a2_scale, + qa2_scale, w2_scale, w2_zp, curr_topk_weights,