Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
95b1b7d
first prototype, working for BS=1
benchislett Aug 24, 2025
5bd9851
wip for batched
benchislett Aug 25, 2025
9a59696
fix bs > 1
benchislett Aug 25, 2025
a24d715
add back removed code
benchislett Aug 25, 2025
51b6169
minor perf optimization
benchislett Aug 25, 2025
2832e37
improvements
benchislett Aug 26, 2025
bd331b4
remove old prints
benchislett Aug 26, 2025
dfa5ca9
Merge branch 'main' into overlap-model-execution
benchislett Aug 26, 2025
c118525
fix precommit
benchislett Aug 26, 2025
43b4f17
Merge branch 'main' into overlap-model-execution
benchislett Aug 27, 2025
752ccf9
misc cleanup
benchislett Aug 27, 2025
9f28326
refactor prepare_input_ids
benchislett Aug 27, 2025
15d7b31
tiny refactor to reorder some ops
benchislett Aug 27, 2025
b351a56
Merge branch 'main' into overlap-model-execution
benchislett Sep 2, 2025
5df3ae8
refactor async model runner output
benchislett Sep 2, 2025
efcc3ee
tiny cleanup
benchislett Sep 2, 2025
b4611f4
Merge branch 'main' into overlap-model-execution
benchislett Sep 2, 2025
6c025bb
remove torch from multiproc_executor
benchislett Sep 2, 2025
bc99a79
refactor async output in multiproc executor
benchislett Sep 3, 2025
2ffa123
cleanup
benchislett Sep 3, 2025
7ae3166
improve async gpu model runner output structure
benchislett Sep 3, 2025
75c109d
use cuda event to sync copy stream
benchislett Sep 3, 2025
3f9d46b
Merge branch 'main' into overlap-model-execution
benchislett Sep 3, 2025
ff5bc7a
minor refactor for readability
benchislett Sep 4, 2025
6a44032
more minor refactor
benchislett Sep 4, 2025
b411981
Merge branch 'main' into overlap-model-execution
benchislett Sep 4, 2025
0d23f0e
refactor prepare_input_ids for fewer cpu ops
benchislett Sep 4, 2025
54feea9
restructure multiproc output handling to isolate effects on non-async…
benchislett Sep 5, 2025
4bddae2
Merge branch 'main' into overlap-model-execution
benchislett Sep 5, 2025
70f4921
Merge branch 'main' into overlap-model-execution
benchislett Sep 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import multiprocessing
import os
import pickle
import queue
import signal
import threading
import time
Expand All @@ -18,6 +19,7 @@
from typing import Any, Callable, Optional, Union, cast

import cloudpickle
import torch

import vllm.envs as envs
from vllm.config import VllmConfig
Expand Down Expand Up @@ -586,6 +588,49 @@ class ResponseStatus(Enum):

def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""

def process_output(output: Any, worker_response_mq: MessageQueue,
copy_stream: torch.cuda.Stream):
if isinstance(output, ModelRunnerOutput) and isinstance(
output.sampled_token_ids, tuple):
# sampled_token_ids is a tuple of (Tensor, list[int])
# where the second element is the list of invalid req indices
tensor, invalid_req_indices = output.sampled_token_ids
tensor = cast(torch.Tensor, tensor)
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(copy_stream):
copy_stream.wait_stream(default_stream)
sampled_token_ids_list = tensor.to('cpu',
non_blocking=True)
copy_stream.synchronize()
sampled_token_ids_list = sampled_token_ids_list.tolist()
for i in invalid_req_indices:
sampled_token_ids_list[i].clear()
output.sampled_token_ids = sampled_token_ids_list
worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))
else:
worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))
return

def _output_processor_loop(input_queue: queue.Queue,
worker_response_mq: MessageQueue,
copy_stream: torch.cuda.Stream):
while True:
output = input_queue.get()
process_output(output, worker_response_mq, copy_stream)

output_queue: queue.Queue = queue.Queue()
copy_stream_ = torch.cuda.Stream()
output_processor_thread = Thread(target=_output_processor_loop,
args=(output_queue,
self.worker_response_mq,
copy_stream_),
daemon=True,
name="WorkerOutputProcessor")
output_processor_thread.start()

while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()

Expand All @@ -603,10 +648,12 @@ def worker_busy_loop(self):
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
if output_rank is None or self.rank == output_rank:
print("Enqueueing FAILURE message to worker_response_mq")
print(traceback.format_exc())
print(str(e))
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
continue

