diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index 9498e75b279b..781dfd44c1ef 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -1,13 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio import random import pytest import torch import torch.distributed -from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace +from vllm.distributed.eplb.rebalance_execute import ( + move_from_buffer, + rearrange_expert_weights_inplace, + transfer_layer, +) from vllm.distributed.parallel_state import ( ensure_model_parallel_initialized, get_tp_group, @@ -231,6 +236,100 @@ def verify_redundant_experts_have_same_weights( ) +def _test_async_transfer_layer_without_mtp_worker( + env, + world_size: int, + num_layers: int, + num_local_experts: int, + num_logical_experts: int, +) -> None: + set_env_vars_and_device(env) + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) + + tp_group = get_tp_group() + ep_group = tp_group.device_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") + + total_physical_experts = world_size * num_local_experts + hidden_sizes = [16, 32] + + redundancy_config = create_redundancy_config( + num_logical_experts, + total_physical_experts, + ) + old_indices = create_expert_indices_with_redundancy( + num_layers, + num_logical_experts, + total_physical_experts, + redundancy_config, + ) + + new_redundancy_config = create_redundancy_config( + num_logical_experts, + total_physical_experts, + ) + new_indices = create_expert_indices_with_redundancy( + num_layers, + num_logical_experts, + total_physical_experts, + new_redundancy_config, + ) + + expert_weights = create_expert_weights( + num_layers, + num_local_experts, + hidden_sizes, + ep_rank, + device, + old_indices, + ) + + expert_buffer = [torch.empty_like(w) for w in expert_weights[0]] + cuda_stream = torch.cuda.Stream(device=device) + + for layer_idx in range(num_layers): + is_unchanged, is_received_locally, experts_recv_loc = asyncio.run( + transfer_layer( + old_global_expert_indices=old_indices, + new_global_expert_indices=new_indices, + expert_weights=expert_weights, + expert_weights_buffer=expert_buffer, + ep_group=ep_group, + layer=layer_idx, + cuda_stream=cuda_stream, + ) + ) + + cuda_stream.synchronize() + move_from_buffer( + expert_weights=expert_weights[layer_idx], + expert_weights_buffer=expert_buffer, + is_unchanged=is_unchanged, + is_received_locally=is_received_locally, + experts_recv_loc=experts_recv_loc, + new_indices=new_indices[layer_idx].tolist(), + ep_group=ep_group, + ) + + verify_expert_weights_after_shuffle( + expert_weights, + new_indices, + hidden_sizes, + ep_rank, + num_local_experts, + ) + verify_redundant_experts_have_same_weights( + expert_weights, + new_indices, + hidden_sizes, + world_size, + num_local_experts, + ) + + def _test_rearrange_expert_weights_with_redundancy( env, world_size, num_layers, num_local_experts, num_logical_experts ) -> None: @@ -399,6 +498,32 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None: ) +@pytest.mark.parametrize( + "world_size,num_layers,num_local_experts,num_logical_experts", + [ + (2, 2, 2, 3), + ], +) +def test_async_transfer_layer_without_mtp( + world_size: int, + num_layers: int, + num_local_experts: int, + num_logical_experts: int, +): + """Exercise async EPLB transfer path without MTP/spec decode.""" + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need at least {world_size} GPUs to run the test") + + distributed_run( + _test_async_transfer_layer_without_mtp_worker, + world_size, + num_layers, + num_local_experts, + num_logical_experts, + ) + + @pytest.mark.parametrize("world_size", [2, 4]) def test_rearrange_expert_weights_no_change(world_size): """ diff --git a/tests/distributed/test_eplb_spec_decode.py b/tests/distributed/test_eplb_spec_decode.py index 11e23f128f33..c055b7a3f6dd 100644 --- a/tests/distributed/test_eplb_spec_decode.py +++ b/tests/distributed/test_eplb_spec_decode.py @@ -10,10 +10,11 @@ def get_model_args( model_name: str, - spec_model_name: str, + spec_model_name: str | None, spec_method: str, tp_size: int, model_max_len: int, + use_async: bool = False, ) -> dict: speculative_config = { "method": spec_method, @@ -37,6 +38,8 @@ def get_model_args( "enable_eplb": True, "max_model_len": model_max_len, } + if use_async: + model_args["eplb_config"] = {"use_async": True} return model_args @@ -94,3 +97,37 @@ def test_eplb_spec_decode( measured_value - RTOL < expected_gsm8k_value and measured_value + RTOL > expected_gsm8k_value ), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}" + + +@large_gpu_mark(min_gb=80) +def test_eplb_spec_decode_qwen3_next_mtp_async() -> None: + """ + Ensure async EPLB works with MTP speculative decoding for Qwen3-Next. + """ + + TASK = "gsm8k" + FILTER = "exact_match,strict-match" + RTOL = 0.03 + expected_gsm8k_value = 0.86 + + model_args = get_model_args( + model_name="Qwen/Qwen3-Next-80B-A3B-Instruct", + spec_model_name=None, + spec_method="mtp", + tp_size=4, + model_max_len=4096, + use_async=True, + ) + + results = lm_eval.simple_evaluate( + model="vllm", + model_args=model_args, + tasks=TASK, + batch_size=64, + num_fewshot=8, + ) + measured_value = results["results"][TASK][FILTER] + assert ( + measured_value - RTOL < expected_gsm8k_value + and measured_value + RTOL > expected_gsm8k_value + ), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}" diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 4b0236d8de3f..ad438a8b464e 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -60,6 +60,10 @@ class EPLBConfig: Log the balancedness each step of expert parallelism. This is turned off by default since it will cause communication overhead. """ + use_async: bool = False + """ + Whether to use non-blocking EPLB. + """ @config diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py new file mode 100644 index 000000000000..e4b4fc92eeaa --- /dev/null +++ b/vllm/distributed/eplb/async_worker.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +The async worker that transfers experts in the background. +""" + +import asyncio +import threading +from typing import TYPE_CHECKING + +import torch +from torch.distributed import ProcessGroup + +from vllm.distributed.parallel_state import get_ep_group +from vllm.logger import init_logger + +from .rebalance_execute import transfer_layer + +if TYPE_CHECKING: + from .eplb_state import EplbState + +logger = init_logger(__name__) + + +def start_async_worker( + state: "EplbState", + rank_mapping: dict[int, int] | None = None, + is_profile: bool = False, +) -> threading.Thread: + ep_group = get_ep_group().device_group + rank = ep_group.rank() + device_index = state.cuda_device_index + + def thread_target() -> None: + assert device_index is not None + torch.cuda.set_device(device_index) + cuda_stream = torch.cuda.Stream(device=device_index) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + transfer_run_periodically( + state=state, + ep_group=ep_group, + is_profile=is_profile, + rank_mapping=rank_mapping, + cuda_stream=cuda_stream, + ) + ) + except Exception as exc: # pragma: no cover - diagnostic path + logger.exception("async loop error (Rank %d): %s", rank, str(exc)) + finally: + loop.close() + + thread = threading.Thread(target=thread_target, daemon=True) + thread.start() + return thread + + +async def transfer_run_periodically( + state: "EplbState", + ep_group: ProcessGroup, + is_profile: bool = False, + rank_mapping: dict[int, int] | None = None, + cuda_stream: torch.cuda.Stream = None, +) -> None: + while True: + await asyncio.to_thread(state.rearrange_event.wait) + logger.info("async worker woke up for EPLB transfer") + + for model_state in state.model_states.values(): + if not model_state.is_async_enabled: + continue + current_num_layers = model_state.model.num_moe_layers + while ( + model_state.rebalanced + and model_state.layer_to_transfer < current_num_layers + ): + if ( + not model_state.ep_buffer_ready + and model_state.rebalanced + and model_state.new_physical_to_logical_map is not None + ): + await asyncio.to_thread(model_state.buffer_lock.acquire) + try: + if model_state.layer_to_transfer >= current_num_layers: + break + + ( + model_state.is_unchanged, + model_state.is_received_locally, + model_state.experts_recv_loc, + ) = await transfer_layer( + old_global_expert_indices=model_state.physical_to_logical_map, + new_global_expert_indices=model_state.new_physical_to_logical_map, + expert_weights=model_state.model.expert_weights, + expert_weights_buffer=model_state.expert_buffer, + ep_group=ep_group, + is_profile=is_profile, + layer=model_state.layer_to_transfer, + cuda_stream=cuda_stream, + rank_mapping=rank_mapping, + ) + event = torch.cuda.Event(blocking=False) + cuda_stream.record_event(event) + model_state.buffer_ready_event = event + model_state.ep_buffer_ready = 1 + finally: + model_state.buffer_lock.release() + else: + if not model_state.rebalanced: + break + await asyncio.sleep(0.001) + + state.rearrange_event.clear() diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 526d3ceac7b8..9f8798a96a2f 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -26,6 +26,7 @@ physical experts. """ +import threading import time from collections.abc import Sequence from dataclasses import dataclass @@ -43,8 +44,9 @@ from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts +from .async_worker import start_async_worker from .rebalance_algo import rebalance_experts -from .rebalance_execute import rearrange_expert_weights_inplace +from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace logger = init_logger(__name__) @@ -132,6 +134,74 @@ class EplbModelState: """ model_name: str model: MixtureOfExperts + expert_buffer: list[torch.Tensor] + """ + The buffer to store the expert weights during transfer. + """ + buffer_lock: threading.Lock + """ + The lock to protect the expert buffer. + """ + buffer_ready_event: torch.cuda.Event | None + """ + CUDA event recorded when the async worker finishes filling the buffer. + The main thread waits on this before consuming the buffer. + """ + ep_buffer_ready: int + """ + The flag indicates whether the expert buffer is ready for transfer. + 0 or 1. + """ + layer_to_transfer: int + """ + The layer index to transfer in async mode. + """ + rebalanced: bool + """ + The flag indicates whether the experts rebalance have been computed. + """ + pending_global_ready_check: bool + """ + Whether the async EPLB needs to poll peers for buffer readiness. + """ + is_unchanged: list[bool] + """ + intermediate variable between `move_to_buffer` and `move_to_workspace`. + The size is same as the num of physical experts in the current layer. + """ + is_received_locally: list[bool] + """ + intermediate variable between `move_to_buffer` and `move_to_workspace`. + The size is same as the num of physical experts in the current layer. + """ + experts_recv_loc: dict[int, int] + """ + intermediate variable between `move_to_buffer` and `move_to_workspace`. + The size is same as the num of physical experts in the current layer. + """ + is_async_enabled: bool + """ + The flag indicates whether the EPLB is running in async mode. + """ + cuda_device_index: int | None + """ + CUDA device index for the async EPLB worker thread. + """ + new_physical_to_logical_map: torch.Tensor | None = None + """ + intermediate variable between `move_to_buffer` and `move_to_workspace`. + the size is same as physical_to_logical_map + """ + new_logical_to_physical_map: torch.Tensor | None = None + """ + intermediate variable between `move_to_buffer` and `move_to_workspace`. + the size is same as logical_to_physical_map + """ + new_logical_replica_count: torch.Tensor | None = None + """ + intermediate variable between `move_to_buffer` and `move_to_workspace`. + the size is same as logical_replica_count + """ class EplbState: @@ -164,12 +234,31 @@ def __init__(self, parallel_config: ParallelConfig, device: torch.device): Otherwise, the rearrangement will hang at collective communication calls. """ - self.expert_rearrangement_step: int = 0 + self.expert_rearrangement_step_interval: int = 0 """ Interval for expert rearrangement steps. This is a constant and is taken from the config. """ - self.expert_rearrangement_step_interval: int = 0 + self.is_async: bool = False + """ + The flag indicates whether the EPLB is running in async mode. + """ + self.rearrange_event = threading.Event() + """ + Event to signal when a new rearrangement is needed for the async thread. + """ + self.async_worker: threading.Thread | None = None + """ + Background thread handling async transfers. + """ + self.cuda_device_index: int | None = None + """ + CUDA device index for the async EPLB worker thread. + """ + if self.device.type == "cuda": + self.cuda_device_index = self.device.index + if self.cuda_device_index is None and torch.cuda.is_available(): + self.cuda_device_index = torch.cuda.current_device() @staticmethod def build_initial_global_physical_to_logical_map( @@ -239,6 +328,8 @@ def add_model( Build the initial EPLB state. """ self.validate_ep_configuration(model) + self.is_async = self.parallel_config.eplb_config.use_async + physical_to_logical_map_list = ( EplbState.build_initial_global_physical_to_logical_map( model.num_routed_experts, @@ -368,7 +459,12 @@ def add_model( physical_to_logical_map = new_physical_to_logical_map.to(self.device) logical_to_physical_map.copy_(new_logical_to_physical_map) logical_replica_count.copy_(new_logical_replica_count) + else: + new_physical_to_logical_map = None + + new_logical_to_physical_map = None + new_logical_replica_count = None model.set_eplb_state( expert_load_pass, logical_to_physical_map, @@ -385,15 +481,33 @@ def add_model( ) self.expert_rearrangement_step = 0 - self.model_states[model_config.compute_hash()] = EplbModelState( - physical_to_logical_map, - logical_to_physical_map, - logical_replica_count, - expert_load_pass, - expert_load_window, - model_config.model, - model, + expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]] + + model_state = EplbModelState( + physical_to_logical_map=physical_to_logical_map, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + expert_load_pass=expert_load_pass, + expert_load_window=expert_load_window, + model_name=model_config.model, + model=model, + expert_buffer=expert_buffer, + buffer_lock=threading.Lock(), + buffer_ready_event=None, + ep_buffer_ready=0, + layer_to_transfer=0, + rebalanced=False, + pending_global_ready_check=False, + is_unchanged=[], + is_received_locally=[], + experts_recv_loc={}, + is_async_enabled=self.is_async, + cuda_device_index=self.cuda_device_index, + new_physical_to_logical_map=new_physical_to_logical_map, + new_logical_to_physical_map=new_logical_to_physical_map, + new_logical_replica_count=new_logical_replica_count, ) + self.model_states[model_config.compute_hash()] = model_state def step( self, @@ -420,7 +534,7 @@ def step( - `max_tokens`: The maximum load across ranks. - `balancedness`: The ratio of average load to maximum load. """ - + ep_group = get_ep_group().device_group if is_profile: self.rearrange(is_profile=True) return @@ -488,7 +602,49 @@ def step( # rearrangement step and perform rearrangement to ensure all ranks are # performing collective communication. self.expert_rearrangement_step += 1 + + if self.is_async: + for eplb_model_state in self.model_states.values(): + if not eplb_model_state.is_async_enabled: + continue + + all_ranks_buffer_ready = False + if eplb_model_state.pending_global_ready_check: + all_ranks_buffer_ready = self._all_ranks_buffer_ready( + eplb_model_state + ) + if ( + eplb_model_state.is_async_enabled + and eplb_model_state.ep_buffer_ready + and all_ranks_buffer_ready + ): + self.move_to_workspace( + model_state=eplb_model_state, + ep_group=ep_group, + is_profile=is_profile, + ) + if ( + eplb_model_state.layer_to_transfer + >= eplb_model_state.model.num_moe_layers + ): + self.post_eplb(eplb_model_state, is_profile) + eplb_model_state.rebalanced = False + eplb_model_state.layer_to_transfer = 0 + eplb_model_state.pending_global_ready_check = False + logger.info( + "finish async transfer for model %s rank %d layer %d", + eplb_model_state.model_name, + ep_group.rank(), + eplb_model_state.model.num_moe_layers, + ) + if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval: + if any( + eplb_model_state.is_async_enabled and eplb_model_state.rebalanced + for eplb_model_state in self.model_states.values() + ): + # Still performing asynchronous rearrangement + return self.expert_rearrangement_step = 0 self.rearrange() @@ -524,7 +680,11 @@ def rearrange( if is_main_rank: torch.cuda.synchronize() time_start = time.perf_counter() - logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") + logger.info( + "Rearranging experts %s %s...", + "(async mode)" if self.is_async else "sync mode", + "(profile)" if is_profile else "", + ) if global_expert_loads is None: # Map the physical expert load to global logical experts @@ -593,6 +753,7 @@ def rearrange( model = eplb_model_state.model num_replicas = model.num_physical_experts num_groups = model.num_expert_groups + if rank_mapping is not None and len(rank_mapping) == ep_group.size(): # NOTE(yongji): scale down, we need to rebalance the experts on # remaining GPUs, transfer the experts while we haven't shutdown @@ -608,7 +769,7 @@ def rearrange( num_gpus = ep_group.size() if num_gpus % num_nodes != 0: - self.num_nodes = 1 + num_nodes = 1 logger.warning_once( f"num_gpus % num_nodes != 0, " "not using hierarchical rearrangement algorithm.\n" @@ -631,59 +792,215 @@ def rearrange( num_gpus, ) - # Update expert weights - rearrange_expert_weights_inplace( - eplb_model_state.physical_to_logical_map, - new_physical_to_logical_map, - eplb_model_state.model.expert_weights, - ep_group, - is_profile, - rank_mapping, - ) + if not eplb_model_state.is_async_enabled or is_profile: + # Update expert weights + rearrange_expert_weights_inplace( + eplb_model_state.physical_to_logical_map, + new_physical_to_logical_map, + eplb_model_state.model.expert_weights, + ep_group, + is_profile, + rank_mapping, + ) - if not is_profile: - if ( - eplb_model_state.physical_to_logical_map.shape[1] - != new_physical_to_logical_map.shape[1] - ): - eplb_model_state.physical_to_logical_map = ( - new_physical_to_logical_map.to( - eplb_model_state.physical_to_logical_map.device + if not is_profile: + if ( + eplb_model_state.physical_to_logical_map.shape[1] + != new_physical_to_logical_map.shape[1] + ): + eplb_model_state.physical_to_logical_map = ( + new_physical_to_logical_map.to( + eplb_model_state.physical_to_logical_map.device + ) ) + else: + eplb_model_state.physical_to_logical_map.copy_( + new_physical_to_logical_map + ) + max_physical_slots = new_logical_to_physical_map.shape[-1] + assert ( + max_physical_slots + <= eplb_model_state.logical_to_physical_map.shape[-1] ) - else: - eplb_model_state.physical_to_logical_map.copy_( - new_physical_to_logical_map + new_logical_to_physical_map = torch.nn.functional.pad( + new_logical_to_physical_map, + ( + 0, + eplb_model_state.logical_to_physical_map.shape[-1] + - max_physical_slots, + ), + value=-1, ) - max_physical_slots = new_logical_to_physical_map.shape[-1] - assert ( - max_physical_slots - <= eplb_model_state.logical_to_physical_map.shape[-1] - ) - new_logical_to_physical_map = torch.nn.functional.pad( + eplb_model_state.logical_to_physical_map.copy_( + new_logical_to_physical_map + ) + eplb_model_state.logical_replica_count.copy_( + new_logical_replica_count + ) + if is_main_rank: + assert time_start is not None + torch.cuda.synchronize() + time_end = time.perf_counter() + logger.info( + "Rearranged experts%sin %.2f seconds.", + " (profile) " if is_profile else " ", + time_end - time_start, + ) + else: + device = eplb_model_state.physical_to_logical_map.device + new_physical = new_physical_to_logical_map.to(device) + max_slots = eplb_model_state.logical_to_physical_map.shape[-1] + padded_logical = torch.nn.functional.pad( new_logical_to_physical_map, - ( - 0, - eplb_model_state.logical_to_physical_map.shape[-1] - - max_physical_slots, - ), + (0, max(0, max_slots - new_logical_to_physical_map.shape[-1])), value=-1, + ).to(eplb_model_state.logical_to_physical_map.device) + new_replica = new_logical_replica_count.to( + eplb_model_state.logical_replica_count.device ) - eplb_model_state.logical_to_physical_map.copy_( - new_logical_to_physical_map - ) - eplb_model_state.logical_replica_count.copy_(new_logical_replica_count) - if is_main_rank: - assert time_start is not None - torch.cuda.synchronize() - time_end = time.perf_counter() + eplb_model_state.new_physical_to_logical_map = new_physical + eplb_model_state.new_logical_to_physical_map = padded_logical + eplb_model_state.new_logical_replica_count = new_replica + + eplb_model_state.rebalanced = True + eplb_model_state.layer_to_transfer = 0 + eplb_model_state.pending_global_ready_check = True + + # Signal async thread to start transferring layers + if self.is_async and (not is_profile): + self.rearrange_event.set() + return None + + def start_async_loop( + self, + rank_mapping: dict[int, int] | None = None, + is_profile: bool = False, + ): + if not self.is_async: + return + if self.async_worker is None: + self.async_worker = start_async_worker( + self, + rank_mapping=rank_mapping, + is_profile=is_profile, + ) + + def _update_layer_mapping_from_new( + self, model_state: EplbModelState, layer: int + ) -> None: + if ( + model_state.new_physical_to_logical_map is None + or model_state.new_logical_to_physical_map is None + or model_state.new_logical_replica_count is None + ): + return + + target_device = model_state.physical_to_logical_map.device + new_physical = model_state.new_physical_to_logical_map + if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]: + model_state.physical_to_logical_map = new_physical.to(target_device) + else: + model_state.physical_to_logical_map[layer].copy_( + new_physical[layer].to(target_device) + ) + + logical_device = model_state.logical_to_physical_map.device + new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device) + max_slots = model_state.logical_to_physical_map.shape[-1] + slot_delta = max_slots - new_logical.shape[-1] + if slot_delta > 0: + new_logical = torch.nn.functional.pad( + new_logical, (0, slot_delta), value=-1 + ) + model_state.logical_to_physical_map[layer].copy_(new_logical) + + replica_device = model_state.logical_replica_count.device + model_state.logical_replica_count[layer].copy_( + model_state.new_logical_replica_count[layer].to(replica_device) + ) + + def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool: + parallel_state = get_ep_group() + cpu_group = getattr(parallel_state, "cpu_group", None) + if cpu_group is not None and cpu_group.size() > 1: + flag = torch.tensor( + (int(model_state.ep_buffer_ready),), dtype=torch.int32, device="cpu" + ) + all_reduce(flag, group=cpu_group) + return int(flag.item()) == cpu_group.size() + + device_group = parallel_state.device_group + if device_group.size() <= 1: + return bool(model_state.ep_buffer_ready) + + device = getattr( + parallel_state, "device", model_state.physical_to_logical_map.device + ) + flag = torch.tensor( + (int(model_state.ep_buffer_ready),), dtype=torch.int32, device=device + ) + all_reduce(flag, group=device_group) + return int(flag.item()) == device_group.size() + + def move_to_workspace( + self, + model_state: EplbModelState, + ep_group: ProcessGroup, + is_profile: bool = False, + ): + if not model_state.buffer_lock.acquire(blocking=False): + return + try: + assert model_state.new_physical_to_logical_map is not None + device_index = model_state.cuda_device_index or self.cuda_device_index + if model_state.buffer_ready_event is not None and device_index is not None: + stream = torch.cuda.current_stream(device=device_index) + stream.wait_event(model_state.buffer_ready_event) + model_state.buffer_ready_event = None + move_from_buffer( + expert_weights=model_state.model.expert_weights[ + model_state.layer_to_transfer + ], + expert_weights_buffer=model_state.expert_buffer, + is_unchanged=model_state.is_unchanged, + is_received_locally=model_state.is_received_locally, + experts_recv_loc=model_state.experts_recv_loc, + new_indices=model_state.new_physical_to_logical_map[ + model_state.layer_to_transfer + ].tolist(), + ep_group=ep_group, + ) + transferred_layer = model_state.layer_to_transfer + self._update_layer_mapping_from_new(model_state, transferred_layer) + # After the main thread consumes, advance layer_to_transfer + model_state.layer_to_transfer += 1 + model_state.ep_buffer_ready = 0 logger.info( - "Rearranged experts%sin %.2f seconds.", - " (profile) " if is_profile else " ", - time_end - time_start, + "model %s successfully move_to_workspace layer %d", + model_state.model_name, + transferred_layer, ) - return None + finally: + try: + model_state.buffer_lock.release() + except Exception as e: + logger.error( + "Rank %d: buffer_lock release failed in move_to_workspace: %s", + ep_group.rank(), + str(e), + ) + + def post_eplb(self, model_state: EplbModelState, is_profile: bool = False) -> None: + assert model_state.new_physical_to_logical_map is not None + assert model_state.new_logical_to_physical_map is not None + assert model_state.new_logical_replica_count is not None + if not is_profile: + for layer_idx in range(model_state.physical_to_logical_map.shape[0]): + self._update_layer_mapping_from_new(model_state, layer_idx) + model_state.new_physical_to_logical_map = None + model_state.new_logical_to_physical_map = None + model_state.new_logical_replica_count = None @staticmethod def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]: diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 5c1efbaf03ba..376dad8a72ef 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -100,18 +100,19 @@ def get_ep_ranks_with_expert( return ranks_to_send, ranks_to_recv_actual -def shuffle_layer( +def move_to_buffer( num_local_experts: int, - ep_rank: int, old_indices: Sequence[int], new_indices: Sequence[int], expert_weights: Iterable[torch.Tensor], expert_weights_buffer: Sequence[torch.Tensor], + cuda_stream: torch.cuda.Stream | None, ep_group: ProcessGroup, -) -> None: +) -> tuple[list[bool], list[bool], dict[int, int]]: """ Perform expert weights rearrangement of one layer. """ + ep_rank = ep_group.rank() local2global = partial( idx_local_to_global, local_cnt=num_local_experts, @@ -137,7 +138,8 @@ def shuffle_layer( if old_indices[src_global] == new_indices[dst_global]: is_received_locally[dst] = True for weight, buffer in zip(expert_weights, expert_weights_buffer): - buffer[dst].copy_(weight[src]) + with torch.cuda.stream(cuda_stream): + buffer[dst].copy_(weight[src], non_blocking=True) p2p_ops: list[P2POp] = [] @@ -225,25 +227,115 @@ def shuffle_layer( ] # 4. Execute the P2P operations. The real communication happens here. - if p2p_ops: + if p2p_ops and cuda_stream is not None: + with torch.cuda.stream(cuda_stream): + reqs = batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + elif p2p_ops: reqs = batch_isend_irecv(p2p_ops) for req in reqs: req.wait() + # wait for the communication to finish + return is_unchanged, is_received_locally, experts_recv_loc + + +def move_from_buffer( + expert_weights: Iterable[torch.Tensor], + expert_weights_buffer: list[torch.Tensor], + is_unchanged: list[bool], + is_received_locally: list[bool], + experts_recv_loc: dict[int, int], + new_indices: Sequence[int], + ep_group: ProcessGroup, +) -> None: + ep_rank = ep_group.rank() + num_local_experts = len(is_unchanged) + + local2global = partial( + idx_local_to_global, local_cnt=num_local_experts, ep_rank=ep_rank + ) - # 5. Copy the weights from the buffer back to the original weights. for dst in range(num_local_experts): if is_unchanged[dst]: continue if is_received_locally[dst]: for weight, buffer in zip(expert_weights, expert_weights_buffer): - weight[dst].copy_(buffer[dst]) + weight[dst].copy_(buffer[dst], non_blocking=True) else: expert = new_indices[local2global(dst)] if expert == -1: continue src = experts_recv_loc[expert] for weight, buffer in zip(expert_weights, expert_weights_buffer): - weight[dst].copy_(buffer[src]) + weight[dst].copy_(buffer[src], non_blocking=True) + + +async def transfer_layer( + old_global_expert_indices: torch.Tensor, + new_global_expert_indices: torch.Tensor, + expert_weights: Sequence[Iterable[torch.Tensor]], + expert_weights_buffer: Sequence[torch.Tensor], + ep_group: ProcessGroup, + is_profile: bool = False, + layer: int = 0, + cuda_stream: torch.cuda.Stream | None = None, + rank_mapping: dict[int, int] | None = None, +) -> tuple[list[bool], list[bool], dict[int, int]]: + """ + Rearranges the expert weights in place according to the new expert indices. + + The value of the indices arguments are logical indices of the experts, + while keys are physical. + + Args: + old_global_expert_indices: Shape (num_moe_layers, num_physical_experts). + new_global_expert_indices: Shape (num_moe_layers, num_physical_experts). + expert_weights: A sequence of shape (num_moe_layers)(weight_count) + of tensors of shape (num_local_physical_experts, hidden_size_i). + For example, a linear layer may have up and down projection, + so weight_count = 2. Each weight's hidden size can be different. + ep_group: The device process group for expert parallelism. + is_profile (bool): If `True`, do not perform any actual weight copy. + This is used during profile run, where we only perform dummy + communications to reserve enough memory for the buffers. + """ + ep_size = ep_group.size() + if rank_mapping is not None: + if len(rank_mapping) == ep_group.size(): + # scale down + new_global_expert_indices = _map_new_expert_indices_with_rank_mapping( + new_global_expert_indices, + rank_mapping, + ) + else: + # scale up + old_global_expert_indices = _map_old_expert_indices_with_rank_mapping( + old_global_expert_indices, + rank_mapping, + ep_group.size(), + ) + + assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1] + num_moe_layers, num_physical_experts = old_global_expert_indices.shape + assert len(expert_weights) == num_moe_layers + num_local_physical_experts = next(iter(expert_weights[0])).shape[0] + assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) + assert num_physical_experts == ep_size * num_local_physical_experts + # A buffer to hold the expert weights in one layer during the exchange. + # NOTE: Currently we assume the same weights across different layers + # have the same shape. + + is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer( + num_local_experts=num_local_physical_experts, + old_indices=old_global_expert_indices[layer].tolist(), + new_indices=new_global_expert_indices[layer].tolist(), + expert_weights=expert_weights[layer], + expert_weights_buffer=expert_weights_buffer, + cuda_stream=cuda_stream, + ep_group=ep_group, + ) + return is_unchanged, is_received_locally, experts_recv_loc def rearrange_expert_weights_inplace( @@ -296,7 +388,6 @@ def rearrange_expert_weights_inplace( num_local_physical_experts = next(iter(expert_weights[0])).shape[0] assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) - ep_rank = ep_group.rank() ep_size = ep_group.size() assert num_physical_experts == ep_size * num_local_physical_experts @@ -329,14 +420,24 @@ def rearrange_expert_weights_inplace( torch.cuda.synchronize() for layer in range(num_moe_layers): - shuffle_layer( - num_local_physical_experts, - ep_rank, - old_global_expert_indices_cpu[layer].tolist(), - new_global_expert_indices_cpu[layer].tolist(), - expert_weights[layer], - expert_weights_buffer, - ep_group, + is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer( + num_local_experts=num_local_physical_experts, + old_indices=old_global_expert_indices_cpu[layer].tolist(), + new_indices=new_global_expert_indices_cpu[layer].tolist(), + expert_weights=expert_weights[layer], + expert_weights_buffer=expert_weights_buffer, + cuda_stream=None, + ep_group=ep_group, + ) + + move_from_buffer( + expert_weights=expert_weights[layer], + expert_weights_buffer=expert_weights_buffer, + is_unchanged=is_unchanged, + is_received_locally=is_received_locally, + experts_recv_loc=experts_recv_loc, + new_indices=new_global_expert_indices[layer].tolist(), + ep_group=ep_group, ) @@ -428,4 +529,4 @@ def _map_new_expert_indices_with_rank_mapping( return mapped_expert_indices -__all__ = ["rearrange_expert_weights_inplace"] +__all__ = ["transfer_layer", "move_from_buffer"] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6a54e02f861e..cbafc9c993cc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3370,6 +3370,8 @@ def load_model(self, eep_scale_up: bool = False) -> None: old_global_expert_indices, rank_mapping, ) + if self.eplb_state.is_async: + self.eplb_state.start_async_loop(rank_mapping=rank_mapping) if ( self.vllm_config.compilation_config.mode