@@ -119,7 +119,6 @@ def __init__(self, moe: FusedMoEConfig):
119119 super ().__init__ ()
120120 self .moe = moe
121121 self .moe_quant_config : FusedMoEQuantConfig | None = None
122- self .topk_indices_dtype = None
123122
124123 @abstractmethod
125124 def create_weights (
@@ -244,7 +243,7 @@ def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
244243 else :
245244 return None
246245
247- def init_prepare_finalize (
246+ def maybe_init_modular_kernel (
248247 self , layer : torch .nn .Module
249248 ) -> FusedMoEModularKernel | None :
250249 assert self .moe is not None
@@ -260,8 +259,6 @@ def init_prepare_finalize(
260259 logger .debug (
261260 "%s for %s(%s)" , prepare_finalize .__class__ .__name__ , self , id (self )
262261 )
263- assert self .topk_indices_dtype is None
264- self .topk_indices_dtype = prepare_finalize .topk_indices_dtype ()
265262 experts = self .select_gemm_impl (prepare_finalize , layer )
266263 return FusedMoEModularKernel (
267264 prepare_finalize ,
@@ -289,6 +286,10 @@ def get_fused_moe_quant_config(
289286 ) -> FusedMoEQuantConfig | None :
290287 raise NotImplementedError
291288
289+ @property
290+ def topk_indices_dtype (self ) -> torch .dtype | None :
291+ return None
292+
292293 @property
293294 def supports_eplb (self ) -> bool :
294295 return False
@@ -328,31 +329,33 @@ def apply(
328329class FusedMoEModularMethod (FusedMoEMethodBase , CustomOp ):
329330 def __init__ (
330331 self ,
331- old_moe_method : FusedMoEMethodBase ,
332+ old_quant_method : FusedMoEMethodBase ,
332333 fused_experts : FusedMoEModularKernel ,
333334 ):
334- super ().__init__ (old_moe_method .moe )
335- # Find better way to copy attributes?
336- # self.__dict__.update(old_moe_method.__dict__)
337-
338- self .moe_quant_config = old_moe_method .moe_quant_config
335+ super ().__init__ (old_quant_method .moe )
336+ # Find better way to copy attributes? Should we even copy attributes?
337+ # self.__dict__.update(old_quant_method.__dict__)
338+ self .moe_quant_config = old_quant_method .moe_quant_config
339339 self .fused_experts = fused_experts
340- self .topk_indices_dtype = old_moe_method .topk_indices_dtype
341- self .disable_expert_map = not fused_experts .supports_expert_map ()
342- self .old_method_name = old_moe_method .__class__ .__name__
343- self ._supports_eplb = old_moe_method .supports_eplb
344- self ._allow_inplace = old_moe_method .allow_inplace
345- if isinstance (old_moe_method , torch .nn .Module ):
346- self .load_state_dict (old_moe_method .state_dict ())
347- logger .debug ("Swapping out %s" , self .old_method_name )
340+ self .disable_expert_map = getattr (
341+ old_quant_method ,
342+ "disable_expert_map" ,
343+ not fused_experts .supports_expert_map (),
344+ )
345+ self .old_quant_method = old_quant_method
346+ logger .debug ("Swapping out %s" , self .old_quant_method .__class__ .__name__ )
347+
348+ @property
349+ def topk_indices_dtype (self ) -> torch .dtype | None :
350+ return self .fused_experts .prepare_finalize .topk_indices_dtype ()
348351
349352 @property
350353 def supports_eplb (self ) -> bool :
351- return self ._supports_eplb
354+ return self .old_quant_method . supports_eplb
352355
353356 @property
354357 def allow_inplace (self ) -> bool :
355- return self ._allow_inplace
358+ return self .old_quant_method . allow_inplace
356359
357360 def create_weights (
358361 self ,
@@ -405,10 +408,11 @@ def apply(
405408 assert isinstance (layer , FusedMoE )
406409 else :
407410 raise NotImplementedError (
408- f"EPLB is not supported for { self .old_method_name } "
411+ "EPLB is not supported for "
412+ f"{ self .old_quant_method .__class__ .__name__ } ."
409413 )
410414
411- select_result = FusedMoE .select_experts (
415+ topk_weights , topk_ids , zero_expert_result = FusedMoE .select_experts (
412416 hidden_states = x ,
413417 router_logits = router_logits ,
414418 use_grouped_topk = use_grouped_topk ,
@@ -431,8 +435,6 @@ def apply(
431435 zero_expert_type = zero_expert_type ,
432436 )
433437
434- topk_weights , topk_ids , zero_expert_result = select_result
435-
436438 result = self .fused_experts (
437439 hidden_states = x ,
438440 w1 = layer .w13_weight ,
@@ -1421,7 +1423,7 @@ def _get_quant_method() -> FusedMoEMethodBase:
14211423 )
14221424
14231425 if not isinstance (
1424- quant_method , (UnquantizedFusedMoEMethod , ModelOptFp8MoEMethod )
1426+ self . quant_method , (UnquantizedFusedMoEMethod , ModelOptFp8MoEMethod )
14251427 ):
14261428 raise NotImplementedError (
14271429 "is_act_and_mul=False is supported only for unquantized "
@@ -1441,6 +1443,7 @@ def _get_quant_method() -> FusedMoEMethodBase:
14411443 # If you plan to add support for more quantization methods,
14421444 # please refer to the implementation in `Fp8MoEMethod`.
14431445 raise NotImplementedError (
1446+ f"EPLB is not supported { self .quant_method .__class__ .__name__ } . "
14441447 "EPLB is only supported for FP8 quantization for now."
14451448 )
14461449
@@ -1466,12 +1469,12 @@ def _get_quant_method() -> FusedMoEMethodBase:
14661469 self .batched_hidden_states : torch .Tensor | None = None
14671470 self .batched_router_logits : torch .Tensor | None = None
14681471
1469- # Note: init_prepare_finalize should only be called by
1472+ # Note: maybe_init_modular_kernel should only be called by
14701473 # prepare_communication_buffer_for_model.
14711474 # This is called after all weight loading and post-processing, so it
14721475 # should be safe to swap out the quant_method.
1473- def init_prepare_finalize (self ) -> None :
1474- mk = self .quant_method .init_prepare_finalize (self )
1476+ def maybe_init_modular_kernel (self ) -> None :
1477+ mk = self .quant_method .maybe_init_modular_kernel (self )
14751478 if mk is not None :
14761479 self .quant_method = FusedMoEModularMethod (self .quant_method , mk )
14771480
0 commit comments