Skip to content

Commit b82519d

Browse files
ZeldaHuangwangyu31577
authored andcommitted
[Feature] Support Qwen Omni online batch inference (vllm-project#438)
Signed-off-by: ZeldaHuang <[email protected]> Signed-off-by: wangyu31577 <[email protected]>
1 parent e78ef50 commit b82519d

2 files changed

Lines changed: 132 additions & 90 deletions

File tree

tests/e2e/online_serving/test_qwen3_omni.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
E2E Online tests for Qwen3-Omni model with video input and audio output.
55
"""
66

7+
import concurrent.futures
78
import os
89
import socket
910
import subprocess
@@ -167,40 +168,55 @@ def dummy_messages_from_video_data(
167168

168169

169170
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
170-
def test_video_to_audio(
171+
def test_video_to_audio_concurrent(
171172
client: openai.OpenAI,
172173
omni_server,
173174
base64_encoded_video: str,
174175
) -> None:
175-
"""Test processing video, generating audio output via OpenAI API."""
176+
"""Test processing video with multiple concurrent completions, generating audio output via OpenAI API."""
176177
# Create data URL for the base64 encoded video
177178
video_data_url = f"data:video/mp4;base64,{base64_encoded_video}"
178179

179180
messages = dummy_messages_from_video_data(video_data_url)
180181

181-
# Test single completion
182-
chat_completion = client.chat.completions.create(
183-
model=omni_server.model,
184-
messages=messages,
185-
)
186-
187-
assert len(chat_completion.choices) == 2 # 1 for text output, 1 for audio output
188-
189-
# Verify text output
190-
text_choice = chat_completion.choices[0]
191-
assert text_choice.finish_reason == "length"
192-
193-
# Verify we got a response
194-
text_message = text_choice.message
195-
assert text_message.content is not None and len(text_message.content) >= 10
196-
assert text_message.role == "assistant"
197-
198-
# Verify audio output
199-
audio_choice = chat_completion.choices[1]
200-
assert audio_choice.finish_reason == "stop"
201-
audio_message = audio_choice.message
202-
203-
# Check if audio was generated
204-
if hasattr(audio_message, "audio") and audio_message.audio:
205-
assert audio_message.audio.data is not None
206-
assert len(audio_message.audio.data) > 0
182+
# Test multiple concurrent completions
183+
num_concurrent_requests = 5
184+
185+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor:
186+
# Submit multiple completion requests concurrently
187+
futures = [
188+
executor.submit(
189+
client.chat.completions.create,
190+
model=omni_server.model,
191+
messages=messages,
192+
)
193+
for _ in range(num_concurrent_requests)
194+
]
195+
196+
# Wait for all requests to complete and collect results
197+
chat_completions = [future.result() for future in concurrent.futures.as_completed(futures)]
198+
199+
# Verify all completions succeeded
200+
assert len(chat_completions) == num_concurrent_requests
201+
202+
for chat_completion in chat_completions:
203+
assert len(chat_completion.choices) == 2 # 1 for text output, 1 for audio output
204+
205+
# Verify text output
206+
text_choice = chat_completion.choices[0]
207+
assert text_choice.finish_reason == "length"
208+
209+
# Verify we got a response
210+
text_message = text_choice.message
211+
assert text_message.content is not None and len(text_message.content) >= 10
212+
assert text_message.role == "assistant"
213+
214+
# Verify audio output
215+
audio_choice = chat_completion.choices[1]
216+
assert audio_choice.finish_reason == "stop"
217+
audio_message = audio_choice.message
218+
219+
# Check if audio was generated
220+
if hasattr(audio_message, "audio") and audio_message.audio:
221+
assert audio_message.audio.data is not None
222+
assert len(audio_message.audio.data) > 0

vllm_omni/entrypoints/omni_stage.py

Lines changed: 88 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717
import multiprocessing as mp
1818
import os
19+
import queue
1920
import sys
2021
import traceback
2122
from typing import Any
@@ -1135,19 +1136,14 @@ def filter(self, record: _logging.LogRecord) -> bool:
11351136
)
11361137
except Exception as e:
11371138
_logging.getLogger(__name__).warning("[Stage-%s] Failed to send stage ready signal: %s", stage_id, e)
1138-
1139+
generation_out_q = asyncio.Queue()
11391140
# Batch processing loop
1140-
while True:
1141-
task = in_q.get()
1142-
_recv_dequeue_ts = _time.time()
1143-
if task is None:
1144-
_logging.getLogger(__name__).debug("[Stage-%s] Received shutdown signal", stage_id)
1145-
break
1146-
1147-
_rx_bytes_by_rid: dict[Any, int] = {}
1148-
_rx_decode_ms_by_rid: dict[Any, float] = {}
1149-
_in_flight_ms_by_rid: dict[Any, float] = {}
1141+
_rx_bytes_by_rid: dict[Any, int] = {}
1142+
_rx_decode_ms_by_rid: dict[Any, float] = {}
1143+
_in_flight_ms_by_rid: dict[Any, float] = {}
11501144

1145+
async def generation_single_request(task: dict[str, Any]):
1146+
_recv_dequeue_ts = _time.time()
11511147
rid = task["request_id"]
11521148
try:
11531149
sent_ts = float(task.get("sent_ts", None)) if isinstance(task, dict) else None
@@ -1157,62 +1153,101 @@ def filter(self, record: _logging.LogRecord) -> bool:
11571153
_in_flight_ms_by_rid[rid] = 0.0
11581154
except Exception:
11591155
_in_flight_ms_by_rid[rid] = 0.0
1160-
ein, _rx_metrics = try_recv_via_connector(
1161-
task=task,
1162-
connectors=connectors,
1163-
stage_id=stage_id,
1164-
)
1165-
if ein is None or _rx_metrics is None:
1166-
raise RuntimeError(
1167-
f"[Stage-{stage_id}] Missing connector payload for request {rid}. "
1168-
"Ensure connectors are configured for all incoming edges."
1169-
)
1170-
_rx_decode_ms_by_rid[rid] = float(_rx_metrics.get("rx_decode_time_ms", 0.0))
1171-
_rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0))
1172-
1173-
sampling_params = task["sampling_params"]
1174-
_logging.getLogger(__name__).debug("[Stage-%s] Received batch size=1, request_ids=%s", stage_id, rid)
1175-
print("--------------------------------", flush=True)
1176-
print(f"[Stage-{stage_id}] Received batch size=1, request_ids={rid}", flush=True)
1177-
print("--------------------------------", flush=True)
11781156
try:
1179-
_batch_seq += 1
1157+
ein, _rx_metrics = try_recv_via_connector(
1158+
task=task,
1159+
connectors=connectors,
1160+
stage_id=stage_id,
1161+
)
1162+
if ein is None or _rx_metrics is None:
1163+
raise RuntimeError(
1164+
f"[Stage-{stage_id}] Missing connector payload for request {rid}. "
1165+
"Ensure connectors are configured for all incoming edges."
1166+
)
1167+
_rx_decode_ms_by_rid[rid] = float(_rx_metrics.get("rx_decode_time_ms", 0.0))
1168+
_rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0))
1169+
1170+
sampling_params = task["sampling_params"]
1171+
_logging.getLogger(__name__).debug("[Stage-%s] Received batch size=1, request_ids=%s", stage_id, rid)
1172+
print("--------------------------------", flush=True)
1173+
print(f"[Stage-{stage_id}] Received batch size=1, request_ids={rid}", flush=True)
1174+
print("--------------------------------", flush=True)
11801175
_gen_t0 = _time.time()
11811176
if isinstance(ein, list):
11821177
ein = ein[0]
1183-
11841178
async for res in stage_engine.generate(ein, sampling_params, rid):
11851179
gen_output = res
11861180
_gen_t1 = _time.time()
11871181
_gen_ms = (_gen_t1 - _gen_t0) * 1000.0
1182+
await generation_out_q.put((rid, gen_output, _gen_ms))
1183+
except Exception as e:
1184+
_logging.getLogger(__name__).exception("[Stage-%s] Failed on request %s: %s", stage_id, rid, e)
1185+
out_q.put(
1186+
{
1187+
"request_id": rid,
1188+
"stage_id": stage_id,
1189+
"error": str(e),
1190+
}
1191+
)
11881192

1189-
r_outputs = [gen_output]
1190-
_num_tokens = count_tokens_from_outputs(r_outputs)
1191-
_agg_total_tokens += _num_tokens
1192-
_agg_total_gen_time_ms += _gen_ms
1193-
1194-
if _stats_file:
1195-
_avg_tokens_per_s = (
1196-
(_agg_total_tokens * 1000.0 / _agg_total_gen_time_ms) if _agg_total_gen_time_ms > 0 else 0.0
1197-
)
1198-
log_stage_running_avg(
1199-
_stats_file,
1200-
stage_id,
1201-
int(_agg_total_tokens),
1202-
float(_agg_total_gen_time_ms),
1203-
float(_avg_tokens_per_s),
1204-
)
1193+
_batch_gen_t0 = _time.time()
1194+
while True:
1195+
try:
1196+
task = in_q.get_nowait()
1197+
if task is None:
1198+
_logging.getLogger(__name__).debug("[Stage-%s] Received shutdown signal", stage_id)
1199+
break
1200+
asyncio.create_task(generation_single_request(task))
1201+
except queue.Empty:
1202+
await asyncio.sleep(0.001)
1203+
batch_request_outputs: list[Any] = []
1204+
batch_request_ids: list[Any] = []
1205+
_gen_ms_list = []
1206+
while True:
1207+
try:
1208+
rids, gen_output, _gen_ms = generation_out_q.get_nowait()
1209+
_num_tokens = count_tokens_from_outputs([gen_output])
1210+
batch_request_outputs.append(gen_output)
1211+
_gen_ms_list.append(_gen_ms)
1212+
batch_request_ids.append(rids)
1213+
_agg_total_tokens += _num_tokens
1214+
except asyncio.QueueEmpty:
1215+
await asyncio.sleep(0.001)
1216+
break
1217+
1218+
if not batch_request_outputs:
1219+
continue
1220+
_batch_seq += 1
1221+
if _stats_file:
1222+
_batch_gen_t1 = _time.time()
1223+
_agg_total_gen_time_ms += (_batch_gen_t1 - _batch_gen_t0) * 1000
1224+
_batch_gen_t0 = _batch_gen_t1
1225+
_avg_tokens_per_s = (
1226+
(_agg_total_tokens * 1000.0 / _agg_total_gen_time_ms) if _agg_total_gen_time_ms > 0 else 0.0
1227+
)
1228+
log_stage_running_avg(
1229+
_stats_file,
1230+
stage_id,
1231+
int(_agg_total_tokens),
1232+
float(_agg_total_gen_time_ms),
1233+
float(_avg_tokens_per_s),
1234+
)
1235+
logger.info("[Stage-%s] Running avg: %s tokens/s", stage_id, _avg_tokens_per_s)
1236+
for rid, _gen_ms in zip(batch_request_ids, _gen_ms_list):
12051237
log_stage_batch_stats(_stats_file, stage_id, 1, float(_gen_ms), [rid])
12061238

1239+
logger.info("[Stage-%s] Sending outputs to main process", stage_id)
1240+
for rid, output, _gen_ms in zip(batch_request_ids, batch_request_outputs, _gen_ms_list):
12071241
try:
1242+
r_outputs = [output]
12081243
use_shm, payload = maybe_dump_to_shm(r_outputs, shm_threshold_bytes)
12091244
_metrics = {
12101245
"num_tokens_out": int(count_tokens_from_outputs(r_outputs)),
12111246
"stage_gen_time_ms": _gen_ms,
12121247
"batch_id": int(_batch_seq),
1213-
"rx_decode_time_ms": float(_rx_decode_ms_by_rid.get(rid, 0.0)),
1214-
"rx_transfer_bytes": int(_rx_bytes_by_rid.get(rid, 0)),
1215-
"rx_in_flight_time_ms": float(_in_flight_ms_by_rid.get(rid, 0.0)),
1248+
"rx_decode_time_ms": float(_rx_decode_ms_by_rid.pop(rid, 0.0)),
1249+
"rx_transfer_bytes": int(_rx_bytes_by_rid.pop(rid, 0)),
1250+
"rx_in_flight_time_ms": float(_in_flight_ms_by_rid.pop(rid, 0.0)),
12161251
}
12171252
if _stats_file:
12181253
compute_and_log_stage_request_stats(
@@ -1266,23 +1301,14 @@ def filter(self, record: _logging.LogRecord) -> bool:
12661301
"metrics": {
12671302
"num_tokens_out": int(count_tokens_from_outputs(r_outputs)),
12681303
"stage_gen_time_ms": _gen_ms,
1269-
"rx_decode_time_ms": float(_rx_decode_ms_by_rid.get(rid, 0.0)),
1270-
"rx_transfer_bytes": int(_rx_bytes_by_rid.get(rid, 0)),
1271-
"rx_in_flight_time_ms": float(_in_flight_ms_by_rid.get(rid, 0.0)),
1304+
"rx_decode_time_ms": float(_rx_decode_ms_by_rid.pop(rid, 0.0)),
1305+
"rx_transfer_bytes": int(_rx_bytes_by_rid.pop(rid, 0)),
1306+
"rx_in_flight_time_ms": float(_in_flight_ms_by_rid.pop(rid, 0.0)),
12721307
},
12731308
}
12741309
)
12751310
_logging.getLogger(__name__).debug("[Stage-%s] Enqueued result for request %s to downstream", stage_id, rid)
12761311

1277-
except Exception as e:
1278-
_logging.getLogger(__name__).exception("[Stage-%s] Failed on request %s: %s", stage_id, rid, e)
1279-
out_q.put(
1280-
{
1281-
"request_id": rid,
1282-
"stage_id": stage_id,
1283-
"error": str(e),
1284-
}
1285-
)
12861312
print("--------------------------------", flush=True)
12871313
print(f"[Stage-{stage_id}] Stage worker exiting", flush=True)
12881314
print("--------------------------------", flush=True)

0 commit comments

Comments
 (0)