|
30 | 30 | from sglang.srt.managers.schedule_batch import global_server_args_dict |
31 | 31 | from sglang.srt.utils import ( |
32 | 32 | cpu_has_amx_support, |
| 33 | + get_bool_env_var, |
33 | 34 | get_compiler_backend, |
34 | 35 | is_cpu, |
35 | 36 | is_cuda, |
|
38 | 39 |
|
39 | 40 | _is_cuda = is_cuda() |
40 | 41 | _is_hip = is_hip() |
| 42 | +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip |
41 | 43 | _is_cpu_amx_available = cpu_has_amx_support() |
42 | 44 | _is_cpu = is_cpu() |
43 | 45 |
|
|
46 | 48 |
|
47 | 49 | if _is_cuda or _is_hip: |
48 | 50 | from sgl_kernel import topk_softmax |
| 51 | +if _use_aiter: |
| 52 | + try: |
| 53 | + from aiter import biased_grouped_topk as aiter_biased_grouped_topk |
| 54 | + except ImportError: |
| 55 | + raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") |
49 | 56 |
|
50 | 57 |
|
51 | 58 | def fused_topk_torch_native( |
@@ -347,6 +354,25 @@ def biased_grouped_topk_gpu( |
347 | 354 | topk_ids, expert_location_dispatch_info, num_token_non_padded |
348 | 355 | ) |
349 | 356 | return topk_weights, topk_ids |
| 357 | + elif _use_aiter: |
| 358 | + token = gating_output.shape[0] |
| 359 | + device = gating_output.device |
| 360 | + assert ( |
| 361 | + hidden_states.shape[0] == gating_output.shape[0] |
| 362 | + ), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}" |
| 363 | + topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) |
| 364 | + topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) |
| 365 | + aiter_biased_grouped_topk( |
| 366 | + gating_output, |
| 367 | + correction_bias, |
| 368 | + topk_weights, |
| 369 | + topk_ids, |
| 370 | + num_expert_group, |
| 371 | + topk_group, |
| 372 | + renormalize, |
| 373 | + routed_scaling_factor, |
| 374 | + ) |
| 375 | + return topk_weights, topk_ids |
350 | 376 | else: |
351 | 377 | biased_grouped_topk_fn = ( |
352 | 378 | torch.compile( |
|
0 commit comments