|
19 | 19 | maybe_warn_marlin_atomic_add, |
20 | 20 | ) |
21 | 21 | from vllm.scalar_type import ScalarType, scalar_types |
22 | | -from vllm.utils import direct_register_custom_op |
23 | 22 |
|
24 | 23 |
|
25 | 24 | def fused_marlin_moe( |
@@ -241,44 +240,6 @@ def fused_marlin_moe( |
241 | 240 | return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output) |
242 | 241 |
|
243 | 242 |
|
244 | | -def fused_marlin_moe_fake( |
245 | | - hidden_states: torch.Tensor, |
246 | | - w1: torch.Tensor, |
247 | | - w2: torch.Tensor, |
248 | | - w1_scale: torch.Tensor, |
249 | | - w2_scale: torch.Tensor, |
250 | | - gating_output: torch.Tensor | None, |
251 | | - topk_weights: torch.Tensor, |
252 | | - topk_ids: torch.Tensor, |
253 | | - quant_type_id: int, |
254 | | - apply_router_weight_on_input: bool = False, |
255 | | - global_num_experts: int = -1, |
256 | | - global_scale1: torch.Tensor | None = None, |
257 | | - global_scale2: torch.Tensor | None = None, |
258 | | - expert_map: torch.Tensor | None = None, |
259 | | - g_idx1: torch.Tensor | None = None, |
260 | | - g_idx2: torch.Tensor | None = None, |
261 | | - sort_indices1: torch.Tensor | None = None, |
262 | | - sort_indices2: torch.Tensor | None = None, |
263 | | - w1_zeros: torch.Tensor | None = None, |
264 | | - w2_zeros: torch.Tensor | None = None, |
265 | | - workspace: torch.Tensor | None = None, |
266 | | - intermediate_cache13: torch.Tensor | None = None, |
267 | | - intermediate_cache2: torch.Tensor | None = None, |
268 | | - is_k_full: bool = True, |
269 | | - output: torch.Tensor | None = None, |
270 | | - inplace: bool = False, |
271 | | -) -> torch.Tensor: |
272 | | - return torch.empty_like(hidden_states) |
273 | | - |
274 | | - |
275 | | -direct_register_custom_op( |
276 | | - op_name="fused_marlin_moe", |
277 | | - op_func=fused_marlin_moe, |
278 | | - fake_impl=fused_marlin_moe_fake, |
279 | | -) |
280 | | - |
281 | | - |
282 | 243 | class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): |
283 | 244 | def __init__(self, quant_config: FusedMoEQuantConfig): |
284 | 245 | # TODO (varun) : Enable activation quantization |
|
0 commit comments