Skip to content

Commit 7c4b6f9

Browse files
committed
Add fused biased_grouped_topk (sgl-project#20)
* Add fused biased_grouped_topk * add record function
1 parent 269594f commit 7c4b6f9

File tree

1 file changed

+25
-13
lines changed
  • python/sglang/srt/layers/moe

1 file changed

+25
-13
lines changed

python/sglang/srt/layers/moe/topk.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -350,19 +350,31 @@ def select_experts(
350350
expert_location_dispatch_info=expert_location_dispatch_info,
351351
)
352352
else:
353-
topk_weights, topk_ids = biased_grouped_topk(
354-
hidden_states=hidden_states,
355-
gating_output=router_logits,
356-
correction_bias=correction_bias,
357-
topk=top_k,
358-
renormalize=renormalize,
359-
num_expert_group=num_expert_group,
360-
topk_group=topk_group,
361-
num_fused_shared_experts=num_fused_shared_experts,
362-
routed_scaling_factor=routed_scaling_factor,
363-
num_token_non_padded=num_token_non_padded,
364-
expert_location_dispatch_info=expert_location_dispatch_info,
365-
)
353+
device = hidden_states.device
354+
if device == torch.device("cpu") and _is_cpu_amx:
355+
topk_weights, topk_ids = torch.ops.sgl_kernel.biased_grouped_topk_cpu(
356+
hidden_states,
357+
router_logits,
358+
correction_bias,
359+
top_k,
360+
renormalize,
361+
num_expert_group,
362+
topk_group,
363+
)
364+
else:
365+
topk_weights, topk_ids = biased_grouped_topk(
366+
hidden_states=hidden_states,
367+
gating_output=router_logits,
368+
correction_bias=correction_bias,
369+
topk=top_k,
370+
renormalize=renormalize,
371+
num_expert_group=num_expert_group,
372+
topk_group=topk_group,
373+
num_fused_shared_experts=num_fused_shared_experts,
374+
routed_scaling_factor=routed_scaling_factor,
375+
num_token_non_padded=num_token_non_padded,
376+
expert_location_dispatch_info=expert_location_dispatch_info,
377+
)
366378
elif torch_native and custom_routing_function is None:
367379
assert (
368380
num_token_non_padded is None

0 commit comments

Comments
 (0)