Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9dc3db1
upd
BBuf May 12, 2025
8abddc1
upd
BBuf May 12, 2025
1422361
upd
BBuf May 12, 2025
23f93cb
upd
BBuf May 12, 2025
061f26b
upd
BBuf May 12, 2025
9741099
refine
BBuf May 12, 2025
380f326
upd
BBuf May 12, 2025
3dbdc35
upd
BBuf May 12, 2025
4300d78
rebase baseline
BBuf May 12, 2025
7704849
upd
BBuf May 12, 2025
95881e2
upd
BBuf May 12, 2025
ef02f57
upd
BBuf May 12, 2025
565e29e
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf May 12, 2025
fcc8947
upd
BBuf May 12, 2025
40dc6a1
Merge branch 'fuse_routed_scaling_factor_in_deepseek' of github.com:s…
BBuf May 12, 2025
0f94fba
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf May 13, 2025
efb3d6f
fix ci
BBuf May 13, 2025
f72890c
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf May 13, 2025
51ae90d
fix ci
BBuf May 13, 2025
b540a22
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf May 13, 2025
4cf655a
Merge branch 'fuse_routed_scaling_factor_in_deepseek' of github.com:s…
BBuf May 13, 2025
b84c12a
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf Jun 7, 2025
e51b9b7
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
merrymercy Jun 7, 2025
ecc6314
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf Jun 7, 2025
1b74751
refine
BBuf Jun 7, 2025
86c9cd5
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf Jun 7, 2025
e608247
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf Jun 7, 2025
cc7827b
refine
BBuf Jun 8, 2025
9b15f7e
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf Jun 8, 2025
18fc133
fallback for hip
BBuf Jun 8, 2025
a10cfd7
Merge branch 'fuse_routed_scaling_factor_in_deepseek' of github.com:s…
BBuf Jun 8, 2025
efbc0da
fallback for hip
BBuf Jun 8, 2025
f4f945d
upd
BBuf Jun 8, 2025
dcc7f78
refine
BBuf Jun 8, 2025
9af55b5
refine
BBuf Jun 8, 2025
9beaff2
refine
BBuf Jun 8, 2025
4ac8ee9
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf Jun 8, 2025
b6d9743
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf Jun 8, 2025
c1a4e00
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf Jun 8, 2025
9e1432b
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf Jun 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 199 additions & 0 deletions benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import torch
import triton
import triton.language as tl
from triton.testing import do_bench


# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
@triton.jit
def _moe_sum_reduce_kernel(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: int,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)

token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)

token_start = token_block_id * BLOCK_M
token_end = min((token_block_id + 1) * BLOCK_M, token_num)

dim_start = dim_block_id * BLOCK_DIM
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)

offs_dim = dim_start + tl.arange(0, BLOCK_DIM)

for token_index in range(token_start, token_end):
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tmp = tl.load(
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
)
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
)


def moe_sum_reduce(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
assert output.is_contiguous()

token_num, topk_num, hidden_dim = input.shape
assert output.shape[0] == token_num and output.shape[1] == hidden_dim

BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 1
num_warps = 8

grid = (
triton.cdiv(token_num, BLOCK_M),
triton.cdiv(hidden_dim, BLOCK_DIM),
)

_moe_sum_reduce_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
token_num=token_num,
topk_num=topk_num,
hidden_dim=hidden_dim,
routed_scaling_factor=routed_scaling_factor,
BLOCK_M=BLOCK_M,
BLOCK_DIM=BLOCK_DIM,
NUM_STAGE=NUM_STAGE,
num_warps=num_warps,
)
return


def compute_sum_scaled_baseline(
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
) -> torch.Tensor:
torch.sum(x, dim=1, out=out)
out.mul_(routed_scaling_factor)
return out


@torch.compile
def compute_sum_scaled_compiled(
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
) -> torch.Tensor:
torch.sum(x * routed_scaling_factor, dim=1, out=out)
return out
Comment on lines +102 to +106
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The compute_sum_scaled_compiled function here implements scale-then-sum (torch.sum(x * routed_scaling_factor, ...)), while the Triton kernel (_moe_sum_reduce_kernel) and the moe_sum_reduce_torch_compile function in python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py (lines 1464-1467) implement sum-then-scale.

If this benchmark is intended to compare the performance of the specific sum-then-scale operation being fused into the model (as suggested by the removal of the scaling factor application after self.experts in deepseek_v2.py), then this compiled version tests a mathematically different operation. This could lead to misleading benchmark results and correctness issues in the comparison.

Could you clarify if this difference is intentional for this specific benchmark, or should it be aligned with the sum-then-scale logic used elsewhere in this PR (e.g., by using the same logic as moe_sum_reduce_torch_compile from fused_moe.py)?



def get_benchmark():
num_tokens_range = [2**i for i in range(0, 13)]

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=num_tokens_range,
line_arg="version",
line_vals=["baseline", "compiled", "triton"],
line_names=["Original", "TorchCompile", "TritonKernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="sum_scaled_performance",
args={},
)
)
def benchmark(num_tokens, version):
topk = 9
hidden_size = 4096
dtype = torch.bfloat16
scaling_factor = 0.3

x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")

# Warmup
for _ in range(3):
if version == "baseline":
compute_sum_scaled_baseline(x, out, scaling_factor)
elif version == "compiled":
compute_sum_scaled_compiled(x, out, scaling_factor)
else:
moe_sum_reduce(x, out, scaling_factor)

# Benchmark
quantiles = [0.5, 0.2, 0.8]
if version == "baseline":
ms, min_ms, max_ms = do_bench(
lambda: compute_sum_scaled_baseline(x, out, scaling_factor),
quantiles=quantiles,
)
elif version == "compiled":
ms, min_ms, max_ms = do_bench(
lambda: compute_sum_scaled_compiled(x, out, scaling_factor),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = do_bench(
lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles
)

return 1000 * ms, 1000 * max_ms, 1000 * min_ms

return benchmark


def verify_correctness(num_tokens=1024):
x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16)
scaling_factor = 0.3

out_baseline = torch.empty_like(x[:, 0])
compute_sum_scaled_baseline(x, out_baseline, scaling_factor)

out_compiled = torch.empty_like(out_baseline)
compute_sum_scaled_compiled(x, out_compiled, scaling_factor)

out_triton = torch.empty_like(out_baseline)
moe_sum_reduce(x, out_triton, scaling_factor)

if torch.allclose(
out_baseline, out_compiled, atol=1e-2, rtol=1e-2
) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
print(
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
)
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}")


if __name__ == "__main__":
print("Running correctness verification...")
verify_correctness()

print("\nRunning performance benchmark...")
benchmark = get_benchmark()
benchmark.run(
print_data=True,
# save_path="./configs/benchmark_ops/sum_scaled/"
)
Loading
Loading