Skip to content

Commit 1a425d7

Browse files
PerryZhang01zgplvyou
authored andcommitted
[EPLB][ROCm]: support EPBL for ROCm backend (vllm-project#27731)
Signed-off-by: Perry Zhang <[email protected]> Co-authored-by: Perry Zhang <[email protected]>
1 parent 07d7828 commit 1a425d7

File tree

4 files changed

+23
-6
lines changed

4 files changed

+23
-6
lines changed

vllm/config/parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,10 @@ def _validate_parallel_config(self) -> Self:
278278
)
279279

280280
if self.enable_eplb:
281-
if not current_platform.is_cuda():
281+
if not current_platform.is_cuda_alike():
282282
raise ValueError(
283283
"Expert parallelism load balancing is only supported on "
284-
"CUDA devices now."
284+
"CUDA devices or ROCm devices now."
285285
)
286286
if not self.enable_expert_parallel:
287287
raise ValueError("enable_expert_parallel must be True to use EPLB.")

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,11 @@ def load_weights(
12181218

12191219
def get_expert_weights(self) -> Iterable[torch.Tensor]:
12201220
weights = list(self.named_parameters())
1221-
assert all(weight.is_contiguous() for _, weight in weights)
1221+
assert all(
1222+
weight.is_contiguous()
1223+
for name, weight in weights
1224+
if not name.startswith("_shared_experts.")
1225+
)
12221226

12231227
# Filter out the non-expert weights.
12241228
# `e_score_correction_bias` is a bias for each logical expert,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,9 +1019,10 @@ def apply(
10191019
logical_replica_count: torch.Tensor | None = None,
10201020
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
10211021
if enable_eplb:
1022-
raise NotImplementedError(
1023-
"EPLB not supported for `CompressedTensorsW8A8Fp8MoEMethod` yet."
1024-
)
1022+
assert expert_load_view is not None
1023+
assert logical_to_physical_map is not None
1024+
assert logical_replica_count is not None
1025+
assert isinstance(layer, FusedMoE)
10251026

10261027
topk_weights, topk_ids, _ = FusedMoE.select_experts(
10271028
hidden_states=x,
@@ -1037,6 +1038,11 @@ def apply(
10371038
e_score_correction_bias=e_score_correction_bias,
10381039
indices_type=self.topk_indices_dtype,
10391040
num_fused_shared_experts=layer.num_fused_shared_experts,
1041+
enable_eplb=enable_eplb,
1042+
expert_map=expert_map,
1043+
expert_load_view=expert_load_view,
1044+
logical_to_physical_map=logical_to_physical_map,
1045+
logical_replica_count=logical_replica_count,
10401046
)
10411047

10421048
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
@@ -1145,6 +1151,10 @@ def apply(
11451151
quant_config=self.moe_quant_config,
11461152
)
11471153

1154+
@property
1155+
def supports_eplb(self) -> bool:
1156+
return True
1157+
11481158

11491159
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
11501160
def __init__(

0 commit comments

Comments
 (0)