@@ -314,6 +314,8 @@ def __init__(
314314 inplace : bool = True ,
315315 no_combine : bool = False ,
316316 routed_scaling_factor : Optional [float ] = None ,
317+ enable_flashinfer_moe : Optional [bool ] = False ,
318+ enable_ep_moe : Optional [bool ] = False ,
317319 ):
318320 super ().__init__ ()
319321
@@ -324,9 +326,34 @@ def __init__(
324326 self .tp_size = (
325327 tp_size if tp_size is not None else get_tensor_model_parallel_world_size ()
326328 )
329+ self .tp_rank = get_tensor_model_parallel_rank ()
330+ self .num_experts = num_experts
331+ self .expert_map = None
332+ self .enable_flashinfer_moe = enable_flashinfer_moe
333+ if enable_ep_moe :
334+ assert (
335+ self .enable_flashinfer_moe
336+ ), "FusedMoE only supports EP with --enable-flashinfer-moe"
337+ self .ep_size = self .tp_size
338+ self .ep_rank = self .tp_rank
339+ self .tp_size = 1
340+ self .tp_rank = 0
341+ # Create a tensor of size num_experts filled with -1
342+ self .expert_map = torch .full ((self .num_experts ,), - 1 , dtype = torch .int32 )
343+ # Create a expert map for the local experts
344+ assert num_experts % self .ep_size == 0
345+ self .local_num_experts = num_experts // self .ep_size
346+ self .expert_map [
347+ self .ep_rank
348+ * self .local_num_experts : (self .ep_rank + 1 )
349+ * self .local_num_experts
350+ ] = torch .arange (0 , self .local_num_experts , dtype = torch .int32 , device = "cpu" )
351+ else :
352+ self .ep_size = 1
353+ self .ep_rank = 0
354+ self .local_num_experts = num_experts
327355 self .routed_scaling_factor = routed_scaling_factor
328356 self .top_k = top_k
329- self .num_experts = num_experts
330357 assert intermediate_size % self .tp_size == 0
331358 self .intermediate_size_per_partition = intermediate_size // self .tp_size
332359 self .reduce_results = reduce_results
@@ -344,19 +371,20 @@ def __init__(
344371 self .use_presharded_weights = use_presharded_weights
345372 self .inplace = inplace
346373 self .no_combine = no_combine
347- self .local_num_experts = num_experts
348374
349375 if quant_config is None :
350376 self .quant_method : Optional [QuantizeMethodBase ] = (
351377 UnquantizedFusedMoEMethod ()
352378 )
353379 else :
354380 self .quant_method = quant_config .get_quant_method (self , prefix )
381+ if self .quant_method .__class__ .__name__ == "ModelOptNvFp4FusedMoEMethod" :
382+ self .quant_method .enable_flashinfer_moe = self .enable_flashinfer_moe
355383 assert self .quant_method is not None
356384
357385 self .quant_method .create_weights (
358386 layer = self ,
359- num_experts = num_experts ,
387+ num_experts = self . local_num_experts ,
360388 hidden_size = hidden_size ,
361389 # FIXME: figure out which intermediate_size to use
362390 intermediate_size = self .intermediate_size_per_partition ,
@@ -450,12 +478,15 @@ def _load_w13(
450478
451479 # Narrow parameter and load.
452480 # w1, gate_proj: Load into first logical weight of w13.
453- if shard_id == "w1" :
454- expert_data = expert_data .narrow (shard_dim , 0 , shard_size )
455481 # w3, up_proj: Load into second logical weight of w13.
482+ # trtllm cutlass kernel assumes differently
483+ assert shard_id in ("w1" , "w3" )
484+ switch_w13 = getattr (self .quant_method , "load_up_proj_weight_first" , False )
485+ if (switch_w13 and shard_id == "w1" ) or (not switch_w13 and shard_id == "w3" ):
486+ start = shard_size
456487 else :
457- assert shard_id == "w3"
458- expert_data = expert_data .narrow (shard_dim , shard_size , shard_size )
488+ start = 0
489+ expert_data = expert_data .narrow (shard_dim , start , shard_size )
459490 expert_data .copy_ (loaded_weight )
460491
461492 def _load_w2 (
@@ -509,6 +540,11 @@ def _load_g_idx(
509540 assert shard_id in ("w1" , "w3" )
510541 expert_data .copy_ (loaded_weight )
511542
543+ def _map_global_expert_id_to_local_expert_id (self , expert_id : int ) -> int :
544+ if self .expert_map is None :
545+ return expert_id
546+ return self .expert_map [expert_id ].item ()
547+
512548 def weight_loader (
513549 self ,
514550 param : torch .nn .Parameter ,
@@ -517,6 +553,13 @@ def weight_loader(
517553 shard_id : str ,
518554 expert_id : int ,
519555 ) -> None :
556+ expert_id = self ._map_global_expert_id_to_local_expert_id (expert_id )
557+ if expert_id == - 1 :
558+ return
559+
560+ # TP rank is set to 0 if EP is enabled
561+ tp_rank = 0 if self .ep_size > 1 else get_tensor_model_parallel_rank ()
562+
520563 # compressed-tensors checkpoints with packed weights are stored flipped
521564 # TODO (mgoin): check self.quant_method.quant_config.quant_format
522565 # against known CompressionFormat enum values that have this quality
@@ -541,15 +584,14 @@ def weight_loader(
541584 SHARD_ID_TO_SHARDED_DIM = {"w1" : 0 , "w2" : 1 , "w3" : 0 }
542585
543586 expert_data = param .data [expert_id ]
544- tp_rank = get_tensor_model_parallel_rank ()
545587
546588 # is_transposed: if the dim to shard the weight
547589 # should be flipped. Required by GPTQ, compressed-tensors
548590 # should be whatever dimension intermediate_size is
549591 is_transposed = getattr (param , "is_transposed" , False )
550592 shard_dim = SHARD_ID_TO_SHARDED_DIM [shard_id ]
551593 if is_transposed :
552- shard_dim = ~ shard_dim
594+ shard_dim = int ( not shard_dim )
553595
554596 # Case input scale: input_scale loading is only supported for fp8
555597 if "input_scale" in weight_name :
@@ -690,9 +732,19 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
690732 activation = self .activation ,
691733 apply_router_weight_on_input = self .apply_router_weight_on_input ,
692734 routed_scaling_factor = self .routed_scaling_factor ,
735+ ** (
736+ dict (
737+ tp_rank = self .tp_rank ,
738+ tp_size = self .tp_size ,
739+ ep_rank = self .ep_rank ,
740+ ep_size = self .ep_size ,
741+ )
742+ if self .quant_method .__class__ .__name__ == "ModelOptNvFp4FusedMoEMethod"
743+ else {}
744+ ),
693745 )
694746
695- if self .reduce_results and self .tp_size > 1 :
747+ if self .reduce_results and ( self .tp_size > 1 or self . ep_size > 1 ) :
696748 final_hidden_states = tensor_model_parallel_all_reduce (final_hidden_states )
697749
698750 return final_hidden_states
0 commit comments