From 178cb9f7182b8cbd9f0c544c2464b559420f5fea Mon Sep 17 00:00:00 2001 From: Ronald1995 Date: Sat, 23 Aug 2025 10:23:29 +0800 Subject: [PATCH] implement overlap of prepare_input during execute_model Signed-off-by: Ronald1995 --- vllm/config/scheduler.py | 10 ++ .../device_communicators/shm_broadcast.py | 45 ++++-- vllm/distributed/utils.py | 11 +- vllm/engine/arg_utils.py | 11 ++ vllm/v1/engine/core.py | 8 + vllm/v1/executor/multiproc_executor.py | 38 ++++- vllm/v1/sample/rejection_sampler.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 142 +++++++++++++----- 8 files changed, 209 insertions(+), 58 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 93002012799a..f74227af95be 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -159,6 +159,12 @@ class SchedulerConfig: structured outputs, speculative decoding, and pipeline parallelism. """ + async_execute_model: bool = False + """EXPERIMENTAL: If set to True, perform async model execution. + This may help reduce the CPU overheads, leading to better latency + and throughput. Moreover, this rely on async scheduling. + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -247,6 +253,10 @@ def __post_init__(self) -> None: self.scheduler_cls = ( "vllm.v1.core.sched.async_scheduler.AsyncScheduler") + if self.async_execute_model: + assert self.async_scheduling, ( + "async_execute_model requires async_scheduling to be True.") + @model_validator(mode='after') def _verify_args(self) -> Self: if (self.max_num_batched_tokens < self.max_model_len diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index c7810043b81e..6627263acb97 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -33,8 +33,8 @@ class SpinTimer: def record_activity(self): pass - def spin(self): - sched_yield() + def spin(self, sleep_time: Optional[float] = None): + sched_yield(sleep_time) class SpinSleepTimer(SpinTimer): @@ -370,7 +370,11 @@ def wait_until_ready(self): assert recv == b"READY" @contextmanager - def acquire_write(self, timeout: Optional[float] = None): + def acquire_write( + self, + timeout: Optional[float] = None, + sleep_time: Optional[float] = None, + ): assert self._is_writer, "Only writers can acquire write" start_time = time.monotonic() n_warning = 1 @@ -385,7 +389,7 @@ def acquire_write(self, timeout: Optional[float] = None): # we need to wait until it is read by all readers # Release the processor to other threads - sched_yield() + sched_yield(sleep_time) # if we wait for a long time, log a message if (time.monotonic() - start_time @@ -428,9 +432,12 @@ def acquire_write(self, timeout: Optional[float] = None): break @contextmanager - def acquire_read(self, - timeout: Optional[float] = None, - cancel: Optional[Event] = None): + def acquire_read( + self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None, + sleep_time: Optional[float] = None, + ): assert self._is_local_reader, "Only readers can acquire read" start_time = time.monotonic() n_warning = 1 @@ -448,7 +455,7 @@ def acquire_read(self, # we need to wait until it is written # Release the processor to other threads - self._read_spin_timer.spin() + self._read_spin_timer.spin(sleep_time) # if we wait for a long time, log a message if (time.monotonic() - start_time @@ -483,28 +490,36 @@ def acquire_read(self, self._read_spin_timer.record_activity() break - def enqueue(self, obj, timeout: Optional[float] = None): + def enqueue( + self, + obj, + timeout: Optional[float] = None, + sleep_time: Optional[float] = None, + ): """ Write to message queue with optional timeout (in seconds) """ assert self._is_writer, "Only writers can enqueue" serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) if self.n_local_reader > 0: if len(serialized_obj) >= self.buffer.max_chunk_bytes: - with self.acquire_write(timeout) as buf: + with self.acquire_write(timeout, sleep_time) as buf: buf[0] = 1 # overflow self.local_socket.send(serialized_obj) else: - with self.acquire_write(timeout) as buf: + with self.acquire_write(timeout, sleep_time) as buf: buf[0] = 0 # not overflow buf[1:len(serialized_obj) + 1] = serialized_obj if self.n_remote_reader > 0: self.remote_socket.send(serialized_obj) - def dequeue(self, - timeout: Optional[float] = None, - cancel: Optional[Event] = None): + def dequeue( + self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None, + sleep_time: Optional[float] = None, + ) -> Any: """ Read from message queue with optional timeout (in seconds) """ if self._is_local_reader: - with self.acquire_read(timeout, cancel) as buf: + with self.acquire_read(timeout, cancel, sleep_time) as buf: overflow = buf[0] == 1 if not overflow: # no need to know the size of serialized object diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 67f71643d039..e248e2c6e569 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -38,8 +38,15 @@ and sys.version_info[2] >= 8)) -def sched_yield(): - if USE_SCHED_YIELD: +def sched_yield(sleep_time: Optional[float] = None): + # when we set more than one threads in Worker Process, + # os.sched_yield() and time.sleep(0) both set the thread to ready state, + # but the cpu may reschedule it immediately, + # so we add a small sleep time to make sure the thread is set to blocked state, + # and the cpu can schedule other threads. + if sleep_time is not None: + time.sleep(sleep_time) + elif USE_SCHED_YIELD: os.sched_yield() else: time.sleep(0) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b0f50b4429a8..eb1b8730b22f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -445,6 +445,8 @@ class EngineArgs: async_scheduling: bool = SchedulerConfig.async_scheduling + async_execute_model: bool = SchedulerConfig.async_execute_model + kv_sharing_fast_prefill: bool = \ CacheConfig.kv_sharing_fast_prefill @@ -864,6 +866,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: scheduler_group.add_argument("--async-scheduling", **scheduler_kwargs["async_scheduling"]) + scheduler_group.add_argument("--async-execute-model", + **scheduler_kwargs["async_execute_model"]) + # vLLM arguments vllm_kwargs = get_kwargs(VllmConfig) vllm_group = parser.add_argument_group( @@ -1254,6 +1259,12 @@ def create_engine_config( raise ValueError("Async scheduling is not supported with " "pipeline-parallel-size > 1.") + if self.async_execute_model: + # TODO(Ronald1995): Support async execute model with ray. + if self.distributed_executor_backend != "mp": + raise ValueError("Async execute model is only supported with " + "mp-based distributed executor backend.") + # Currently, async scheduling does not support speculative decoding. # TODO(woosuk): Support it. if self.speculative_config is not None: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 32765cda6482..49dfc1556901 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -155,6 +155,7 @@ def __init__(self, self.request_block_hasher = get_request_block_hasher( block_size, caching_hash_fn) + self.async_execute_model = self.vllm_config.scheduler_config.async_execute_model def _initialize_kv_caches( self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: @@ -341,6 +342,13 @@ def step_with_batch_queue( # but peeking the first element in a queue is not thread-safe, # so we need more work. if not scheduled_batch and not self.batch_queue.empty(): + # when enable async_execute_model, we should not block to get + # future restult when total_num_scheduled_tokens equals to 0. + # cause in this case, it wont's send execute_model task to workers. + if (self.async_execute_model + and scheduler_output.total_num_scheduled_tokens == 0): + return engine_core_outputs, scheduled_batch + future, scheduler_output = self.batch_queue.get_nowait() # Blocking until the first result is available. diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 15b88a212899..680934fc4f5f 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -16,6 +16,7 @@ from multiprocessing.process import BaseProcess from threading import Thread from typing import Any, Callable, Optional, Union, cast +import queue import cloudpickle @@ -403,6 +404,12 @@ def __init__( # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) + # queue size and threadpool size are set to 2 to coincide with + # the max_concurrent_batches of the executor when enable async scheduling. + self.exe_queue = queue.Queue(2) + self.exe_thread_pool = ThreadPoolExecutor( + max_workers=2, thread_name_prefix="execute_model") + # Initialize device and loads weights self.worker.init_device() self.worker.load_model() @@ -586,6 +593,12 @@ class ResponseStatus(Enum): def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" + async_execute_model = self.worker.vllm_config.scheduler_config.async_execute_model + events = { + "d2h_copy_event": threading.Event(), + "update_sampled_tokens_event": threading.Event() + } + exe_count = 0 while True: method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue() @@ -594,7 +607,19 @@ def worker_busy_loop(self): func = getattr(self.worker, method) elif isinstance(method, bytes): func = partial(cloudpickle.loads(method), self.worker) - output = func(*args, **kwargs) + + if async_execute_model and func.__name__ == "execute_model": + args = (*args, exe_count, events) + output = self.execute_model_with_queue( + func, + *args, + **kwargs, + ) + exe_count += 1 + if not output: + continue + else: + output = func(*args, **kwargs) except Exception as e: # Notes have been introduced in python 3.11 if hasattr(e, "add_note"): @@ -610,3 +635,14 @@ def worker_busy_loop(self): if output_rank is None or self.rank == output_rank: self.worker_response_mq.enqueue( (WorkerProc.ResponseStatus.SUCCESS, output)) + + def execute_model_with_queue(self, func, *args, **kwargs): + """Execute model with a queue for async execution.""" + output = None + if not self.exe_queue.full(): + output_future = self.exe_thread_pool.submit(func, *args, **kwargs) + self.exe_queue.put_nowait(output_future) + if self.exe_queue.full(): + output = self.exe_queue.get().result() + self.exe_queue.task_done() + return output diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index b2354c53302a..c926fde2b8b5 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -121,7 +121,7 @@ def parse_output( Returns: A list of lists of token IDs. """ - output_token_ids_np = output_token_ids.cpu().numpy() + output_token_ids_np = output_token_ids.numpy() # Create mask for valid tokens. valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (output_token_ids_np < vocab_size)) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7caa873be444..d278c1f0f351 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -348,6 +348,7 @@ def __init__( # Cached outputs. self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None + self.async_execute_model = self.vllm_config.scheduler_config.async_execute_model def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() @@ -746,14 +747,6 @@ def _prepare_inputs( token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) - self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) self.input_batch.block_table.commit_slot_mapping( @@ -777,9 +770,6 @@ def _prepare_inputs( seq_lens = self.seq_lens[:num_reqs] max_seq_len = self.seq_lens_np[:num_reqs].max().item() - # Copy the tensors to the GPU. - self.input_ids[:total_num_scheduled_tokens].copy_( - self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( @@ -944,9 +934,42 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens) + return ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens, + spec_decode_common_attn_metadata, + max_num_scheduled_tokens, + token_indices, + ) + + def _update_token_ids( + self, + scheduler_output: "SchedulerOutput", + token_indices: np.ndarray, + ) -> None: + """Update the token IDs in the input batch with the new token indices. + + Args: + scheduler_output: The scheduler output containing the new token IDs. + token_indices: The indices of the tokens to be updated. + """ + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) + # Copy the tensors to the GPU. + self.input_ids[:total_num_scheduled_tokens].copy_( + self.input_ids_cpu[:total_num_scheduled_tokens], + non_blocking=True, + ) def _compute_cascade_attn_prefix_len( self, @@ -1511,7 +1534,21 @@ def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, + exe_count: int = 0, + events: Optional[dict] = None, ) -> Union[ModelRunnerOutput, IntermediateTensors]: + if self.async_execute_model: + assert events is not None, \ + "Events must be provided for async execution." + d2h_copy_event = events["d2h_copy_event"] + update_sampled_tokens_event = events["update_sampled_tokens_event"] + # when last execute_model is in execution, + # we need to wait for the d2h copy to start, + # then asynchronously execute update_states and prepare_inputs of this step + # during d2h copy of last step. + if exe_count > 0: + d2h_copy_event.wait() + d2h_copy_event.clear() self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): @@ -1522,9 +1559,15 @@ def execute_model( self.vllm_config) # Prepare the decoder inputs. - (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len) = (self._prepare_inputs(scheduler_output)) + ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens_np, + spec_decode_common_attn_metadata, + max_query_len, + token_indices, + ) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE @@ -1558,31 +1601,11 @@ def execute_model( mm_embeds = [] if self.supports_mm_inputs and get_pp_group().is_first_rank: - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - inputs_embeds_scheduled = self.model.get_input_embeddings( - input_ids=self.input_ids[:num_scheduled_tokens], - multimodal_embeddings=mm_embeds or None, - ) - - # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_scheduled_tokens].copy_( - inputs_embeds_scheduled) - - input_ids = None - inputs_embeds = self.inputs_embeds[:num_input_tokens] model_kwargs = { **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), } else: - # For text-only models, we use token ids as input. - # While it is possible to use embeddings as input just like the - # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the CUDA graph. - input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = None model_kwargs = self._init_model_kwargs(num_input_tokens) if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] @@ -1602,6 +1625,36 @@ def execute_model( cudagraph_runtime_mode, batch_descriptor = \ self.cudagraph_dispatcher.dispatch(batch_descriptor) + if self.async_execute_model and exe_count > 0: + # the operation of token ids(input ids) rely on the last execute_model step, + # so we need to wait for the token ids update of the last step to finish. + update_sampled_tokens_event.wait() + update_sampled_tokens_event.clear() + # these are operations of token ids(input ids). + if self.supports_mm_inputs and get_pp_group().is_first_rank: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + inputs_embeds_scheduled = self.model.get_input_embeddings( + input_ids=self.input_ids[:num_scheduled_tokens], + multimodal_embeddings=mm_embeds or None, + ) + + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_scheduled_tokens].copy_( + inputs_embeds_scheduled) + + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] + + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + self._update_token_ids(scheduler_output, token_indices) # Run the model. # Use persistent buffers for CUDA graphs. with set_forward_context( @@ -1733,13 +1786,20 @@ def execute_model( # Get the valid generated tokens. sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] + sample_token_ids_cpu = sampled_token_ids.to(device="cpu", + non_blocking=True) + if self.async_execute_model: + # start to copy sampled token ids to cpu, it will block the cpu, + # so notify another thread to do cpu works like update_states and prepare_inputs + d2h_copy_event.set() + torch.cuda.synchronize() if max_gen_len == 1: # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() + valid_sampled_token_ids = sample_token_ids_cpu.tolist() else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, + sample_token_ids_cpu, self.input_batch.vocab_size, ) # Mask out the sampled tokens that should not be sampled. @@ -1770,6 +1830,10 @@ def execute_model( req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) + if self.async_execute_model: + # token ids in input_batch has been updated, so we can notify + # the waiting execute_model to continue. + update_sampled_tokens_event.set() if self.speculative_config: assert spec_decode_common_attn_metadata is not None