@@ -244,6 +244,7 @@ def apply(
244244 expert_map : Optional [torch .Tensor ] = None ,
245245 custom_routing_function : Optional [Callable ] = None ,
246246 scoring_func : str = "softmax" ,
247+ routed_scaling_factor : float = 1.0 ,
247248 e_score_correction_bias : Optional [torch .Tensor ] = None ,
248249 apply_router_weight_on_input : bool = False ,
249250 activation : str = "silu" ,
@@ -400,6 +401,7 @@ def apply(
400401 expert_map : Optional [torch .Tensor ] = None ,
401402 custom_routing_function : Optional [Callable ] = None ,
402403 scoring_func : str = "softmax" ,
404+ routed_scaling_factor : float = 1.0 ,
403405 e_score_correction_bias : Optional [torch .Tensor ] = None ,
404406 apply_router_weight_on_input : bool = False ,
405407 activation : str = "silu" ,
@@ -427,6 +429,7 @@ def apply(
427429 expert_map = expert_map ,
428430 custom_routing_function = custom_routing_function ,
429431 scoring_func = scoring_func ,
432+ routed_scaling_factor = routed_scaling_factor ,
430433 e_score_correction_bias = e_score_correction_bias ,
431434 activation = activation ,
432435 apply_router_weight_on_input = apply_router_weight_on_input ,
@@ -450,6 +453,7 @@ def forward_cuda(
450453 expert_map : Optional [torch .Tensor ] = None ,
451454 custom_routing_function : Optional [Callable ] = None ,
452455 scoring_func : str = "softmax" ,
456+ routed_scaling_factor : float = 1.0 ,
453457 e_score_correction_bias : Optional [torch .Tensor ] = None ,
454458 apply_router_weight_on_input : bool = False ,
455459 activation : str = "silu" ,
@@ -469,6 +473,7 @@ def forward_cuda(
469473 num_expert_group = num_expert_group ,
470474 custom_routing_function = custom_routing_function ,
471475 scoring_func = scoring_func ,
476+ routed_scaling_factor = routed_scaling_factor ,
472477 e_score_correction_bias = e_score_correction_bias ,
473478 indices_type = self .topk_indices_dtype ,
474479 enable_eplb = enable_eplb ,
@@ -534,6 +539,7 @@ def forward_cpu(
534539 expert_map : Optional [torch .Tensor ] = None ,
535540 custom_routing_function : Optional [Callable ] = None ,
536541 scoring_func : str = "softmax" ,
542+ routed_scaling_factor : float = 1.0 ,
537543 e_score_correction_bias : Optional [torch .Tensor ] = None ,
538544 apply_router_weight_on_input : bool = False ,
539545 activation : str = "silu" ,
@@ -560,6 +566,7 @@ def forward_cpu(
560566 expert_map ,
561567 custom_routing_function ,
562568 scoring_func ,
569+ routed_scaling_factor ,
563570 e_score_correction_bias ,
564571 apply_router_weight_on_input ,
565572 activation ,
@@ -579,6 +586,7 @@ def forward_xpu(
579586 expert_map : Optional [torch .Tensor ] = None ,
580587 custom_routing_function : Optional [Callable ] = None ,
581588 scoring_func : str = "softmax" ,
589+ routed_scaling_factor : float = 1.0 ,
582590 e_score_correction_bias : Optional [torch .Tensor ] = None ,
583591 apply_router_weight_on_input : bool = False ,
584592 activation : str = "silu" ,
@@ -617,6 +625,7 @@ def forward_tpu(
617625 expert_map : Optional [torch .Tensor ] = None ,
618626 custom_routing_function : Optional [Callable ] = None ,
619627 scoring_func : str = "softmax" ,
628+ routed_scaling_factor : float = 1.0 ,
620629 e_score_correction_bias : Optional [torch .Tensor ] = None ,
621630 apply_router_weight_on_input : bool = False ,
622631 activation : str = "silu" ,
@@ -637,6 +646,9 @@ def forward_tpu(
637646 raise NotImplementedError (
638647 "Expert score correction bias is not supported for TPU." )
639648 assert activation == "silu" , f"{ activation } is not supported for TPU."
649+ assert routed_scaling_factor == 1.0 , \
650+ f"routed_scaling_factor { routed_scaling_factor } is not supported " \
651+ f"for TPU."
640652 if enable_eplb is not False or expert_load_view is not None or \
641653 logical_to_physical_map is not None or \
642654 logical_replica_count is not None :
@@ -766,6 +778,7 @@ def __init__(
766778 prefix : str = "" ,
767779 custom_routing_function : Optional [Callable ] = None ,
768780 scoring_func : str = "softmax" ,
781+ routed_scaling_factor : float = 1.0 ,
769782 e_score_correction_bias : Optional [torch .Tensor ] = None ,
770783 apply_router_weight_on_input : bool = False ,
771784 activation : str = "silu" ,
@@ -848,6 +861,7 @@ def __init__(
848861 self .topk_group = topk_group
849862 self .custom_routing_function = custom_routing_function
850863 self .scoring_func = scoring_func
864+ self .routed_scaling_factor = routed_scaling_factor
851865 self .e_score_correction_bias = e_score_correction_bias
852866 self .apply_router_weight_on_input = apply_router_weight_on_input
853867 self .activation = activation
@@ -1416,6 +1430,7 @@ def select_experts(
14161430 num_expert_group : Optional [int ] = None ,
14171431 custom_routing_function : Optional [Callable ] = None ,
14181432 scoring_func : str = "softmax" ,
1433+ routed_scaling_factor : float = 1.0 ,
14191434 e_score_correction_bias : Optional [torch .Tensor ] = None ,
14201435 indices_type : Optional [torch .dtype ] = None ,
14211436 enable_eplb : bool = False ,
@@ -1460,6 +1475,7 @@ def select_experts(
14601475 num_expert_group = num_expert_group ,
14611476 topk_group = topk_group ,
14621477 scoring_func = scoring_func ,
1478+ routed_scaling_factor = routed_scaling_factor ,
14631479 e_score_correction_bias = e_score_correction_bias )
14641480 if indices_type is not None :
14651481 topk_ids = topk_ids .to (dtype = indices_type )
@@ -1627,6 +1643,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
16271643 num_expert_group = self .num_expert_group ,
16281644 custom_routing_function = self .custom_routing_function ,
16291645 scoring_func = self .scoring_func ,
1646+ routed_scaling_factor = self .routed_scaling_factor ,
16301647 e_score_correction_bias = self .e_score_correction_bias ,
16311648 activation = self .activation ,
16321649 enable_eplb = self .enable_eplb ,
@@ -1695,6 +1712,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
16951712 num_expert_group = self .num_expert_group ,
16961713 custom_routing_function = self .custom_routing_function ,
16971714 scoring_func = self .scoring_func ,
1715+ routed_scaling_factor = self .routed_scaling_factor ,
16981716 e_score_correction_bias = self .e_score_correction_bias ,
16991717 activation = self .activation ,
17001718 apply_router_weight_on_input = self .apply_router_weight_on_input ,
0 commit comments