@@ -529,14 +529,6 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
529529 # Start output handler on the first call to generate()
530530 self ._run_output_handler ()
531531
532- if sampling_params_list is None :
533- sampling_params_list = self .default_sampling_params_list
534- if len (sampling_params_list ) != len (self .stage_list ):
535- raise ValueError (
536- f"Expected { len (self .stage_list )} sampling params, \
537- got { len (sampling_params_list )} "
538- )
539-
540532 prompt = args [0 ] if args else kwargs .get ("prompt" )
541533 request_id = args [1 ] if len (args ) > 1 else kwargs .get ("request_id" )
542534 sampling_params_list = args [2 ] if len (args ) > 2 else kwargs .get ("sampling_params_list" )
@@ -597,6 +589,7 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
597589
598590 sp0 : SamplingParams = sampling_params_list [0 ] # type: ignore[index]
599591 task = {
592+ "type" : OmniStageTaskType .GENERATE ,
600593 "request_id" : request_id ,
601594 "engine_inputs" : prompt ,
602595 "sampling_params" : sp0 ,
@@ -639,10 +632,7 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
639632 stage_id ,
640633 req_id ,
641634 )
642- # Seed stage-0 queue with all requests
643- logger .debug ("[Orchestrator] Seeding request into stage-0" )
644- req_state = ClientRequestState (request_id )
645- self .request_states [request_id ] = req_state
635+ stage .set_engine_outputs (engine_outputs )
646636
647637 if getattr (stage , "final_output" , False ):
648638 logger .debug (
@@ -651,45 +641,22 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
651641 stage_id ,
652642 )
653643
654- sp0 : SamplingParams = sampling_params_list [0 ] # type: ignore[index]
655- task = {
656- "type" : OmniStageTaskType .GENERATE ,
657- "request_id" : request_id ,
658- "engine_inputs" : prompt ,
659- "sampling_params" : sp0 ,
660- }
661- self .stage_list [0 ].submit (task )
662- _req_start_ts [request_id ] = time .time ()
663- logger .debug ("[Orchestrator] Enqueued request %s to stage-0" , request_id )
664-
665- logger .debug ("[Orchestrator] Entering scheduling loop: stages=%d" , num_stages )
666- for stage_id , stage in enumerate (self .stage_list [: final_stage_id_for_e2e + 1 ]):
667- result = await req_state .queue .get ()
668- assert stage_id == req_state .stage_id
669-
670- req_id = result .get ("request_id" )
671- if "error" in result :
672- logger .error (
673- "Stage %s error on request %s: %s" ,
674- stage_id ,
675- req_id ,
676- result ["error" ],
677- )
678- raise RuntimeError (result ) # Request Finished due to error
679-
680- engine_outputs = _load (result , obj_key = "engine_outputs" , shm_key = "engine_outputs_shm" )
681- # Mark last output time for this stage whenever we receive outputs
682- metrics .stage_last_ts [stage_id ] = max (metrics .stage_last_ts [stage_id ] or 0.0 , time .time ())
644+ # End-to-end timing and time-per-token for final output
645+ # (only once per request at the designated final stage)
683646 try :
684- _m = asdict (result .get ("metrics" ))
685- if _m is not None :
686- metrics .on_stage_metrics (stage_id , req_id , _m )
647+ rid_key = str (req_id )
648+ if stage_id == final_stage_id_for_e2e and rid_key not in metrics .e2e_done :
649+ metrics .on_finalize_request (
650+ stage_id ,
651+ req_id ,
652+ engine_outputs ,
653+ _req_start_ts .get (req_id , _wall_start_ts ),
654+ )
687655 except Exception as e :
688656 logger .exception (
689657 "[AsyncOrchestrator] Finalize request handling error for req %s at stage %s: %s" ,
690658 req_id ,
691659 stage_id ,
692- req_id ,
693660 e ,
694661 )
695662
@@ -754,6 +721,7 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
754721 logger .debug (
755722 "[AsyncOrchestrator] Forwarded request %s to stage-%s" ,
756723 req_id ,
724+ next_stage_id ,
757725 )
758726 else :
759727 logger .debug ("[AsyncOrchestrator] Request %s fully completed" , req_id )
@@ -770,7 +738,7 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
770738 self .request_states .pop (request_id , None )
771739 except (asyncio .CancelledError , GeneratorExit ):
772740 await self .abort (request_id )
773- print ( " Request %s aborted." , request_id )
741+ logger . exception ( "[AsyncOrchestrator] Request %s aborted." , request_id )
774742 raise
775743
776744 def _wait_for_stages_ready (self , timeout : int = 120 ) -> None :
0 commit comments