@@ -1040,6 +1040,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
10401040 w2 : torch .Tensor ,
10411041 topk_weights : torch .Tensor ,
10421042 topk_ids : torch .Tensor ,
1043+ activation : str = "silu" ,
10431044 use_fp8_w8a8 : bool = False ,
10441045 use_int8_w8a16 : bool = False ,
10451046 use_int4_w4a16 : bool = False ,
@@ -1053,9 +1054,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
10531054 a2_scale : Optional [torch .Tensor ] = None ,
10541055 block_shape : Optional [List [int ]] = None ) -> None :
10551056 fused_experts_impl (hidden_states , w1 , w2 , topk_weights , topk_ids , True ,
1056- use_fp8_w8a8 , use_int8_w8a16 , use_int4_w4a16 ,
1057- global_num_experts , expert_map , w1_scale , w2_scale ,
1058- w1_zp , w2_zp , a1_scale , a2_scale , block_shape )
1057+ activation , use_fp8_w8a8 , use_int8_w8a16 ,
1058+ use_int4_w4a16 , global_num_experts , expert_map ,
1059+ w1_scale , w2_scale , w1_zp , w2_zp , a1_scale , a2_scale ,
1060+ block_shape )
10591061
10601062
10611063def inplace_fused_experts_fake (
@@ -1064,6 +1066,7 @@ def inplace_fused_experts_fake(
10641066 w2 : torch .Tensor ,
10651067 topk_weights : torch .Tensor ,
10661068 topk_ids : torch .Tensor ,
1069+ activation : str = "silu" ,
10671070 use_fp8_w8a8 : bool = False ,
10681071 use_int8_w8a16 : bool = False ,
10691072 use_int4_w4a16 : bool = False ,
@@ -1093,6 +1096,7 @@ def outplace_fused_experts(
10931096 w2 : torch .Tensor ,
10941097 topk_weights : torch .Tensor ,
10951098 topk_ids : torch .Tensor ,
1099+ activation : str = "silu" ,
10961100 use_fp8_w8a8 : bool = False ,
10971101 use_int8_w8a16 : bool = False ,
10981102 use_int4_w4a16 : bool = False ,
@@ -1106,7 +1110,7 @@ def outplace_fused_experts(
11061110 a2_scale : Optional [torch .Tensor ] = None ,
11071111 block_shape : Optional [List [int ]] = None ) -> torch .Tensor :
11081112 return fused_experts_impl (hidden_states , w1 , w2 , topk_weights , topk_ids ,
1109- False , use_fp8_w8a8 , use_int8_w8a16 ,
1113+ False , activation , use_fp8_w8a8 , use_int8_w8a16 ,
11101114 use_int4_w4a16 , global_num_experts , expert_map ,
11111115 w1_scale , w2_scale , w1_zp , w2_zp , a1_scale ,
11121116 a2_scale , block_shape )
@@ -1118,6 +1122,7 @@ def outplace_fused_experts_fake(
11181122 w2 : torch .Tensor ,
11191123 topk_weights : torch .Tensor ,
11201124 topk_ids : torch .Tensor ,
1125+ activation : str = "silu" ,
11211126 use_fp8_w8a8 : bool = False ,
11221127 use_int8_w8a16 : bool = False ,
11231128 use_int4_w4a16 : bool = False ,
@@ -1147,6 +1152,7 @@ def fused_experts(hidden_states: torch.Tensor,
11471152 topk_weights : torch .Tensor ,
11481153 topk_ids : torch .Tensor ,
11491154 inplace : bool = False ,
1155+ activation : str = "silu" ,
11501156 use_fp8_w8a8 : bool = False ,
11511157 use_int8_w8a16 : bool = False ,
11521158 use_int4_w4a16 : bool = False ,
@@ -1162,15 +1168,17 @@ def fused_experts(hidden_states: torch.Tensor,
11621168
11631169 if inplace :
11641170 torch .ops .vllm .inplace_fused_experts (
1165- hidden_states , w1 , w2 , topk_weights , topk_ids , use_fp8_w8a8 ,
1166- use_int8_w8a16 , use_int4_w4a16 , global_num_experts , expert_map ,
1167- w1_scale , w2_scale , w1_zp , w2_zp , a1_scale , a2_scale , block_shape )
1171+ hidden_states , w1 , w2 , topk_weights , topk_ids , activation ,
1172+ use_fp8_w8a8 , use_int8_w8a16 , use_int4_w4a16 , global_num_experts ,
1173+ expert_map , w1_scale , w2_scale , w1_zp , w2_zp , a1_scale , a2_scale ,
1174+ block_shape )
11681175 return hidden_states
11691176 else :
11701177 return torch .ops .vllm .outplace_fused_experts (
1171- hidden_states , w1 , w2 , topk_weights , topk_ids , use_fp8_w8a8 ,
1172- use_int8_w8a16 , use_int4_w4a16 , global_num_experts , expert_map ,
1173- w1_scale , w2_scale , w1_zp , w2_zp , a1_scale , a2_scale , block_shape )
1178+ hidden_states , w1 , w2 , topk_weights , topk_ids , activation ,
1179+ use_fp8_w8a8 , use_int8_w8a16 , use_int4_w4a16 , global_num_experts ,
1180+ expert_map , w1_scale , w2_scale , w1_zp , w2_zp , a1_scale , a2_scale ,
1181+ block_shape )
11741182
11751183
11761184def fused_experts_impl (hidden_states : torch .Tensor ,
@@ -1179,6 +1187,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
11791187 topk_weights : torch .Tensor ,
11801188 topk_ids : torch .Tensor ,
11811189 inplace : bool = False ,
1190+ activation : str = "silu" ,
11821191 use_fp8_w8a8 : bool = False ,
11831192 use_int8_w8a16 : bool = False ,
11841193 use_int4_w4a16 : bool = False ,
@@ -1303,8 +1312,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
13031312 use_int4_w4a16 = use_int4_w4a16 ,
13041313 block_shape = block_shape )
13051314
1306- torch .ops ._C .silu_and_mul (intermediate_cache2 ,
1307- intermediate_cache1 .view (- 1 , N ))
1315+ if activation == "silu" :
1316+ torch .ops ._C .silu_and_mul (intermediate_cache2 ,
1317+ intermediate_cache1 .view (- 1 , N ))
1318+ elif activation == "gelu" :
1319+ torch .ops ._C .gelu_and_mul (intermediate_cache2 ,
1320+ intermediate_cache1 .view (- 1 , N ))
1321+ else :
1322+ raise ValueError (f"Unsupported FusedMoe activation: { activation } " )
13081323
13091324 invoke_fused_moe_kernel (intermediate_cache2 ,
13101325 w2 ,
@@ -1339,6 +1354,7 @@ def fused_moe(
13391354 topk : int ,
13401355 renormalize : bool ,
13411356 inplace : bool = False ,
1357+ activation : str = "silu" ,
13421358 use_grouped_topk : bool = False ,
13431359 num_expert_group : Optional [int ] = None ,
13441360 topk_group : Optional [int ] = None ,
@@ -1370,6 +1386,8 @@ def fused_moe(
13701386 - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
13711387 - inplace (bool): If True, perform the operation in-place.
13721388 Defaults to False.
1389+ - activation (str): The activation function to apply after the first
1390+ MoE layer.
13731391 - num_expert_group: Optional[int]: additional parameter for grouped_topk
13741392 - topk_group: Optional[int]: additional parameter for grouped_topk
13751393 - use_grouped_topk: If True, use grouped_topk instead of fused_topk
@@ -1420,6 +1438,7 @@ def fused_moe(
14201438 topk_weights ,
14211439 topk_ids ,
14221440 inplace = inplace ,
1441+ activation = activation ,
14231442 use_fp8_w8a8 = use_fp8_w8a8 ,
14241443 use_int8_w8a16 = use_int8_w8a16 ,
14251444 use_int4_w4a16 = use_int4_w4a16 ,
0 commit comments