From 0a0a2275745779afd40fb981fd78705533fb9080 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Apr 2025 20:32:34 +0000 Subject: [PATCH] Re-fuse triton moe weight application Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index aa0bd553fc32..0817879c4d57 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1297,30 +1297,24 @@ def fused_experts_impl(hidden_states: torch.Tensor, qintermediate_cache2 = intermediate_cache2 a2q_scale = a2_scale - invoke_fused_moe_kernel( - qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - w2_scale, - w2_zp, - curr_topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, #True, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - block_shape=block_shape) - - if True: - intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K) - intermediate_cache3.mul_( - curr_topk_weights.view(tokens_in_chunk, -1, 1)) + invoke_fused_moe_kernel(qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx])