4646from vllm_omni .engine .arg_utils import AsyncOmniEngineArgs
4747from vllm_omni .engine .output_processor import MultimodalOutputProcessor
4848from vllm_omni .engine .processor import OmniProcessor
49+ from vllm_omni .entrypoints .client_request_state import ClientRequestState
4950from vllm_omni .entrypoints .log_utils import (
5051 OrchestratorMetrics ,
5152 configure_orchestrator_logger ,
@@ -122,6 +123,8 @@ def __init__(
122123 self .stage_configs = load_stage_configs_from_yaml (cli_args .stage_configs_path , base_engine_args )
123124
124125 shm_threshold_bytes = cli_args .shm_threshold_bytes
126+ self .output_handler : Optional [asyncio .Task ] = None
127+ self .request_states : dict [str , ClientRequestState ] = {} # request_id -> state
125128
126129 # Initialize connectors
127130 self .omni_transfer_config , self .connectors = initialize_orchestrator_connectors (
@@ -238,6 +241,10 @@ def close(self) -> None:
238241 except Exception as e :
239242 logger .warning ("[Orchestrator] Failed to stop stage worker: %s" , e )
240243
244+ if self .output_handler is not None :
245+ self .output_handler .cancel ()
246+ self .output_handler = None
247+
241248 try_close_ray (self ._ray_pg )
242249
243250 def __del__ (self ) -> None : # best-effort
@@ -295,6 +302,10 @@ async def generate(
295302 ValueError: If sampling_params_list has incorrect length.
296303 """
297304 logger .debug ("[Orchestrator] generate() called" )
305+
306+ # Start output handler on the first call to generate()
307+ self ._run_output_handler ()
308+
298309 if sampling_params_list is None :
299310 sampling_params_list = self .default_sampling_params_list
300311 if len (sampling_params_list ) != len (self .stage_list ):
@@ -339,6 +350,9 @@ async def generate(
339350 )
340351 # Seed stage-0 queue with all requests
341352 logger .debug ("[Orchestrator] Seeding request into stage-0" )
353+ req_state = ClientRequestState (request_id )
354+ self .request_states [request_id ] = req_state
355+
342356 # Mark first input time for stage-0
343357 metrics .stage_first_ts [0 ] = metrics .stage_first_ts [0 ] or time .time ()
344358
@@ -353,135 +367,120 @@ async def generate(
353367 logger .debug ("[Orchestrator] Enqueued request %s to stage-0" , request_id )
354368
355369 logger .debug ("[Orchestrator] Entering scheduling loop: stages=%d" , num_stages )
356- finished = False
357- while not finished :
358- made_progress = False
359- for stage_id , stage in enumerate (self .stage_list ):
360- result = stage .try_collect ()
361- if result is None :
362- continue
370+ for stage_id , stage in enumerate (self .stage_list ):
371+ result = await req_state .queue .get ()
372+ assert stage_id == req_state .stage_id
363373
364- made_progress = True
365- req_id = result .get ("request_id" )
366- if "error" in result :
367- logger .error (
368- "Stage %s error on request %s: %s" ,
369- stage_id ,
370- req_id ,
371- result ["error" ],
372- )
373- continue
374+ req_id = result .get ("request_id" )
375+ if "error" in result :
376+ logger .error (
377+ "Stage %s error on request %s: %s" ,
378+ stage_id ,
379+ req_id ,
380+ result ["error" ],
381+ )
382+ raise RuntimeError (result ) # Request Finished due to error
374383
375- if result .get ("type" ) == "stage_ready" :
376- # Only happens when stage is initialized slower than expected,
377- # so we wait for a short time and try again
378- time .sleep (0.05 )
379- continue
384+ engine_outputs = _load (result , obj_key = "engine_outputs" , shm_key = "engine_outputs_shm" )
385+ # Mark last output time for this stage whenever we receive outputs
386+ metrics .stage_last_ts [stage_id ] = max (metrics .stage_last_ts [stage_id ] or 0.0 , time .time ())
387+ try :
388+ _m = result .get ("metrics" )
389+ if _m is not None :
390+ metrics .on_stage_metrics (stage_id , req_id , _m )
391+ except Exception as e :
392+ logger .exception (
393+ "[Orchestrator] Failed to process metrics for stage %s, \
394+ req %s: %s" ,
395+ stage_id ,
396+ req_id ,
397+ e ,
398+ )
399+ logger .debug (
400+ "[Orchestrator] Stage-%s completed request %s; \
401+ forwarding or finalizing" ,
402+ stage_id ,
403+ req_id ,
404+ )
405+ stage .set_engine_outputs (engine_outputs )
380406
381- engine_outputs = _load (result , obj_key = "engine_outputs" , shm_key = "engine_outputs_shm" )
382- # Mark last output time for this stage whenever we receive outputs
383- metrics .stage_last_ts [stage_id ] = max (metrics .stage_last_ts [stage_id ] or 0.0 , time .time ())
407+ if getattr (stage , "final_output" , False ):
408+ logger .debug (
409+ "[Orchestrator] Request %s finalized at stage-%s" ,
410+ req_id ,
411+ stage_id ,
412+ )
413+
414+ # End-to-end timing and time-per-token for final output
415+ # (only once per request at the designated final stage)
384416 try :
385- _m = result .get ("metrics" )
386- if _m is not None :
387- metrics .on_stage_metrics (stage_id , req_id , _m )
417+ rid_key = str (req_id )
418+ if stage_id == final_stage_id_for_e2e and rid_key not in metrics .e2e_done :
419+ metrics .on_finalize_request (
420+ stage_id ,
421+ req_id ,
422+ engine_outputs ,
423+ _req_start_ts .get (req_id , _wall_start_ts ),
424+ )
388425 except Exception as e :
389426 logger .exception (
390- "[Orchestrator] Failed to process metrics for stage %s, \
391- req %s: %s" ,
392- stage_id ,
427+ "[Orchestrator] Finalize request handling error for \
428+ req %s at stage %s: %s" ,
393429 req_id ,
430+ stage_id ,
394431 e ,
395432 )
396- logger .debug (
397- "[Orchestrator] Stage-%s completed request %s; \
398- forwarding or finalizing" ,
399- stage_id ,
400- req_id ,
433+
434+ if isinstance (engine_outputs , list ):
435+ engine_outputs = engine_outputs [0 ]
436+ yield OmniRequestOutput (
437+ stage_id = stage_id ,
438+ final_output_type = stage .final_output_type ,
439+ request_output = engine_outputs ,
401440 )
402- stage .set_engine_outputs (engine_outputs )
403441
404- if getattr (stage , "final_output" , False ):
405- logger .debug (
406- "[Orchestrator] Request %s finalized at stage-%s" ,
407- req_id ,
408- stage_id ,
409- )
442+ next_stage_id = stage_id + 1
443+ if next_stage_id < num_stages :
444+ next_stage : OmniStage = self .stage_list [next_stage_id ]
445+ next_inputs = next_stage .process_engine_inputs (self .stage_list , prompt )
446+ sp_next : SamplingParams = sampling_params_list [next_stage_id ]
410447
411- # End-to-end timing and time-per-token for final output
412- # (only once per request at the designated final stage)
413- try :
414- rid_key = str (req_id )
415- if stage_id == final_stage_id_for_e2e and rid_key not in metrics .e2e_done :
416- metrics .on_finalize_request (
417- stage_id ,
418- req_id ,
419- engine_outputs ,
420- _req_start_ts .get (req_id , _wall_start_ts ),
421- )
422- except Exception as e :
423- logger .exception (
424- "[Orchestrator] Finalize request handling error for \
425- req %s at stage %s: %s" ,
426- req_id ,
427- stage_id ,
428- e ,
429- )
448+ # Check if we have a connector for this edge
449+ connector_key = (str (stage_id ), str (next_stage_id ))
450+ connector = self .connectors .get (connector_key )
430451
431- if isinstance (engine_outputs , list ):
432- engine_outputs = engine_outputs [0 ]
433- yield OmniRequestOutput (
452+ sent_via_connector = False
453+ if connector :
454+ sent_via_connector = try_send_via_connector (
455+ connector = connector ,
434456 stage_id = stage_id ,
435- final_output_type = stage .final_output_type ,
436- request_output = engine_outputs ,
457+ next_stage_id = next_stage_id ,
458+ req_id = req_id ,
459+ next_inputs = next_inputs ,
460+ sampling_params = sp_next ,
461+ original_prompt = prompt ,
462+ next_stage_queue_submit_fn = self .stage_list [next_stage_id ].submit ,
463+ metrics = metrics ,
437464 )
438465
439- next_stage_id = stage_id + 1
440- if next_stage_id < num_stages :
441- next_stage : OmniStage = self .stage_list [next_stage_id ]
442- next_inputs = next_stage .process_engine_inputs (self .stage_list , prompt )
443- sp_next : SamplingParams = sampling_params_list [next_stage_id ]
444-
445- # Check if we have a connector for this edge
446- connector_key = (str (stage_id ), str (next_stage_id ))
447- connector = self .connectors .get (connector_key )
448-
449- sent_via_connector = False
450- if connector :
451- sent_via_connector = try_send_via_connector (
452- connector = connector ,
453- stage_id = stage_id ,
454- next_stage_id = next_stage_id ,
455- req_id = req_id ,
456- next_inputs = next_inputs ,
457- sampling_params = sp_next ,
458- original_prompt = prompt ,
459- next_stage_queue_submit_fn = self .stage_list [next_stage_id ].submit ,
460- metrics = metrics ,
461- )
462-
463- if not sent_via_connector :
464- # Fallback logic removed as we now enforce connector usage.
465- # If no connector is found or send fails, we log an error and raise,
466- # because continuing would cause the request to be silently dropped
467- # and the orchestrator to hang waiting for completion.
468- error_msg = (
469- f"[Orchestrator] Failed to send request { req_id } to stage-{ next_stage_id } via connector. "
470- "Configure a connector for this edge or inspect connector logs for details."
471- )
472- logger .error (error_msg )
473- raise RuntimeError (error_msg )
474- logger .debug (
475- "[Orchestrator] Forwarded request %s to stage-%s" ,
476- req_id ,
477- next_stage_id ,
466+ if not sent_via_connector :
467+ # Fallback logic removed as we now enforce connector usage.
468+ # If no connector is found or send fails, we log an error and raise,
469+ # because continuing would cause the request to be silently dropped
470+ # and the orchestrator to hang waiting for completion.
471+ error_msg = (
472+ f"[Orchestrator] Failed to send request { req_id } to stage-{ next_stage_id } via connector. "
473+ "Configure a connector for this edge or inspect connector logs for details."
478474 )
479- else :
480- finished = True
481- logger .debug ("[Orchestrator] Request %s fully completed" , req_id )
482-
483- if not made_progress :
484- time .sleep (0.005 )
475+ logger .error (error_msg )
476+ raise RuntimeError (error_msg )
477+ logger .debug (
478+ "[Orchestrator] Forwarded request %s to stage-%s" ,
479+ req_id ,
480+ next_stage_id ,
481+ )
482+ else :
483+ logger .debug ("[Orchestrator] Request %s fully completed" , req_id )
485484
486485 logger .debug ("[Orchestrator] All requests completed" )
487486
@@ -491,6 +490,8 @@ async def generate(
491490 logger .info ("[Summary] %s" , summary )
492491 except Exception as e :
493492 logger .exception ("[Orchestrator] Failed to build/log summary: %s" , e )
493+ finally :
494+ self .request_states .pop (request_id , None )
494495
495496 def _wait_for_stages_ready (self , timeout : int = 120 ) -> None :
496497 num_stages = len (self .stage_list )
@@ -560,6 +561,51 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None:
560561 occurred while logging suggestions" ,
561562 )
562563
564+ def _run_output_handler (self ) -> None :
565+ if self .output_handler is not None :
566+ return
567+
568+ stage_list = self .stage_list
569+ request_states = self .request_states
570+
571+ async def output_handler ():
572+ try :
573+ while True :
574+ idle = True
575+ for stage_id , stage in enumerate (stage_list ):
576+ result = stage .try_collect ()
577+ if result is None :
578+ continue
579+ idle = False
580+ if result .get ("type" ) == "stage_ready" :
581+ # Only happens when stage is initialized slower than expected,
582+ # so we wait for a short time and try again
583+ await asyncio .sleep (0.05 )
584+ continue
585+ req_id = result .get ("request_id" )
586+ req_state = request_states .get (req_id )
587+ if req_state is None :
588+ logger .debug (
589+ "[Orchestrator] Request may have been aborted; \
590+ dropping output for req %s at stage-%s " ,
591+ req_id ,
592+ stage_id ,
593+ )
594+ continue
595+ await req_state .queue .put (result )
596+ req_state .stage_id = stage_id
597+ if idle :
598+ await asyncio .sleep (0.001 ) # Avoid CPU overload when idle
599+ else :
600+ await asyncio .sleep (0 )
601+ except Exception as e :
602+ logger .exception ("AsyncOmni output_handler failed." )
603+ for req_state in request_states .values ():
604+ await req_state .queue .put ({"request_id" : req_id , "error" : str (e )})
605+ self .output_handler = None # Make possible for restart
606+
607+ self .output_handler = asyncio .create_task (output_handler ())
608+
563609 @property
564610 def is_running (self ) -> bool :
565611 # Is None before the loop is started.
@@ -782,7 +828,7 @@ def __init__(
782828 )
783829 self .logger_manager .log_engine_initialized ()
784830
785- self .output_handler : Optional [ asyncio .Task ] = None
831+ self .output_handler : asyncio .Task | None = None
786832 try :
787833 # Start output handler eagerly if we are in the asyncio eventloop.
788834 asyncio .get_running_loop ()
0 commit comments