Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,17 @@ def _parse_quant_hf_config(self):
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
quant_cfg = modelopt_quant_config
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
quant_cfg = modelopt_quant_config
quant_config_file = os.path.join(
self.model_path, "hf_quant_config.json"
)
with open(quant_config_file) as f:
quant_config_dict = json.load(f)
json_quant_configs = quant_config_dict["quantization"]
quant_algo = json_quant_configs.get("quant_algo", None)
if quant_algo == "MIXED_PRECISION":
quant_cfg = {"quant_method": "w4afp8"}
else:
quant_cfg = modelopt_quant_config
return quant_cfg

# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
Expand Down Expand Up @@ -389,6 +399,7 @@ def _verify_quantization(self) -> None:
"w8a8_fp8",
"moe_wna16",
"qoq",
"w4afp8",
]
compatible_quantization_methods = {
"modelopt_fp4": ["modelopt"],
Expand Down
215 changes: 215 additions & 0 deletions python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# SPDX-License-Identifier: Apache-2.0
"""Cutlass W4A8 MoE kernel."""
from typing import Optional

import torch
from sgl_kernel import (
cutlass_w4a8_moe_mm,
get_cutlass_w4a8_moe_mm_data,
sgl_per_tensor_quant_fp8,
silu_and_mul,
)

from sglang.srt.layers.moe.ep_moe.kernels import (
post_reorder_triton_kernel,
pre_reorder_triton_kernel_for_cutlass_moe,
run_cutlass_moe_ep_preproess,
)


def cutlass_w4a8_moe(
start_expert_id: int,
end_expert_id: int,
total_num_experts: int,
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids_: torch.Tensor,
local_topk_ids: torch.Tensor,
a_strides1: torch.Tensor,
b_strides1: torch.Tensor,
c_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides2: torch.Tensor,
c_strides2: torch.Tensor,
s_strides13: torch.Tensor,
s_strides2: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.

Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
Shape: [num_experts, N * 2, K // 2]
(the weights are passed transposed and int4-packed)
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
Shape: [num_experts, K, N // 2]
(the weights are passed transposed and int4-packed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts, K // 512, N * 8]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts, N // 512, K * 4]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [1, K]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [1, N]
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.

Returns:
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
"""
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
assert w1_q.dtype == torch.int8
assert w2_q.dtype == torch.int8
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert (
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
and w1_scale.shape[2] == w1_q.shape[1] * 4
), "W1 scale shape mismatch"
assert (
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
and w2_scale.shape[2] == w2_q.shape[1] * 4
), "W2 scale shape mismatch"

assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(2) * 2 # w1_q is transposed and packed
n = w2_q.size(2) * 2 # w2_q is transposed and packed
topk = topk_ids_.size(1)

if apply_router_weight_on_input:
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"

device = a.device

_, src2dst, _ = run_cutlass_moe_ep_preproess(
local_topk_ids,
num_experts,
)

gateup_input = torch.empty(
(m * topk, k),
device=device,
dtype=torch.float8_e4m3fn,
)

pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
a,
gateup_input,
src2dst,
local_topk_ids,
a1_scale,
total_num_experts,
topk,
k,
BLOCK_SIZE=512,
)

# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
# they are kept to allow for a quick switch of the permutation logic
# from the current triton kernel implementation to the cutlass-based one if needed.
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
get_cutlass_w4a8_moe_mm_data(
local_topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
num_experts,
n,
k,
)

c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)

cutlass_w4a8_moe_mm(
c1,
gateup_input,
w1_q,
a1_scale.float(),
w1_scale,
expert_offsets[:-1],
problem_sizes1,
a_strides1,
b_strides1,
c_strides1,
s_strides13,
128,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noticed that chunk_size is hard-coded to 128 here. wondering if only g128 is valid for w4fp8 in your test on hopper arch for now?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we just test w4afp8 on DeepSeek-R1-W4AFP8 model now, where the moe weight is quantized with group_size=128. We can also implement dynamic passing of this value instead of hardcoding it for future flexibility.

topk,
)

intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
silu_and_mul(c1, intermediate)

intermediate_q = torch.empty(
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
)
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)

cutlass_w4a8_moe_mm(
c2,
intermediate_q,
w2_q,
a2_scale.float(),
w2_scale,
expert_offsets[:-1],
problem_sizes2,
a_strides2,
b_strides2,
c_strides2,
s_strides2,
128,
topk,
)

output = torch.empty_like(a)
post_reorder_triton_kernel[(m,)](
c2,
output,
src2dst,
topk_ids_,
topk_weights,
start_expert_id,
end_expert_id,
topk,
k,
0,
BLOCK_SIZE=512,
)
return output
58 changes: 58 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):

def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)

seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)

Expand All @@ -158,9 +159,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
compute_src2dst_triton_kernel[grid](
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
)

return reorder_topk_ids, src2dst, seg_indptr


def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)

seg_indptr = torch.zeros(
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
)
src2dst = torch.empty(
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
)

BLOCK_SIZE = 512
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
compute_src2dst_triton_kernel[grid](
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
)

return reorder_topk_ids, src2dst, seg_indptr
Comment on lines +166 to +182
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function run_cutlass_moe_ep_preproess is nearly identical to run_moe_ep_preproess. This duplication can lead to maintenance issues, where a bug fix or an enhancement in one might be missed in the other.

To improve maintainability, consider refactoring these two functions into a single, more generic function. The core logic is the same, and the different parameter names (local_topk_ids vs topk_ids, local_num_experts vs num_experts) can be handled by a single implementation.



@triton.jit
def pre_reorder_triton_kernel_for_cutlass_moe(
input_ptr,
gateup_input_ptr,
src2dst_ptr,
topk_ids_ptr,
a1_scales_ptr,
num_experts,
topk,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
OutDtype = gateup_input_ptr.dtype.element_ty

src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk

src_ptr = input_ptr + src_idx * hidden_size
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id != num_experts:
if a1_scales_ptr is not None:
scale = 1.0 / tl.load(a1_scales_ptr)
else:
scale = 1.0

dst_idx = tl.load(src2dst_ptr + idx)
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
out_data = (in_data * scale).to(OutDtype)
tl.store(dst_ptr + offset, out_data, mask=mask)


@triton.jit
def pre_reorder_triton_kernel(
input_ptr,
Expand Down
Loading
Loading