Skip to content

Commit c0847e7

Browse files
xyang16mgoinDarkLight1337
authored andcommitted
Add routed_scaling_factor to MoE grouped topk (vllm-project#23123)
Signed-off-by: Xin Yang <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent df48bf0 commit c0847e7

File tree

19 files changed

+77
-4
lines changed

19 files changed

+77
-4
lines changed

vllm/model_executor/layers/fused_moe/cpu_fused_moe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def grouped_topk(
2121
num_expert_group: int = 0,
2222
topk_group: int = 0,
2323
scoring_func: str = "softmax",
24+
routed_scaling_factor: float = 1.0,
2425
e_score_correction_bias: Optional[torch.Tensor] = None
2526
) -> tuple[torch.Tensor, torch.Tensor]:
2627
assert hidden_states.shape[0] == gating_output.shape[0], (
@@ -65,6 +66,8 @@ def grouped_topk(
6566
if renormalize:
6667
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
6768

69+
if routed_scaling_factor != 1.0:
70+
topk_weights = topk_weights * routed_scaling_factor
6871
return topk_weights, topk_ids.to(torch.int32)
6972

7073

@@ -78,6 +81,7 @@ def select_experts(
7881
num_expert_group: Optional[int] = None,
7982
custom_routing_function: Optional[Callable] = None,
8083
scoring_func: str = "softmax",
84+
routed_scaling_factor: float = 1.0,
8185
e_score_correction_bias: Optional[torch.Tensor] = None,
8286
) -> tuple[torch.Tensor, torch.Tensor]:
8387
if use_grouped_topk:
@@ -90,6 +94,7 @@ def select_experts(
9094
num_expert_group=num_expert_group,
9195
topk_group=topk_group,
9296
scoring_func=scoring_func,
97+
routed_scaling_factor=routed_scaling_factor,
9398
e_score_correction_bias=e_score_correction_bias)
9499
elif custom_routing_function is None:
95100
assert scoring_func == "softmax"
@@ -131,12 +136,15 @@ def __call__(
131136
expert_map: Optional[torch.Tensor] = None,
132137
custom_routing_function: Optional[Callable] = None,
133138
scoring_func: str = "softmax",
139+
routed_scaling_factor: float = 1.0,
134140
e_score_correction_bias: Optional[torch.Tensor] = None,
135141
apply_router_weight_on_input: bool = False,
136142
activation: str = "silu",
137143
) -> torch.Tensor:
138144
assert activation == "silu", f"{activation} is not supported."
139145
assert not apply_router_weight_on_input
146+
assert routed_scaling_factor == 1.0, \
147+
f"routed_scaling_factor {routed_scaling_factor} is not supported."
140148
return layer.ipex_fusion(
141149
x,
142150
use_grouped_topk,
@@ -170,6 +178,7 @@ def __call__(
170178
expert_map: Optional[torch.Tensor] = None,
171179
custom_routing_function: Optional[Callable] = None,
172180
scoring_func: str = "softmax",
181+
routed_scaling_factor: float = 1.0,
173182
e_score_correction_bias: Optional[torch.Tensor] = None,
174183
apply_router_weight_on_input: bool = False,
175184
activation: str = "silu",
@@ -186,6 +195,7 @@ def __call__(
186195
num_expert_group=num_expert_group,
187196
custom_routing_function=custom_routing_function,
188197
scoring_func=scoring_func,
198+
routed_scaling_factor=routed_scaling_factor,
189199
e_score_correction_bias=e_score_correction_bias,
190200
)
191201

@@ -227,6 +237,7 @@ def __call__(
227237
expert_map: Optional[torch.Tensor] = None,
228238
custom_routing_function: Optional[Callable] = None,
229239
scoring_func: str = "softmax",
240+
routed_scaling_factor: float = 1.0,
230241
e_score_correction_bias: Optional[torch.Tensor] = None,
231242
apply_router_weight_on_input: bool = False,
232243
activation: str = "silu",
@@ -243,6 +254,7 @@ def __call__(
243254
num_expert_group=num_expert_group,
244255
custom_routing_function=custom_routing_function,
245256
scoring_func=scoring_func,
257+
routed_scaling_factor=routed_scaling_factor,
246258
e_score_correction_bias=e_score_correction_bias,
247259
)
248260

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,8 @@ def grouped_topk(
10111011
if renormalize:
10121012
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
10131013

1014-
topk_weights = topk_weights * routed_scaling_factor
1014+
if routed_scaling_factor != 1.0:
1015+
topk_weights = topk_weights * routed_scaling_factor
10151016
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
10161017

10171018

@@ -1790,8 +1791,8 @@ def fused_moe(
17901791
Defaults to False.
17911792
- global_num_experts (int): The total number of experts in the global
17921793
expert space.
1793-
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
1794-
from the global expert space to the local expert space of the expert
1794+
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
1795+
from the global expert space to the local expert space of the expert
17951796
parallel shard.
17961797
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
17971798
w1.

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def apply(
244244
expert_map: Optional[torch.Tensor] = None,
245245
custom_routing_function: Optional[Callable] = None,
246246
scoring_func: str = "softmax",
247+
routed_scaling_factor: float = 1.0,
247248
e_score_correction_bias: Optional[torch.Tensor] = None,
248249
apply_router_weight_on_input: bool = False,
249250
activation: str = "silu",
@@ -400,6 +401,7 @@ def apply(
400401
expert_map: Optional[torch.Tensor] = None,
401402
custom_routing_function: Optional[Callable] = None,
402403
scoring_func: str = "softmax",
404+
routed_scaling_factor: float = 1.0,
403405
e_score_correction_bias: Optional[torch.Tensor] = None,
404406
apply_router_weight_on_input: bool = False,
405407
activation: str = "silu",
@@ -427,6 +429,7 @@ def apply(
427429
expert_map=expert_map,
428430
custom_routing_function=custom_routing_function,
429431
scoring_func=scoring_func,
432+
routed_scaling_factor=routed_scaling_factor,
430433
e_score_correction_bias=e_score_correction_bias,
431434
activation=activation,
432435
apply_router_weight_on_input=apply_router_weight_on_input,
@@ -450,6 +453,7 @@ def forward_cuda(
450453
expert_map: Optional[torch.Tensor] = None,
451454
custom_routing_function: Optional[Callable] = None,
452455
scoring_func: str = "softmax",
456+
routed_scaling_factor: float = 1.0,
453457
e_score_correction_bias: Optional[torch.Tensor] = None,
454458
apply_router_weight_on_input: bool = False,
455459
activation: str = "silu",
@@ -469,6 +473,7 @@ def forward_cuda(
469473
num_expert_group=num_expert_group,
470474
custom_routing_function=custom_routing_function,
471475
scoring_func=scoring_func,
476+
routed_scaling_factor=routed_scaling_factor,
472477
e_score_correction_bias=e_score_correction_bias,
473478
indices_type=self.topk_indices_dtype,
474479
enable_eplb=enable_eplb,
@@ -534,6 +539,7 @@ def forward_cpu(
534539
expert_map: Optional[torch.Tensor] = None,
535540
custom_routing_function: Optional[Callable] = None,
536541
scoring_func: str = "softmax",
542+
routed_scaling_factor: float = 1.0,
537543
e_score_correction_bias: Optional[torch.Tensor] = None,
538544
apply_router_weight_on_input: bool = False,
539545
activation: str = "silu",
@@ -560,6 +566,7 @@ def forward_cpu(
560566
expert_map,
561567
custom_routing_function,
562568
scoring_func,
569+
routed_scaling_factor,
563570
e_score_correction_bias,
564571
apply_router_weight_on_input,
565572
activation,
@@ -579,6 +586,7 @@ def forward_xpu(
579586
expert_map: Optional[torch.Tensor] = None,
580587
custom_routing_function: Optional[Callable] = None,
581588
scoring_func: str = "softmax",
589+
routed_scaling_factor: float = 1.0,
582590
e_score_correction_bias: Optional[torch.Tensor] = None,
583591
apply_router_weight_on_input: bool = False,
584592
activation: str = "silu",
@@ -617,6 +625,7 @@ def forward_tpu(
617625
expert_map: Optional[torch.Tensor] = None,
618626
custom_routing_function: Optional[Callable] = None,
619627
scoring_func: str = "softmax",
628+
routed_scaling_factor: float = 1.0,
620629
e_score_correction_bias: Optional[torch.Tensor] = None,
621630
apply_router_weight_on_input: bool = False,
622631
activation: str = "silu",
@@ -637,6 +646,9 @@ def forward_tpu(
637646
raise NotImplementedError(
638647
"Expert score correction bias is not supported for TPU.")
639648
assert activation == "silu", f"{activation} is not supported for TPU."
649+
assert routed_scaling_factor == 1.0, \
650+
f"routed_scaling_factor {routed_scaling_factor} is not supported " \
651+
f"for TPU."
640652
if enable_eplb is not False or expert_load_view is not None or \
641653
logical_to_physical_map is not None or \
642654
logical_replica_count is not None:
@@ -766,6 +778,7 @@ def __init__(
766778
prefix: str = "",
767779
custom_routing_function: Optional[Callable] = None,
768780
scoring_func: str = "softmax",
781+
routed_scaling_factor: float = 1.0,
769782
e_score_correction_bias: Optional[torch.Tensor] = None,
770783
apply_router_weight_on_input: bool = False,
771784
activation: str = "silu",
@@ -848,6 +861,7 @@ def __init__(
848861
self.topk_group = topk_group
849862
self.custom_routing_function = custom_routing_function
850863
self.scoring_func = scoring_func
864+
self.routed_scaling_factor = routed_scaling_factor
851865
self.e_score_correction_bias = e_score_correction_bias
852866
self.apply_router_weight_on_input = apply_router_weight_on_input
853867
self.activation = activation
@@ -1416,6 +1430,7 @@ def select_experts(
14161430
num_expert_group: Optional[int] = None,
14171431
custom_routing_function: Optional[Callable] = None,
14181432
scoring_func: str = "softmax",
1433+
routed_scaling_factor: float = 1.0,
14191434
e_score_correction_bias: Optional[torch.Tensor] = None,
14201435
indices_type: Optional[torch.dtype] = None,
14211436
enable_eplb: bool = False,
@@ -1460,6 +1475,7 @@ def select_experts(
14601475
num_expert_group=num_expert_group,
14611476
topk_group=topk_group,
14621477
scoring_func=scoring_func,
1478+
routed_scaling_factor=routed_scaling_factor,
14631479
e_score_correction_bias=e_score_correction_bias)
14641480
if indices_type is not None:
14651481
topk_ids = topk_ids.to(dtype=indices_type)
@@ -1627,6 +1643,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
16271643
num_expert_group=self.num_expert_group,
16281644
custom_routing_function=self.custom_routing_function,
16291645
scoring_func=self.scoring_func,
1646+
routed_scaling_factor=self.routed_scaling_factor,
16301647
e_score_correction_bias=self.e_score_correction_bias,
16311648
activation=self.activation,
16321649
enable_eplb=self.enable_eplb,
@@ -1695,6 +1712,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
16951712
num_expert_group=self.num_expert_group,
16961713
custom_routing_function=self.custom_routing_function,
16971714
scoring_func=self.scoring_func,
1715+
routed_scaling_factor=self.routed_scaling_factor,
16981716
e_score_correction_bias=self.e_score_correction_bias,
16991717
activation=self.activation,
17001718
apply_router_weight_on_input=self.apply_router_weight_on_input,

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def rocm_aiter_grouped_topk(
267267
num_expert_group: int = 0,
268268
topk_group: int = 0,
269269
scoring_func: str = "softmax",
270+
routed_scaling_factor: float = 1.0,
270271
e_score_correction_bias: Optional[torch.Tensor] = None
271272
) -> tuple[torch.Tensor, torch.Tensor]:
272273
token = hidden_states.shape[0]
@@ -298,6 +299,8 @@ def rocm_aiter_grouped_topk(
298299
scoring_func,
299300
)
300301

302+
if routed_scaling_factor != 1.0:
303+
topk_weights = topk_weights * routed_scaling_factor
301304
return topk_weights, topk_ids
302305

303306

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ def apply(
497497
expert_map: Optional[torch.Tensor] = None,
498498
custom_routing_function: Optional[Callable] = None,
499499
scoring_func: str = "softmax",
500+
routed_scaling_factor: float = 1.0,
500501
e_score_correction_bias: Optional[torch.Tensor] = None,
501502
apply_router_weight_on_input: bool = False,
502503
activation: str = "silu",
@@ -523,6 +524,7 @@ def apply(
523524
num_expert_group=num_expert_group,
524525
custom_routing_function=custom_routing_function,
525526
scoring_func=scoring_func,
527+
routed_scaling_factor=routed_scaling_factor,
526528
e_score_correction_bias=e_score_correction_bias,
527529
indices_type=self.topk_indices_dtype)
528530

vllm/model_executor/layers/quantization/bitsandbytes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ def apply(
466466
expert_map: Optional[torch.Tensor] = None,
467467
custom_routing_function: Optional[Callable] = None,
468468
scoring_func: str = "softmax",
469+
routed_scaling_factor: float = 1.0,
469470
e_score_correction_bias: Optional[torch.Tensor] = None,
470471
apply_router_weight_on_input: bool = False,
471472
activation: str = "silu",
@@ -490,6 +491,7 @@ def apply(
490491
num_expert_group=num_expert_group,
491492
custom_routing_function=custom_routing_function,
492493
scoring_func=scoring_func,
494+
routed_scaling_factor=routed_scaling_factor,
493495
e_score_correction_bias=e_score_correction_bias,
494496
indices_type=self.topk_indices_dtype)
495497
if self.quant_config.load_in_8bit:

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def apply(
350350
expert_map: Optional[torch.Tensor] = None,
351351
custom_routing_function: Optional[Callable] = None,
352352
scoring_func: str = "softmax",
353+
routed_scaling_factor: float = 1.0,
353354
e_score_correction_bias: Optional[torch.Tensor] = None,
354355
apply_router_weight_on_input: bool = False,
355356
activation: str = "silu",
@@ -375,6 +376,7 @@ def apply(
375376
num_expert_group=num_expert_group,
376377
custom_routing_function=custom_routing_function,
377378
scoring_func=scoring_func,
379+
routed_scaling_factor=routed_scaling_factor,
378380
e_score_correction_bias=e_score_correction_bias,
379381
indices_type=self.topk_indices_dtype,
380382
)
@@ -809,6 +811,7 @@ def apply(
809811
expert_map: Optional[torch.Tensor] = None,
810812
custom_routing_function: Optional[Callable] = None,
811813
scoring_func: str = "softmax",
814+
routed_scaling_factor: float = 1.0,
812815
e_score_correction_bias: Optional[torch.Tensor] = None,
813816
apply_router_weight_on_input: bool = False,
814817
activation: str = "silu",
@@ -832,6 +835,7 @@ def apply(
832835
num_expert_group=num_expert_group,
833836
custom_routing_function=custom_routing_function,
834837
scoring_func=scoring_func,
838+
routed_scaling_factor=routed_scaling_factor,
835839
e_score_correction_bias=e_score_correction_bias,
836840
indices_type=self.topk_indices_dtype,
837841
)
@@ -1057,6 +1061,7 @@ def apply(
10571061
expert_map: Optional[torch.Tensor] = None,
10581062
custom_routing_function: Optional[Callable] = None,
10591063
scoring_func: str = "softmax",
1064+
routed_scaling_factor: float = 1.0,
10601065
e_score_correction_bias: Optional[torch.Tensor] = None,
10611066
apply_router_weight_on_input: bool = False,
10621067
activation: str = "silu",
@@ -1084,6 +1089,7 @@ def apply(
10841089
num_expert_group=num_expert_group,
10851090
custom_routing_function=custom_routing_function,
10861091
scoring_func=scoring_func,
1092+
routed_scaling_factor=routed_scaling_factor,
10871093
e_score_correction_bias=e_score_correction_bias,
10881094
indices_type=self.topk_indices_dtype)
10891095

@@ -1361,6 +1367,7 @@ def apply(
13611367
expert_map: Optional[torch.Tensor] = None,
13621368
custom_routing_function: Optional[Callable] = None,
13631369
scoring_func: str = "softmax",
1370+
routed_scaling_factor: float = 1.0,
13641371
e_score_correction_bias: Optional[torch.Tensor] = None,
13651372
apply_router_weight_on_input: bool = False,
13661373
activation: str = "silu",
@@ -1389,6 +1396,7 @@ def apply(
13891396
num_expert_group=num_expert_group,
13901397
custom_routing_function=custom_routing_function,
13911398
scoring_func=scoring_func,
1399+
routed_scaling_factor=routed_scaling_factor,
13921400
e_score_correction_bias=e_score_correction_bias,
13931401
indices_type=self.topk_indices_dtype)
13941402

@@ -1592,6 +1600,7 @@ def apply(
15921600
expert_map: Optional[torch.Tensor] = None,
15931601
custom_routing_function: Optional[Callable] = None,
15941602
scoring_func: str = "softmax",
1603+
routed_scaling_factor: float = 1.0,
15951604
e_score_correction_bias: Optional[torch.Tensor] = None,
15961605
apply_router_weight_on_input: bool = False,
15971606
activation: str = "silu",
@@ -1618,6 +1627,7 @@ def apply(
16181627
num_expert_group=num_expert_group,
16191628
custom_routing_function=custom_routing_function,
16201629
scoring_func=scoring_func,
1630+
routed_scaling_factor=routed_scaling_factor,
16211631
e_score_correction_bias=e_score_correction_bias,
16221632
indices_type=self.topk_indices_dtype)
16231633

vllm/model_executor/layers/quantization/experts_int8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def apply(
120120
expert_map: Optional[torch.Tensor] = None,
121121
custom_routing_function: Optional[Callable] = None,
122122
scoring_func: str = "softmax",
123+
routed_scaling_factor: float = 1.0,
123124
e_score_correction_bias: Optional[torch.Tensor] = None,
124125
apply_router_weight_on_input: bool = False,
125126
activation: str = "silu",
@@ -146,6 +147,7 @@ def apply(
146147
num_expert_group=num_expert_group,
147148
custom_routing_function=custom_routing_function,
148149
scoring_func=scoring_func,
150+
routed_scaling_factor=routed_scaling_factor,
149151
e_score_correction_bias=e_score_correction_bias,
150152
indices_type=self.topk_indices_dtype)
151153

0 commit comments

Comments
 (0)