Skip to content

Commit e984d50

Browse files
authored
enable aiter_biased_grouped_topk kernel (#7423)
1 parent 755f314 commit e984d50

3 files changed

Lines changed: 29 additions & 2 deletions

File tree

python/sglang/srt/layers/moe/topk.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sglang.srt.managers.schedule_batch import global_server_args_dict
3131
from sglang.srt.utils import (
3232
cpu_has_amx_support,
33+
get_bool_env_var,
3334
get_compiler_backend,
3435
is_cpu,
3536
is_cuda,
@@ -38,6 +39,7 @@
3839

3940
_is_cuda = is_cuda()
4041
_is_hip = is_hip()
42+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
4143
_is_cpu_amx_available = cpu_has_amx_support()
4244
_is_cpu = is_cpu()
4345

@@ -46,6 +48,11 @@
4648

4749
if _is_cuda or _is_hip:
4850
from sgl_kernel import topk_softmax
51+
if _use_aiter:
52+
try:
53+
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
54+
except ImportError:
55+
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
4956

5057

5158
def fused_topk_torch_native(
@@ -347,6 +354,25 @@ def biased_grouped_topk_gpu(
347354
topk_ids, expert_location_dispatch_info, num_token_non_padded
348355
)
349356
return topk_weights, topk_ids
357+
elif _use_aiter:
358+
token = gating_output.shape[0]
359+
device = gating_output.device
360+
assert (
361+
hidden_states.shape[0] == gating_output.shape[0]
362+
), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
363+
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
364+
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
365+
aiter_biased_grouped_topk(
366+
gating_output,
367+
correction_bias,
368+
topk_weights,
369+
topk_ids,
370+
num_expert_group,
371+
topk_group,
372+
renormalize,
373+
routed_scaling_factor,
374+
)
375+
return topk_weights, topk_ids
350376
else:
351377
biased_grouped_topk_fn = (
352378
torch.compile(

python/sglang/srt/model_executor/cuda_graph_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def capture(self) -> None:
421421
empty_cache=False,
422422
)
423423
capture_range.set_description(
424-
f"Capturing batches ({avail_mem=:.2f} GB)"
424+
f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
425425
)
426426

427427
with patch_model(

python/sglang/srt/models/deepseek_v2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,8 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
388388
final_hidden_states = self.experts(
389389
hidden_states=hidden_states, router_logits=router_logits
390390
)
391-
if not _is_cuda:
391+
if not _is_cuda and not _use_aiter:
392+
# fused in biased_grouped_topk so we can skip here
392393
final_hidden_states *= self.routed_scaling_factor
393394
if shared_output is not None:
394395
final_hidden_states = final_hidden_states + shared_output

0 commit comments

Comments
 (0)