Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
47a5c62
refactor apply_w8a8_block_fp8_linear
ChangyiYang May 23, 2025
87a5629
refactoring the dispatching logic to avoid filtering overhead
ChangyiYang May 25, 2025
cb4909c
Modify comments in fp8_kernel.py
ChangyiYang May 26, 2025
fecf409
create w8a8_block_fp8_matmul_triton, leave w8a8_block_fp8_matmul as a…
ChangyiYang May 26, 2025
cf851a1
fix typo
ChangyiYang May 26, 2025
48e3781
Update kernel function ref in bench_fp8_blockwise_gemm.py
ChangyiYang May 27, 2025
e8db7ce
Merge branch 'main' into refactor_apply_w8a8_block_fp8_linear
Alcanderian May 27, 2025
bea308a
fix referenced before assignment error
ChangyiYang May 28, 2025
2c78443
fix bug that output_dtype is not passed correctly
ChangyiYang May 28, 2025
807fc4e
fix bug that output_dtype is not passed correctly
ChangyiYang May 28, 2025
b0e2c42
Merge branch 'main' into refactor_apply_w8a8_block_fp8_linear
zhyncs May 28, 2025
bdab9a1
[PD Perf] replace Queue to FastQueue (#6649)
whybeyoung May 28, 2025
279885a
[Bugfix] Fix slice operation when chunk size mismatch (#6697)
ShangmingCai May 28, 2025
8f84f9c
[Bugfix] Fix ChatCompletion endpoint of mini_lb when stream is set (#…
ShangmingCai May 28, 2025
0de019b
[CI] Fix setup of disaggregation with different tp (#6706)
ShangmingCai May 28, 2025
cbf1e96
[PD] Remove Unnecessary Exception Handling for FastQueue.get() (#6712)
Hongbosherlock May 28, 2025
1a55a95
Fuse routed_scaling_factor in DeepSeek (#6710)
fzyzcjy May 28, 2025
6374a27
Overlap two kernels in DeepSeek with communication (#6711)
fzyzcjy May 28, 2025
cdedbb3
Minor refactor two-batch overlap (#6682)
fzyzcjy May 28, 2025
93ef744
Speed up when having padding tokens two-batch overlap (#6668)
fzyzcjy May 28, 2025
7df4699
[Feature] Support Flashinfer fp8 blockwise GEMM kernel on Blackwell (…
Fridge003 May 28, 2025
8368a5d
create new dispatching function flashinfer_gemm_w8a8_block_fp8_linear
ChangyiYang May 29, 2025
2664692
Merge branch 'main' into refactor_apply_w8a8_block_fp8_linear
ChangyiYang May 29, 2025
4ef9146
Merge branch 'main' into refactor_apply_w8a8_block_fp8_linear
ChangyiYang May 29, 2025
37bf9c9
Merge branch 'main' into refactor_apply_w8a8_block_fp8_linear
ChangyiYang May 29, 2025
d4e3cf6
Merge branch 'main' into refactor_apply_w8a8_block_fp8_linear
zhyncs May 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
)

from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul
from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_deepgemm as w8a8_block_fp8_matmul,
)


# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
Expand Down
6 changes: 4 additions & 2 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def dummy_func(*args, **kwargs):
)
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
apply_w8a8_block_fp8_linear,
cutlass_fp8_supported,
dispatch_w8a8_block_fp8_linear,
input_to_float8,
is_sm100_supported,
normalize_e4m3fn_to_e4m3fnuz,
Expand Down Expand Up @@ -209,6 +209,8 @@ def __init__(self, quant_config: Fp8Config):
# Marlin doesn't support block-wise fp8
self.use_marlin = False

self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()

def create_weights(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -417,7 +419,7 @@ def apply(
)

if self.block_quant:
return apply_w8a8_block_fp8_linear(
return self.w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
Expand Down
184 changes: 118 additions & 66 deletions python/sglang/srt/layers/quantization/fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,59 @@ def select_w8a8_block_fp8_matmul_kernel(M, N, META):
return _w8a8_block_fp8_matmul


def w8a8_block_fp8_matmul(
def prepare_block_fp8_matmul_inputs(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> Tuple[int, int, int]:
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]

assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
assert A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]

M = A.numel() // A.shape[-1]

assert B.ndim == 2
assert B.is_contiguous()
assert Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]

C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)

return M, N, K, C


def w8a8_block_fp8_matmul_deepgemm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)

# Deepgemm only supports output tensor type as bfloat16
assert C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM

if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else:
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)

return C


def w8a8_block_fp8_matmul_triton(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Expand All @@ -764,81 +816,81 @@ def w8a8_block_fp8_matmul(
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]

assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]

assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)

C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
block_n, block_k = block_size

# deepgemm only support bf16
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else:
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}

def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}

def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)

kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)
kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)

kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)

return C


# universal entry point, for testing purposes
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
if output_dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
return w8a8_block_fp8_matmul_deepgemm(
A, B, As, Bs, block_size, output_dtype=output_dtype
)

return w8a8_block_fp8_matmul_triton(
A, B, As, Bs, block_size, output_dtype=output_dtype
)


@triton.jit
def _per_tensor_quant_mla_fp8_stage1(
x_ptr,
Expand Down
Loading
Loading