Skip to content

Commit dd25c4d

Browse files
Implement per-token w4afp8 CUTLASS MoE GEMM for FP8 dispatch
Co-authored-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
1 parent 6dfa8a4 commit dd25c4d

File tree

15 files changed

+1302
-409
lines changed

15 files changed

+1302
-409
lines changed

python/sglang/srt/layers/quantization/fp8_kernel.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2055,6 +2055,82 @@ def triton_scaled_mm(
20552055
return result.to(out_dtype)
20562056

20572057

2058+
@triton.jit
2059+
def interleave_int4xfp8_Hopper_kernel(
2060+
ptr_4b,
2061+
ptr_4b_interleaved,
2062+
rows,
2063+
cols,
2064+
grid_0: tl.constexpr,
2065+
BLOCK_TX: tl.constexpr,
2066+
BLOCK_TY: tl.constexpr,
2067+
):
2068+
blockIdx_x = tl.program_id(0)
2069+
lane_x = tl.arange(0, BLOCK_TX)
2070+
lane_y = tl.arange(0, BLOCK_TY)
2071+
2072+
rows_half = rows // 2
2073+
cols_div4 = cols // 4
2074+
2075+
block_id = blockIdx_x
2076+
while block_id < rows_half:
2077+
partition_id = lane_y
2078+
mask_partition = partition_id < (cols // 64)
2079+
while tl.max(mask_partition.to(tl.int32)) > 0:
2080+
lane_id = lane_x[None, :]
2081+
partition = partition_id[:, None]
2082+
row_id = (block_id // 8) * 16 + (block_id % 8)
2083+
dst_row_id = row_id + ((lane_id % 8) // 4) * 8
2084+
mma_id = lane_id // 8
2085+
interleaved_lane_id = mma_id * 8 + (lane_id % 4) * 2
2086+
col_id = partition * 16 + lane_id
2087+
dst_col_id = partition * 16 + interleaved_lane_id
2088+
2089+
src_id_a = row_id * cols_div4 + col_id
2090+
src_id_b = (row_id + 8) * cols_div4 + col_id
2091+
2092+
dst_id = dst_row_id * cols_div4 + dst_col_id
2093+
2094+
valid = (
2095+
(col_id < cols_div4)
2096+
& (dst_row_id < rows)
2097+
& (row_id + 8 < rows)
2098+
& mask_partition[:, None]
2099+
)
2100+
2101+
fp4x2_a = tl.load(ptr_4b + src_id_a, mask=valid, other=0).to(tl.uint16)
2102+
fp4x2_b = tl.load(ptr_4b + src_id_b, mask=valid, other=0).to(tl.uint16)
2103+
2104+
tl.store(ptr_4b_interleaved + dst_id, fp4x2_a, mask=valid)
2105+
tl.store(ptr_4b_interleaved + (dst_id + 1), fp4x2_b, mask=valid)
2106+
2107+
partition_id = partition_id + BLOCK_TY
2108+
mask_partition = partition_id < (cols // 64)
2109+
2110+
block_id += grid_0
2111+
2112+
2113+
def interleave_int4(int4_ptr, int4_interleaved_ptr, rows, cols):
2114+
BLOCK_TX = 16
2115+
BLOCK_TY = 32
2116+
grid_0 = 1024
2117+
2118+
int16_ptr = int4_ptr.view(torch.int16) # reinterpret_cast<const uint16_t*>
2119+
int16_interleaved_ptr = int4_interleaved_ptr.view(
2120+
torch.int16
2121+
) # reinterpret_cast<uint16_t*>
2122+
2123+
interleave_int4xfp8_Hopper_kernel[(grid_0,)](
2124+
int16_ptr,
2125+
int16_interleaved_ptr,
2126+
rows,
2127+
cols,
2128+
grid_0=grid_0,
2129+
BLOCK_TX=BLOCK_TX,
2130+
BLOCK_TY=BLOCK_TY,
2131+
)
2132+
2133+
20582134
if _is_cuda:
20592135
if enable_sgl_per_token_group_quant_8bit:
20602136

sgl-kernel/csrc/common_extension.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
244244
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
245245
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
246246
" Tensor problem_sizes, Tensor a_strides, "
247-
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
247+
" Tensor b_strides, Tensor d_strides, Tensor sa_strides, Tensor sb_strides, "
248248
" int chunk_size, int topk) -> ()");
249249
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
250250

0 commit comments

Comments
 (0)