if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))
output_queue.put(output)
5 changes: 5 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ def __init__(

self.pooling_params: dict[str, PoolingParams] = {}

# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: torch.Tensor | None = None
self.prev_sampled_token_ids_invalid_indices: set[int] | None = None
self.prev_req_id_to_index: dict[str, int] | None = None

@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
Expand Down
132 changes: 109 additions & 23 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def __init__(
is_pooling_model=self.is_pooling_model,
)

self.use_async_scheduling = self.scheduler_config.async_scheduling

# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order.
Expand Down Expand Up @@ -777,6 +779,57 @@ def _prepare_inputs(
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)

if self.input_batch.prev_sampled_token_ids is not None:
# First, calculate which requests in the current batch also exist in the previous cached batch.
# And what their indices are in the new batch
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
current_req_id_to_index = self.input_batch.req_id_to_index
common_req_ids = set(prev_req_id_to_index.keys()).intersection(
set(current_req_id_to_index.keys()))
if common_req_ids:
current_common_req_indices = [
current_req_id_to_index[req_id]
for req_id in common_req_ids
]
prev_common_req_indices = [
prev_req_id_to_index[req_id] for req_id in common_req_ids
]
# Now, for each of these requests, we need to copy the last sampled token from the previous batch
# to the correct position in the current input_ids tensor

# We need to compute the flattened input_ids index of the last token for each of these requests
flattened_indices = [
int(cu_num_tokens[idx]) - 1
for idx in current_common_req_indices
]
if flattened_indices == prev_common_req_indices and \
set(flattened_indices) == set(range(len(flattened_indices))):
# Common-case optimization: the batch is unchanged
# and no reordering happened.
# The indices are both the same permutation of [0, 1, 2, ..., len - 1]
self.input_ids[:len(flattened_indices)].copy_(
self.input_batch.prev_sampled_token_ids[:len(
flattened_indices)].squeeze(1),
non_blocking=True)
else:
# Upload the index tensors asynchronously so the scatter can be non-blocking
input_ids_index_tensor = torch.tensor(
flattened_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device,
non_blocking=True)
prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device,
non_blocking=True)
self.input_ids.scatter_(
dim=0,
index=input_ids_index_tensor,
src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor].squeeze(1))

if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
Expand Down Expand Up @@ -1727,45 +1780,72 @@ def execute_model(
scheduler_output.num_scheduled_tokens,
)

# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
if not self.use_async_scheduling:
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
valid_sampled_token_ids = None
sampled_token_ids_tensor = sampler_output.sampled_token_ids
invalid_req_indices = list(discard_sampled_tokens_req_indices)
invalid_req_indices_set = set(invalid_req_indices)
assert sampled_token_ids_tensor.shape[-1] == 1

# Cache the sampled tokens on the GPU and avoid CPU sync.
# These will be copied into input_ids in the next step
# when preparing inputs.
self.input_batch.prev_sampled_token_ids = \
sampled_token_ids_tensor.clone()
self.input_batch.prev_sampled_token_ids_invalid_indices = \
invalid_req_indices_set
self.input_batch.prev_req_id_to_index = {
req_id: i
for i, req_id in enumerate(self.input_batch.req_ids)
if i not in invalid_req_indices_set
}

# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
if not sampled_ids:
for req_idx in range(num_sampled_tokens):
if req_idx in discard_sampled_tokens_req_indices:
continue

start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
end_idx = start_idx + 1
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")

self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx

if self.use_async_scheduling:
sampled_ids = [-1] * 1
else:
sampled_ids = valid_sampled_token_ids[req_idx]

req_id = req_ids[req_idx]
req_state = self.requests[req_id]
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
req_state.output_token_ids.extend(sampled_ids)

if self.speculative_config:
Expand All @@ -1783,17 +1863,23 @@ def execute_model(

self.eplb_step()

return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids.copy(),
req_id_to_index=self.input_batch.req_id_to_index.copy(),
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=kv_connector_output,
num_nans_in_logits=num_nans_in_logits,
num_nans_in_logits=num_nans_in_logits.copy(),
)

if self.use_async_scheduling:
output.sampled_token_ids = (sampled_token_ids_tensor,
invalid_req_indices)

return output

def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
if self._draft_token_ids is None:
return None
Expand Down