diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index f6467b3841..890934a93c 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -161,7 +161,7 @@ def fused_moe_fake( device = topk_ids.device M, topk = topk_ids.shape dtype = hidden_states.dtype if dtype is None else dtype - E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) + model_dim = w2.shape[1] moe_buf = torch.empty((M, model_dim), dtype=dtype, device=device) return moe_buf