-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Fuse routed scaling factor in deepseek #6970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 28 commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
9dc3db1
upd
BBuf 8abddc1
upd
BBuf 1422361
upd
BBuf 23f93cb
upd
BBuf 061f26b
upd
BBuf 9741099
refine
BBuf 380f326
upd
BBuf 3dbdc35
upd
BBuf 4300d78
rebase baseline
BBuf 7704849
upd
BBuf 95881e2
upd
BBuf ef02f57
upd
BBuf 565e29e
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf fcc8947
upd
BBuf 40dc6a1
Merge branch 'fuse_routed_scaling_factor_in_deepseek' of github.com:s…
BBuf 0f94fba
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf efb3d6f
fix ci
BBuf f72890c
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf 51ae90d
fix ci
BBuf b540a22
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf 4cf655a
Merge branch 'fuse_routed_scaling_factor_in_deepseek' of github.com:s…
BBuf b84c12a
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf e51b9b7
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
merrymercy ecc6314
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf 1b74751
refine
BBuf 86c9cd5
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf e608247
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf cc7827b
refine
BBuf 9b15f7e
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf 18fc133
fallback for hip
BBuf a10cfd7
Merge branch 'fuse_routed_scaling_factor_in_deepseek' of github.com:s…
BBuf efbc0da
fallback for hip
BBuf f4f945d
upd
BBuf dcc7f78
refine
BBuf 9af55b5
refine
BBuf 9beaff2
refine
BBuf 4ac8ee9
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf b6d9743
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf c1a4e00
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf 9e1432b
Merge branch 'main' into fuse_routed_scaling_factor_in_deepseek
BBuf File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
199 changes: 199 additions & 0 deletions
199
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,199 @@ | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
| from triton.testing import do_bench | ||
|
|
||
|
|
||
| # _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py | ||
| @triton.jit | ||
| def _moe_sum_reduce_kernel( | ||
| input_ptr, | ||
| input_stride_0, | ||
| input_stride_1, | ||
| input_stride_2, | ||
| output_ptr, | ||
| output_stride_0, | ||
| output_stride_1, | ||
| token_num: int, | ||
| topk_num: int, | ||
| hidden_dim: int, | ||
| routed_scaling_factor: tl.constexpr, | ||
| BLOCK_M: tl.constexpr, | ||
| BLOCK_DIM: tl.constexpr, | ||
| NUM_STAGE: tl.constexpr, | ||
| ): | ||
| input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) | ||
| input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) | ||
| output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) | ||
|
|
||
| token_block_id = tl.program_id(0) | ||
| dim_block_id = tl.program_id(1) | ||
|
|
||
| token_start = token_block_id * BLOCK_M | ||
| token_end = min((token_block_id + 1) * BLOCK_M, token_num) | ||
|
|
||
| dim_start = dim_block_id * BLOCK_DIM | ||
| dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim) | ||
|
|
||
| offs_dim = dim_start + tl.arange(0, BLOCK_DIM) | ||
|
|
||
| for token_index in range(token_start, token_end): | ||
| accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32) | ||
| input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim | ||
| for i in tl.range(0, topk_num, num_stages=NUM_STAGE): | ||
| tmp = tl.load( | ||
| input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0 | ||
| ) | ||
| accumulator += tmp | ||
| accumulator = accumulator * routed_scaling_factor | ||
| store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim | ||
| tl.store( | ||
| store_t_ptr, | ||
| accumulator.to(input_ptr.dtype.element_ty), | ||
| mask=offs_dim < dim_end, | ||
| ) | ||
|
|
||
|
|
||
| def moe_sum_reduce( | ||
| input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float | ||
| ): | ||
| assert input.is_contiguous() | ||
| assert output.is_contiguous() | ||
|
|
||
| token_num, topk_num, hidden_dim = input.shape | ||
| assert output.shape[0] == token_num and output.shape[1] == hidden_dim | ||
|
|
||
| BLOCK_M = 1 | ||
| BLOCK_DIM = 2048 | ||
| NUM_STAGE = 1 | ||
| num_warps = 8 | ||
|
|
||
| grid = ( | ||
| triton.cdiv(token_num, BLOCK_M), | ||
| triton.cdiv(hidden_dim, BLOCK_DIM), | ||
| ) | ||
|
|
||
| _moe_sum_reduce_kernel[grid]( | ||
| input, | ||
| *input.stride(), | ||
| output, | ||
| *output.stride(), | ||
| token_num=token_num, | ||
| topk_num=topk_num, | ||
| hidden_dim=hidden_dim, | ||
| routed_scaling_factor=routed_scaling_factor, | ||
| BLOCK_M=BLOCK_M, | ||
| BLOCK_DIM=BLOCK_DIM, | ||
| NUM_STAGE=NUM_STAGE, | ||
| num_warps=num_warps, | ||
| ) | ||
| return | ||
|
|
||
|
|
||
| def compute_sum_scaled_baseline( | ||
| x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float | ||
| ) -> torch.Tensor: | ||
| torch.sum(x, dim=1, out=out) | ||
| out.mul_(routed_scaling_factor) | ||
| return out | ||
|
|
||
|
|
||
| @torch.compile | ||
| def compute_sum_scaled_compiled( | ||
| x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float | ||
| ) -> torch.Tensor: | ||
| torch.sum(x * routed_scaling_factor, dim=1, out=out) | ||
| return out | ||
|
|
||
|
|
||
| def get_benchmark(): | ||
| num_tokens_range = [2**i for i in range(0, 13)] | ||
|
|
||
| @triton.testing.perf_report( | ||
| triton.testing.Benchmark( | ||
| x_names=["num_tokens"], | ||
| x_vals=num_tokens_range, | ||
| line_arg="version", | ||
| line_vals=["baseline", "compiled", "triton"], | ||
| line_names=["Original", "TorchCompile", "TritonKernel"], | ||
| styles=[("blue", "-"), ("green", "-"), ("red", "-")], | ||
| ylabel="us", | ||
| plot_name="sum_scaled_performance", | ||
| args={}, | ||
| ) | ||
| ) | ||
| def benchmark(num_tokens, version): | ||
| topk = 9 | ||
| hidden_size = 4096 | ||
| dtype = torch.bfloat16 | ||
| scaling_factor = 0.3 | ||
|
|
||
| x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda") | ||
| out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") | ||
|
|
||
| # Warmup | ||
| for _ in range(3): | ||
| if version == "baseline": | ||
| compute_sum_scaled_baseline(x, out, scaling_factor) | ||
| elif version == "compiled": | ||
| compute_sum_scaled_compiled(x, out, scaling_factor) | ||
| else: | ||
| moe_sum_reduce(x, out, scaling_factor) | ||
|
|
||
| # Benchmark | ||
| quantiles = [0.5, 0.2, 0.8] | ||
| if version == "baseline": | ||
| ms, min_ms, max_ms = do_bench( | ||
| lambda: compute_sum_scaled_baseline(x, out, scaling_factor), | ||
| quantiles=quantiles, | ||
| ) | ||
| elif version == "compiled": | ||
| ms, min_ms, max_ms = do_bench( | ||
| lambda: compute_sum_scaled_compiled(x, out, scaling_factor), | ||
| quantiles=quantiles, | ||
| ) | ||
| else: | ||
| ms, min_ms, max_ms = do_bench( | ||
| lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles | ||
| ) | ||
|
|
||
| return 1000 * ms, 1000 * max_ms, 1000 * min_ms | ||
|
|
||
| return benchmark | ||
|
|
||
|
|
||
| def verify_correctness(num_tokens=1024): | ||
| x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16) | ||
| scaling_factor = 0.3 | ||
|
|
||
| out_baseline = torch.empty_like(x[:, 0]) | ||
| compute_sum_scaled_baseline(x, out_baseline, scaling_factor) | ||
|
|
||
| out_compiled = torch.empty_like(out_baseline) | ||
| compute_sum_scaled_compiled(x, out_compiled, scaling_factor) | ||
|
|
||
| out_triton = torch.empty_like(out_baseline) | ||
| moe_sum_reduce(x, out_triton, scaling_factor) | ||
|
|
||
| if torch.allclose( | ||
| out_baseline, out_compiled, atol=1e-2, rtol=1e-2 | ||
| ) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2): | ||
| print("✅ All implementations match") | ||
| else: | ||
| print("❌ Implementations differ") | ||
| print( | ||
| f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}" | ||
| ) | ||
| print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| print("Running correctness verification...") | ||
| verify_correctness() | ||
|
|
||
| print("\nRunning performance benchmark...") | ||
| benchmark = get_benchmark() | ||
| benchmark.run( | ||
| print_data=True, | ||
| # save_path="./configs/benchmark_ops/sum_scaled/" | ||
| ) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
compute_sum_scaled_compiledfunction here implementsscale-then-sum(torch.sum(x * routed_scaling_factor, ...)), while the Triton kernel (_moe_sum_reduce_kernel) and themoe_sum_reduce_torch_compilefunction inpython/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py(lines 1464-1467) implementsum-then-scale.If this benchmark is intended to compare the performance of the specific
sum-then-scaleoperation being fused into the model (as suggested by the removal of the scaling factor application afterself.expertsindeepseek_v2.py), then this compiled version tests a mathematically different operation. This could lead to misleading benchmark results and correctness issues in the comparison.Could you clarify if this difference is intentional for this specific benchmark, or should it be aligned with the
sum-then-scalelogic used elsewhere in this PR (e.g., by using the same logic asmoe_sum_reduce_torch_compilefromfused_moe.py)?