|
2 | 2 | import functools |
3 | 3 | import json |
4 | 4 | import os |
5 | | -from typing import Any, Dict, Optional |
| 5 | +from typing import Any, Dict, Optional, Tuple |
6 | 6 |
|
7 | 7 | import torch |
8 | 8 | import triton |
@@ -137,7 +137,7 @@ def fused_moe_kernel( |
137 | 137 |
|
138 | 138 | def moe_align_block_size( |
139 | 139 | topk_ids: torch.Tensor, block_size: int, |
140 | | - num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): |
| 140 | + num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
141 | 141 | """ |
142 | 142 | Aligns the token distribution across experts to be compatible with block size for matrix multiplication. |
143 | 143 |
|
@@ -185,7 +185,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, |
185 | 185 | sorted_token_ids: torch.Tensor, |
186 | 186 | expert_ids: torch.Tensor, |
187 | 187 | num_tokens_post_padded: torch.Tensor, |
188 | | - mul_routed_weight: bool, top_k: int, config: dict): |
| 188 | + mul_routed_weight: bool, top_k: int, |
| 189 | + config: Dict[str, Any]) -> None: |
189 | 190 | assert topk_weights.stride(1) == 1 |
190 | 191 | assert sorted_token_ids.stride(0) == 1 |
191 | 192 |
|
|
0 commit comments