Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion tests/distributed/test_eplb_execute.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import asyncio
import random

import pytest
import torch
import torch.distributed

from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.eplb.rebalance_execute import (
move_from_buffer,
rearrange_expert_weights_inplace,
transfer_layer,
)
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_tp_group,
Expand Down Expand Up @@ -231,6 +236,100 @@ def verify_redundant_experts_have_same_weights(
)


def _test_async_transfer_layer_without_mtp_worker(
env,
world_size: int,
num_layers: int,
num_local_experts: int,
num_logical_experts: int,
) -> None:
set_env_vars_and_device(env)
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)

tp_group = get_tp_group()
ep_group = tp_group.device_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")

total_physical_experts = world_size * num_local_experts
hidden_sizes = [16, 32]

redundancy_config = create_redundancy_config(
num_logical_experts,
total_physical_experts,
)
old_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
redundancy_config,
)

new_redundancy_config = create_redundancy_config(
num_logical_experts,
total_physical_experts,
)
new_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
new_redundancy_config,
)

expert_weights = create_expert_weights(
num_layers,
num_local_experts,
hidden_sizes,
ep_rank,
device,
old_indices,
)

expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
cuda_stream = torch.cuda.Stream(device=device)

for layer_idx in range(num_layers):
is_unchanged, is_received_locally, experts_recv_loc = asyncio.run(
transfer_layer(
old_global_expert_indices=old_indices,
new_global_expert_indices=new_indices,
expert_weights=expert_weights,
expert_weights_buffer=expert_buffer,
ep_group=ep_group,
layer=layer_idx,
cuda_stream=cuda_stream,
)
)

cuda_stream.synchronize()
move_from_buffer(
expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer,
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
experts_recv_loc=experts_recv_loc,
new_indices=new_indices[layer_idx].tolist(),
ep_group=ep_group,
)

verify_expert_weights_after_shuffle(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
num_local_experts,
)
verify_redundant_experts_have_same_weights(
expert_weights,
new_indices,
hidden_sizes,
world_size,
num_local_experts,
)


def _test_rearrange_expert_weights_with_redundancy(
env, world_size, num_layers, num_local_experts, num_logical_experts
) -> None:
Expand Down Expand Up @@ -399,6 +498,32 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
)


@pytest.mark.parametrize(
"world_size,num_layers,num_local_experts,num_logical_experts",
[
(2, 2, 2, 3),
],
)
def test_async_transfer_layer_without_mtp(
world_size: int,
num_layers: int,
num_local_experts: int,
num_logical_experts: int,
):
"""Exercise async EPLB transfer path without MTP/spec decode."""

if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")

distributed_run(
_test_async_transfer_layer_without_mtp_worker,
world_size,
num_layers,
num_local_experts,
num_logical_experts,
)


@pytest.mark.parametrize("world_size", [2, 4])
def test_rearrange_expert_weights_no_change(world_size):
"""
Expand Down
39 changes: 38 additions & 1 deletion tests/distributed/test_eplb_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

def get_model_args(
model_name: str,
spec_model_name: str,
spec_model_name: str | None,
spec_method: str,
tp_size: int,
model_max_len: int,
use_async: bool = False,
) -> dict:
speculative_config = {
"method": spec_method,
Expand All @@ -37,6 +38,8 @@ def get_model_args(
"enable_eplb": True,
"max_model_len": model_max_len,
}
if use_async:
model_args["eplb_config"] = {"use_async": True}
return model_args


Expand Down Expand Up @@ -94,3 +97,37 @@ def test_eplb_spec_decode(
measured_value - RTOL < expected_gsm8k_value
and measured_value + RTOL > expected_gsm8k_value
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"


@large_gpu_mark(min_gb=80)
def test_eplb_spec_decode_qwen3_next_mtp_async() -> None:
"""
Ensure async EPLB works with MTP speculative decoding for Qwen3-Next.
"""

TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
expected_gsm8k_value = 0.86

model_args = get_model_args(
model_name="Qwen/Qwen3-Next-80B-A3B-Instruct",
spec_model_name=None,
spec_method="mtp",
tp_size=4,
model_max_len=4096,
use_async=True,
)

results = lm_eval.simple_evaluate(
model="vllm",
model_args=model_args,
tasks=TASK,
batch_size=64,
num_fewshot=8,
)
measured_value = results["results"][TASK][FILTER]
assert (
measured_value - RTOL < expected_gsm8k_value
and measured_value + RTOL > expected_gsm8k_value
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
4 changes: 4 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class EPLBConfig:
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
use_async: bool = False
"""
Whether to use non-blocking EPLB.
"""


@config
Expand Down
115 changes: 115 additions & 0 deletions vllm/distributed/eplb/async_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
The async worker that transfers experts in the background.
"""

import asyncio
import threading
from typing import TYPE_CHECKING

import torch
from torch.distributed import ProcessGroup

from vllm.distributed.parallel_state import get_ep_group
from vllm.logger import init_logger

from .rebalance_execute import transfer_layer

if TYPE_CHECKING:
from .eplb_state import EplbState

logger = init_logger(__name__)


def start_async_worker(
state: "EplbState",
rank_mapping: dict[int, int] | None = None,
is_profile: bool = False,
) -> threading.Thread:
ep_group = get_ep_group().device_group
rank = ep_group.rank()
device_index = state.cuda_device_index

def thread_target() -> None:
assert device_index is not None
torch.cuda.set_device(device_index)
cuda_stream = torch.cuda.Stream(device=device_index)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
transfer_run_periodically(
state=state,
ep_group=ep_group,
is_profile=is_profile,
rank_mapping=rank_mapping,
cuda_stream=cuda_stream,
)
)
except Exception as exc: # pragma: no cover - diagnostic path
logger.exception("async loop error (Rank %d): %s", rank, str(exc))
finally:
loop.close()

thread = threading.Thread(target=thread_target, daemon=True)
thread.start()
return thread


async def transfer_run_periodically(
state: "EplbState",
ep_group: ProcessGroup,
is_profile: bool = False,
rank_mapping: dict[int, int] | None = None,
cuda_stream: torch.cuda.Stream = None,
) -> None:
while True:
await asyncio.to_thread(state.rearrange_event.wait)
logger.info("async worker woke up for EPLB transfer")

for model_state in state.model_states.values():
if not model_state.is_async_enabled:
continue
current_num_layers = model_state.model.num_moe_layers
while (
model_state.rebalanced
and model_state.layer_to_transfer < current_num_layers
):
if (
not model_state.ep_buffer_ready
and model_state.rebalanced
and model_state.new_physical_to_logical_map is not None
):
await asyncio.to_thread(model_state.buffer_lock.acquire)
try:
if model_state.layer_to_transfer >= current_num_layers:
break

(
model_state.is_unchanged,
model_state.is_received_locally,
model_state.experts_recv_loc,
) = await transfer_layer(
old_global_expert_indices=model_state.physical_to_logical_map,
new_global_expert_indices=model_state.new_physical_to_logical_map,
expert_weights=model_state.model.expert_weights,
expert_weights_buffer=model_state.expert_buffer,
ep_group=ep_group,
is_profile=is_profile,
layer=model_state.layer_to_transfer,
cuda_stream=cuda_stream,
rank_mapping=rank_mapping,
)
event = torch.cuda.Event(blocking=False)
cuda_stream.record_event(event)
model_state.buffer_ready_event = event
model_state.ep_buffer_ready = 1
finally:
model_state.buffer_lock.release()
else:
if not model_state.rebalanced:
break
await asyncio.sleep(0.001)

state.rearrange_event.clear()
Loading