Skip to content

Commit 2894e47

Browse files
committed
fix
Signed-off-by: ZeldaHuang <[email protected]>
1 parent 81e0920 commit 2894e47

1 file changed

Lines changed: 14 additions & 46 deletions

File tree

vllm_omni/entrypoints/async_omni.py

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -529,14 +529,6 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
529529
# Start output handler on the first call to generate()
530530
self._run_output_handler()
531531

532-
if sampling_params_list is None:
533-
sampling_params_list = self.default_sampling_params_list
534-
if len(sampling_params_list) != len(self.stage_list):
535-
raise ValueError(
536-
f"Expected {len(self.stage_list)} sampling params, \
537-
got {len(sampling_params_list)}"
538-
)
539-
540532
prompt = args[0] if args else kwargs.get("prompt")
541533
request_id = args[1] if len(args) > 1 else kwargs.get("request_id")
542534
sampling_params_list = args[2] if len(args) > 2 else kwargs.get("sampling_params_list")
@@ -597,6 +589,7 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
597589

598590
sp0: SamplingParams = sampling_params_list[0] # type: ignore[index]
599591
task = {
592+
"type": OmniStageTaskType.GENERATE,
600593
"request_id": request_id,
601594
"engine_inputs": prompt,
602595
"sampling_params": sp0,
@@ -639,10 +632,7 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
639632
stage_id,
640633
req_id,
641634
)
642-
# Seed stage-0 queue with all requests
643-
logger.debug("[Orchestrator] Seeding request into stage-0")
644-
req_state = ClientRequestState(request_id)
645-
self.request_states[request_id] = req_state
635+
stage.set_engine_outputs(engine_outputs)
646636

647637
if getattr(stage, "final_output", False):
648638
logger.debug(
@@ -651,45 +641,22 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
651641
stage_id,
652642
)
653643

654-
sp0: SamplingParams = sampling_params_list[0] # type: ignore[index]
655-
task = {
656-
"type": OmniStageTaskType.GENERATE,
657-
"request_id": request_id,
658-
"engine_inputs": prompt,
659-
"sampling_params": sp0,
660-
}
661-
self.stage_list[0].submit(task)
662-
_req_start_ts[request_id] = time.time()
663-
logger.debug("[Orchestrator] Enqueued request %s to stage-0", request_id)
664-
665-
logger.debug("[Orchestrator] Entering scheduling loop: stages=%d", num_stages)
666-
for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]):
667-
result = await req_state.queue.get()
668-
assert stage_id == req_state.stage_id
669-
670-
req_id = result.get("request_id")
671-
if "error" in result:
672-
logger.error(
673-
"Stage %s error on request %s: %s",
674-
stage_id,
675-
req_id,
676-
result["error"],
677-
)
678-
raise RuntimeError(result) # Request Finished due to error
679-
680-
engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm")
681-
# Mark last output time for this stage whenever we receive outputs
682-
metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time())
644+
# End-to-end timing and time-per-token for final output
645+
# (only once per request at the designated final stage)
683646
try:
684-
_m = asdict(result.get("metrics"))
685-
if _m is not None:
686-
metrics.on_stage_metrics(stage_id, req_id, _m)
647+
rid_key = str(req_id)
648+
if stage_id == final_stage_id_for_e2e and rid_key not in metrics.e2e_done:
649+
metrics.on_finalize_request(
650+
stage_id,
651+
req_id,
652+
engine_outputs,
653+
_req_start_ts.get(req_id, _wall_start_ts),
654+
)
687655
except Exception as e:
688656
logger.exception(
689657
"[AsyncOrchestrator] Finalize request handling error for req %s at stage %s: %s",
690658
req_id,
691659
stage_id,
692-
req_id,
693660
e,
694661
)
695662

@@ -754,6 +721,7 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
754721
logger.debug(
755722
"[AsyncOrchestrator] Forwarded request %s to stage-%s",
756723
req_id,
724+
next_stage_id,
757725
)
758726
else:
759727
logger.debug("[AsyncOrchestrator] Request %s fully completed", req_id)
@@ -770,7 +738,7 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
770738
self.request_states.pop(request_id, None)
771739
except (asyncio.CancelledError, GeneratorExit):
772740
await self.abort(request_id)
773-
print("Request %s aborted.", request_id)
741+
logger.exception("[AsyncOrchestrator] Request %s aborted.", request_id)
774742
raise
775743

776744
def _wait_for_stages_ready(self, timeout: int = 120) -> None:

0 commit comments

Comments
 (0)