Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def biased_grouped_topk_gpu(
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
aiter_biased_grouped_topk(
gating_output,
gating_output.to(dtype=torch.float32),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This explicit type casting is a good fix for the dtype mismatch error in the CI.

As you mentioned in the PR description, this is a temporary workaround. To improve long-term maintainability and ensure this is addressed later, I suggest adding a TODO comment here. This will make the reason for the cast clear to anyone reading the code in the future.

Suggested change
gating_output.to(dtype=torch.float32),
# TODO: Cast to float32 to match correction_bias.dtype. This is a temporary
# workaround for a limitation in the aiter kernel. Remove this cast once
# bf16/bf16 mixed-precision GEMM is supported.
# Ref: https://github.com/sgl-project/sglang/pull/7825
gating_output.to(dtype=torch.float32),

correction_bias,
topk_weights,
topk_ids,
Expand Down
Loading