Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/v1/tpu/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

MODELS = [
"Qwen/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen1.5-MoE-A2.7B",
# TODO: Enable this models with v6e
# "Qwen/Qwen2-7B-Instruct",
# "meta-llama/Llama-3.1-8B",
Expand Down
19 changes: 18 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,16 @@ def forward_cpu(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
**kwargs,
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \
logical_replica_count is not None:
raise NotImplementedError("Expert load balancing is not supported "
"for TPU.")
return layer.cpu_fused_moe(
layer,
x,
Expand Down Expand Up @@ -518,6 +526,10 @@ def forward_tpu(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
Expand All @@ -531,6 +543,11 @@ def forward_tpu(
raise NotImplementedError(
"Expert score correction bias is not supported for TPU.")
assert activation == "silu", f"{activation} is not supported for TPU."
if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \
logical_replica_count is not None:
raise NotImplementedError("Expert load balancing is not supported "
"for TPU.")
return fused_moe_pallas(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
Expand Down