88import vllm .model_executor .layers .fused_moe .modular_kernel as mk
99from vllm .logger import init_logger
1010from vllm .model_executor .layers .fused_moe .config import FusedMoEQuantConfig
11- from vllm .model_executor .layers .fused_moe .moe_permute_unpermute import (
12- _moe_permute )
11+ from vllm .model_executor .layers .fused_moe .deep_gemm_utils import (
12+ compute_aligned_M , deepgemm_moe_permute , deepgemm_unpermute_and_reduce )
1313from vllm .model_executor .layers .fused_moe .prepare_finalize import (
1414 MoEPrepareAndFinalizeNoEP )
1515from vllm .model_executor .layers .fused_moe .topk_weight_and_reduce import (
16- TopKWeightAndReduceContiguous , TopKWeightAndReduceNoOP )
16+ TopKWeightAndReduceNoOP )
1717from vllm .model_executor .layers .fused_moe .utils import _resize_cache
1818from vllm .model_executor .layers .quantization .utils .fp8_utils import (
1919 per_token_group_quant_fp8 )
20- from vllm .utils import has_deep_gemm , round_up
20+ from vllm .utils import has_deep_gemm
2121from vllm .utils .deep_gemm import m_grouped_fp8_gemm_nt_contiguous
2222
2323logger = init_logger (__name__ )
@@ -93,18 +93,25 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
9393 return TopKWeightAndReduceNoOP ()
9494
9595 def workspace_shapes (
96- self , a : torch .Tensor , aq : torch .Tensor , M : int , N : int , K : int ,
97- topk : int , global_num_experts : int , local_num_experts : int
96+ self ,
97+ a : torch .Tensor ,
98+ aq : torch .Tensor ,
99+ M : int ,
100+ N : int ,
101+ K : int ,
102+ topk : int ,
103+ global_num_experts : int ,
104+ local_num_experts : int ,
105+ expert_tokens_meta : Optional [mk .ExpertTokensMetadata ],
98106 ) -> tuple [tuple [int , ...], tuple [int , ...], tuple [int , ...], torch .dtype ]:
99107 assert self .block_shape is not None
100- # We use global_num_experts due to how moe_align_block_size handles
101- # expert_maps.
102- num_experts = global_num_experts
103108 block_m = self .block_shape [0 ]
104- M_sum = (M * topk ) + num_experts * (block_m - 1 )
105- M_sum = round_up (M_sum , block_m )
106- workspace1 = (M_sum , max (N // 2 , K ))
107- workspace2 = (M_sum , max (N , K ))
109+ M_sum = compute_aligned_M (M , topk , local_num_experts , block_m ,
110+ expert_tokens_meta )
111+ assert M_sum % block_m == 0
112+
113+ workspace1 = (M_sum , max (N , K ))
114+ workspace2 = (M_sum , max (N // 2 , K ))
108115 output = (M , K )
109116 return (workspace1 , workspace2 , output , a .dtype )
110117
@@ -131,43 +138,40 @@ def apply(
131138 apply_router_weight_on_input : bool ,
132139 ):
133140 assert self .block_shape is not None
141+ assert a1q_scale is not None
134142
135143 a1q = hidden_states
136144 _ , N , K = w1 .size ()
137- M , _ = output .size ()
138- num_topk = topk_ids .size (1 )
139145
146+ local_num_experts = w1 .size (0 )
140147 if global_num_experts == - 1 :
141- global_num_experts = w1 . size ( 0 )
148+ global_num_experts = local_num_experts
142149
143150 assert w2 .size (1 ) == K
144151
145- a1q , a1q_scale , _ , expert_ids , inv_perm = _moe_permute (
146- a1q ,
147- a1q_scale ,
148- topk_ids ,
149- global_num_experts ,
150- expert_map ,
151- self .block_shape [0 ],
152- )
153-
154- if expert_map is not None :
155- # DeepGemm (Grouped Contiguous) kernel needs a valid B index
156- # for all rows of A. To that effect, simply compute with
157- # the 0th weight matrix.
158- # Note that this relies on the fact that corresponding topk
159- # weights would be 0 during weight multiplication.
160- expert_ids = torch .where (expert_ids == - 1 , 0 , expert_ids )
161-
162- # Note: M_sum is different than the pre-permuted shape of a1q.
163- M_sum = a1q .size (0 )
164-
165- mm1_out = _resize_cache (workspace2 , (M_sum , N ))
166- act_out = _resize_cache (workspace13 , (M_sum , N // 2 ))
167- quant_out = _resize_cache (workspace2 .view (dtype = torch .float8_e4m3fn ),
152+ M_sum = compute_aligned_M (M = topk_ids .size (0 ),
153+ num_topk = topk_ids .size (1 ),
154+ local_num_experts = local_num_experts ,
155+ alignment = deep_gemm_block_shape ()[0 ],
156+ expert_tokens_meta = expert_tokens_meta )
157+
158+ a1q_perm = _resize_cache (workspace2 .view (dtype = torch .float8_e4m3fn ),
159+ (M_sum , K ))
160+ mm1_out = _resize_cache (workspace13 , (M_sum , N ))
161+ act_out = _resize_cache (workspace2 , (M_sum , N // 2 ))
162+ quant_out = _resize_cache (workspace13 .view (dtype = torch .float8_e4m3fn ),
168163 (M_sum , N // 2 ))
169- mm2_out = _resize_cache (workspace13 , (M_sum , K ))
170- perm_out = _resize_cache (workspace2 , (M * num_topk , K ))
164+ mm2_out = _resize_cache (workspace2 , (M_sum , K ))
165+
166+ a1q , a1q_scale , expert_ids , inv_perm = deepgemm_moe_permute (
167+ aq = a1q ,
168+ aq_scale = a1q_scale ,
169+ topk_ids = topk_ids ,
170+ local_num_experts = local_num_experts ,
171+ expert_map = expert_map ,
172+ expert_tokens_meta = expert_tokens_meta ,
173+ aq_out = a1q_perm )
174+ assert a1q .size (0 ) == M_sum
171175
172176 m_grouped_fp8_gemm_nt_contiguous ((a1q , a1q_scale ), (w1 , w1_scale ),
173177 mm1_out , expert_ids )
@@ -183,14 +187,15 @@ def apply(
183187 m_grouped_fp8_gemm_nt_contiguous ((a2q , a2q_scale ), (w2 , w2_scale ),
184188 mm2_out , expert_ids )
185189
186- torch .index_select (mm2_out , 0 , inv_perm , out = perm_out )
190+ if apply_router_weight_on_input :
191+ topk_weights = torch .ones_like (topk_weights )
187192
188- TopKWeightAndReduceContiguous (). apply (
189- output = output ,
190- fused_expert_output = perm_out ,
191- topk_weights = topk_weights ,
192- topk_ids = topk_ids ,
193- apply_router_weight_on_input = apply_router_weight_on_input )
193+ deepgemm_unpermute_and_reduce ( a = mm2_out ,
194+ topk_ids = topk_ids ,
195+ topk_weights = topk_weights ,
196+ inv_perm = inv_perm ,
197+ expert_map = expert_map ,
198+ output = output )
194199
195200
196201def deep_gemm_moe_fp8 (
0 commit comments