Skip to content

Commit 344a91e

Browse files
committed
send response with request id
Signed-off-by: dengyunyang <584797741@qq.com>
1 parent 7aaec2d commit 344a91e

2 files changed

Lines changed: 169 additions & 113 deletions

File tree

vllm_omni/entrypoints/async_omni.py

Lines changed: 159 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs
4747
from vllm_omni.engine.output_processor import MultimodalOutputProcessor
4848
from vllm_omni.engine.processor import OmniProcessor
49+
from vllm_omni.entrypoints.client_request_state import ClientRequestState
4950
from vllm_omni.entrypoints.log_utils import (
5051
OrchestratorMetrics,
5152
configure_orchestrator_logger,
@@ -122,6 +123,8 @@ def __init__(
122123
self.stage_configs = load_stage_configs_from_yaml(cli_args.stage_configs_path, base_engine_args)
123124

124125
shm_threshold_bytes = cli_args.shm_threshold_bytes
126+
self.output_handler: Optional[asyncio.Task] = None
127+
self.request_states: dict[str, ClientRequestState] = {} # request_id -> state
125128

126129
# Initialize connectors
127130
self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors(
@@ -238,6 +241,10 @@ def close(self) -> None:
238241
except Exception as e:
239242
logger.warning("[Orchestrator] Failed to stop stage worker: %s", e)
240243

244+
if self.output_handler is not None:
245+
self.output_handler.cancel()
246+
self.output_handler = None
247+
241248
try_close_ray(self._ray_pg)
242249

243250
def __del__(self) -> None: # best-effort
@@ -295,6 +302,10 @@ async def generate(
295302
ValueError: If sampling_params_list has incorrect length.
296303
"""
297304
logger.debug("[Orchestrator] generate() called")
305+
306+
# Start output handler on the first call to generate()
307+
self._run_output_handler()
308+
298309
if sampling_params_list is None:
299310
sampling_params_list = self.default_sampling_params_list
300311
if len(sampling_params_list) != len(self.stage_list):
@@ -339,6 +350,9 @@ async def generate(
339350
)
340351
# Seed stage-0 queue with all requests
341352
logger.debug("[Orchestrator] Seeding request into stage-0")
353+
req_state = ClientRequestState(request_id)
354+
self.request_states[request_id] = req_state
355+
342356
# Mark first input time for stage-0
343357
metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time()
344358

@@ -353,135 +367,120 @@ async def generate(
353367
logger.debug("[Orchestrator] Enqueued request %s to stage-0", request_id)
354368

355369
logger.debug("[Orchestrator] Entering scheduling loop: stages=%d", num_stages)
356-
finished = False
357-
while not finished:
358-
made_progress = False
359-
for stage_id, stage in enumerate(self.stage_list):
360-
result = stage.try_collect()
361-
if result is None:
362-
continue
370+
for stage_id, stage in enumerate(self.stage_list):
371+
result = await req_state.queue.get()
372+
assert stage_id == req_state.stage_id
363373

364-
made_progress = True
365-
req_id = result.get("request_id")
366-
if "error" in result:
367-
logger.error(
368-
"Stage %s error on request %s: %s",
369-
stage_id,
370-
req_id,
371-
result["error"],
372-
)
373-
continue
374+
req_id = result.get("request_id")
375+
if "error" in result:
376+
logger.error(
377+
"Stage %s error on request %s: %s",
378+
stage_id,
379+
req_id,
380+
result["error"],
381+
)
382+
raise RuntimeError(result) # Request Finished due to error
374383

375-
if result.get("type") == "stage_ready":
376-
# Only happens when stage is initialized slower than expected,
377-
# so we wait for a short time and try again
378-
time.sleep(0.05)
379-
continue
384+
engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm")
385+
# Mark last output time for this stage whenever we receive outputs
386+
metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time())
387+
try:
388+
_m = result.get("metrics")
389+
if _m is not None:
390+
metrics.on_stage_metrics(stage_id, req_id, _m)
391+
except Exception as e:
392+
logger.exception(
393+
"[Orchestrator] Failed to process metrics for stage %s, \
394+
req %s: %s",
395+
stage_id,
396+
req_id,
397+
e,
398+
)
399+
logger.debug(
400+
"[Orchestrator] Stage-%s completed request %s; \
401+
forwarding or finalizing",
402+
stage_id,
403+
req_id,
404+
)
405+
stage.set_engine_outputs(engine_outputs)
380406

381-
engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm")
382-
# Mark last output time for this stage whenever we receive outputs
383-
metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time())
407+
if getattr(stage, "final_output", False):
408+
logger.debug(
409+
"[Orchestrator] Request %s finalized at stage-%s",
410+
req_id,
411+
stage_id,
412+
)
413+
414+
# End-to-end timing and time-per-token for final output
415+
# (only once per request at the designated final stage)
384416
try:
385-
_m = result.get("metrics")
386-
if _m is not None:
387-
metrics.on_stage_metrics(stage_id, req_id, _m)
417+
rid_key = str(req_id)
418+
if stage_id == final_stage_id_for_e2e and rid_key not in metrics.e2e_done:
419+
metrics.on_finalize_request(
420+
stage_id,
421+
req_id,
422+
engine_outputs,
423+
_req_start_ts.get(req_id, _wall_start_ts),
424+
)
388425
except Exception as e:
389426
logger.exception(
390-
"[Orchestrator] Failed to process metrics for stage %s, \
391-
req %s: %s",
392-
stage_id,
427+
"[Orchestrator] Finalize request handling error for \
428+
req %s at stage %s: %s",
393429
req_id,
430+
stage_id,
394431
e,
395432
)
396-
logger.debug(
397-
"[Orchestrator] Stage-%s completed request %s; \
398-
forwarding or finalizing",
399-
stage_id,
400-
req_id,
433+
434+
if isinstance(engine_outputs, list):
435+
engine_outputs = engine_outputs[0]
436+
yield OmniRequestOutput(
437+
stage_id=stage_id,
438+
final_output_type=stage.final_output_type,
439+
request_output=engine_outputs,
401440
)
402-
stage.set_engine_outputs(engine_outputs)
403441

404-
if getattr(stage, "final_output", False):
405-
logger.debug(
406-
"[Orchestrator] Request %s finalized at stage-%s",
407-
req_id,
408-
stage_id,
409-
)
442+
next_stage_id = stage_id + 1
443+
if next_stage_id < num_stages:
444+
next_stage: OmniStage = self.stage_list[next_stage_id]
445+
next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt)
446+
sp_next: SamplingParams = sampling_params_list[next_stage_id]
410447

411-
# End-to-end timing and time-per-token for final output
412-
# (only once per request at the designated final stage)
413-
try:
414-
rid_key = str(req_id)
415-
if stage_id == final_stage_id_for_e2e and rid_key not in metrics.e2e_done:
416-
metrics.on_finalize_request(
417-
stage_id,
418-
req_id,
419-
engine_outputs,
420-
_req_start_ts.get(req_id, _wall_start_ts),
421-
)
422-
except Exception as e:
423-
logger.exception(
424-
"[Orchestrator] Finalize request handling error for \
425-
req %s at stage %s: %s",
426-
req_id,
427-
stage_id,
428-
e,
429-
)
448+
# Check if we have a connector for this edge
449+
connector_key = (str(stage_id), str(next_stage_id))
450+
connector = self.connectors.get(connector_key)
430451

431-
if isinstance(engine_outputs, list):
432-
engine_outputs = engine_outputs[0]
433-
yield OmniRequestOutput(
452+
sent_via_connector = False
453+
if connector:
454+
sent_via_connector = try_send_via_connector(
455+
connector=connector,
434456
stage_id=stage_id,
435-
final_output_type=stage.final_output_type,
436-
request_output=engine_outputs,
457+
next_stage_id=next_stage_id,
458+
req_id=req_id,
459+
next_inputs=next_inputs,
460+
sampling_params=sp_next,
461+
original_prompt=prompt,
462+
next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit,
463+
metrics=metrics,
437464
)
438465

439-
next_stage_id = stage_id + 1
440-
if next_stage_id < num_stages:
441-
next_stage: OmniStage = self.stage_list[next_stage_id]
442-
next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt)
443-
sp_next: SamplingParams = sampling_params_list[next_stage_id]
444-
445-
# Check if we have a connector for this edge
446-
connector_key = (str(stage_id), str(next_stage_id))
447-
connector = self.connectors.get(connector_key)
448-
449-
sent_via_connector = False
450-
if connector:
451-
sent_via_connector = try_send_via_connector(
452-
connector=connector,
453-
stage_id=stage_id,
454-
next_stage_id=next_stage_id,
455-
req_id=req_id,
456-
next_inputs=next_inputs,
457-
sampling_params=sp_next,
458-
original_prompt=prompt,
459-
next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit,
460-
metrics=metrics,
461-
)
462-
463-
if not sent_via_connector:
464-
# Fallback logic removed as we now enforce connector usage.
465-
# If no connector is found or send fails, we log an error and raise,
466-
# because continuing would cause the request to be silently dropped
467-
# and the orchestrator to hang waiting for completion.
468-
error_msg = (
469-
f"[Orchestrator] Failed to send request {req_id} to stage-{next_stage_id} via connector. "
470-
"Configure a connector for this edge or inspect connector logs for details."
471-
)
472-
logger.error(error_msg)
473-
raise RuntimeError(error_msg)
474-
logger.debug(
475-
"[Orchestrator] Forwarded request %s to stage-%s",
476-
req_id,
477-
next_stage_id,
466+
if not sent_via_connector:
467+
# Fallback logic removed as we now enforce connector usage.
468+
# If no connector is found or send fails, we log an error and raise,
469+
# because continuing would cause the request to be silently dropped
470+
# and the orchestrator to hang waiting for completion.
471+
error_msg = (
472+
f"[Orchestrator] Failed to send request {req_id} to stage-{next_stage_id} via connector. "
473+
"Configure a connector for this edge or inspect connector logs for details."
478474
)
479-
else:
480-
finished = True
481-
logger.debug("[Orchestrator] Request %s fully completed", req_id)
482-
483-
if not made_progress:
484-
time.sleep(0.005)
475+
logger.error(error_msg)
476+
raise RuntimeError(error_msg)
477+
logger.debug(
478+
"[Orchestrator] Forwarded request %s to stage-%s",
479+
req_id,
480+
next_stage_id,
481+
)
482+
else:
483+
logger.debug("[Orchestrator] Request %s fully completed", req_id)
485484

486485
logger.debug("[Orchestrator] All requests completed")
487486

@@ -491,6 +490,8 @@ async def generate(
491490
logger.info("[Summary] %s", summary)
492491
except Exception as e:
493492
logger.exception("[Orchestrator] Failed to build/log summary: %s", e)
493+
finally:
494+
self.request_states.pop(request_id, None)
494495

495496
def _wait_for_stages_ready(self, timeout: int = 120) -> None:
496497
num_stages = len(self.stage_list)
@@ -560,6 +561,51 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None:
560561
occurred while logging suggestions",
561562
)
562563

564+
def _run_output_handler(self) -> None:
565+
if self.output_handler is not None:
566+
return
567+
568+
stage_list = self.stage_list
569+
request_states = self.request_states
570+
571+
async def output_handler():
572+
try:
573+
while True:
574+
idle = True
575+
for stage_id, stage in enumerate(stage_list):
576+
result = stage.try_collect()
577+
if result is None:
578+
continue
579+
idle = False
580+
if result.get("type") == "stage_ready":
581+
# Only happens when stage is initialized slower than expected,
582+
# so we wait for a short time and try again
583+
await asyncio.sleep(0.05)
584+
continue
585+
req_id = result.get("request_id")
586+
req_state = request_states.get(req_id)
587+
if req_state is None:
588+
logger.debug(
589+
"[Orchestrator] Request may have been aborted; \
590+
dropping output for req %s at stage-%s ",
591+
req_id,
592+
stage_id,
593+
)
594+
continue
595+
await req_state.queue.put(result)
596+
req_state.stage_id = stage_id
597+
if idle:
598+
await asyncio.sleep(0.001) # Avoid CPU overload when idle
599+
else:
600+
await asyncio.sleep(0)
601+
except Exception as e:
602+
logger.exception("AsyncOmni output_handler failed.")
603+
for req_state in request_states.values():
604+
await req_state.queue.put({"request_id": req_id, "error": str(e)})
605+
self.output_handler = None # Make possible for restart
606+
607+
self.output_handler = asyncio.create_task(output_handler())
608+
563609
@property
564610
def is_running(self) -> bool:
565611
# Is None before the loop is started.
@@ -782,7 +828,7 @@ def __init__(
782828
)
783829
self.logger_manager.log_engine_initialized()
784830

785-
self.output_handler: Optional[asyncio.Task] = None
831+
self.output_handler: asyncio.Task | None = None
786832
try:
787833
# Start output handler eagerly if we are in the asyncio eventloop.
788834
asyncio.get_running_loop()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import asyncio
2+
3+
4+
class ClientRequestState:
5+
"""Tracks the state of an individual request in the orchestrator."""
6+
7+
def __init__(self, request_id: str, queue: asyncio.Queue | None = None):
8+
self.request_id = request_id
9+
self.stage_id: int | None = None
10+
self.queue = queue if queue is not None else asyncio.Queue()

0 commit comments

Comments
 (0)