@@ -350,19 +350,31 @@ def select_experts(
350350 expert_location_dispatch_info = expert_location_dispatch_info ,
351351 )
352352 else :
353- topk_weights , topk_ids = biased_grouped_topk (
354- hidden_states = hidden_states ,
355- gating_output = router_logits ,
356- correction_bias = correction_bias ,
357- topk = top_k ,
358- renormalize = renormalize ,
359- num_expert_group = num_expert_group ,
360- topk_group = topk_group ,
361- num_fused_shared_experts = num_fused_shared_experts ,
362- routed_scaling_factor = routed_scaling_factor ,
363- num_token_non_padded = num_token_non_padded ,
364- expert_location_dispatch_info = expert_location_dispatch_info ,
365- )
353+ device = hidden_states .device
354+ if device == torch .device ("cpu" ) and _is_cpu_amx :
355+ topk_weights , topk_ids = torch .ops .sgl_kernel .biased_grouped_topk_cpu (
356+ hidden_states ,
357+ router_logits ,
358+ correction_bias ,
359+ top_k ,
360+ renormalize ,
361+ num_expert_group ,
362+ topk_group ,
363+ )
364+ else :
365+ topk_weights , topk_ids = biased_grouped_topk (
366+ hidden_states = hidden_states ,
367+ gating_output = router_logits ,
368+ correction_bias = correction_bias ,
369+ topk = top_k ,
370+ renormalize = renormalize ,
371+ num_expert_group = num_expert_group ,
372+ topk_group = topk_group ,
373+ num_fused_shared_experts = num_fused_shared_experts ,
374+ routed_scaling_factor = routed_scaling_factor ,
375+ num_token_non_padded = num_token_non_padded ,
376+ expert_location_dispatch_info = expert_location_dispatch_info ,
377+ )
366378 elif torch_native and custom_routing_function is None :
367379 assert (
368380 num_token_non_padded is None
0 commit comments