1616import logging
1717import multiprocessing as mp
1818import os
19+ import queue
1920import sys
2021import traceback
2122from typing import Any
@@ -1135,19 +1136,14 @@ def filter(self, record: _logging.LogRecord) -> bool:
11351136 )
11361137 except Exception as e :
11371138 _logging .getLogger (__name__ ).warning ("[Stage-%s] Failed to send stage ready signal: %s" , stage_id , e )
1138-
1139+ generation_out_q = asyncio . Queue ()
11391140 # Batch processing loop
1140- while True :
1141- task = in_q .get ()
1142- _recv_dequeue_ts = _time .time ()
1143- if task is None :
1144- _logging .getLogger (__name__ ).debug ("[Stage-%s] Received shutdown signal" , stage_id )
1145- break
1146-
1147- _rx_bytes_by_rid : dict [Any , int ] = {}
1148- _rx_decode_ms_by_rid : dict [Any , float ] = {}
1149- _in_flight_ms_by_rid : dict [Any , float ] = {}
1141+ _rx_bytes_by_rid : dict [Any , int ] = {}
1142+ _rx_decode_ms_by_rid : dict [Any , float ] = {}
1143+ _in_flight_ms_by_rid : dict [Any , float ] = {}
11501144
1145+ async def generation_single_request (task : dict [str , Any ]):
1146+ _recv_dequeue_ts = _time .time ()
11511147 rid = task ["request_id" ]
11521148 try :
11531149 sent_ts = float (task .get ("sent_ts" , None )) if isinstance (task , dict ) else None
@@ -1157,62 +1153,101 @@ def filter(self, record: _logging.LogRecord) -> bool:
11571153 _in_flight_ms_by_rid [rid ] = 0.0
11581154 except Exception :
11591155 _in_flight_ms_by_rid [rid ] = 0.0
1160- ein , _rx_metrics = try_recv_via_connector (
1161- task = task ,
1162- connectors = connectors ,
1163- stage_id = stage_id ,
1164- )
1165- if ein is None or _rx_metrics is None :
1166- raise RuntimeError (
1167- f"[Stage-{ stage_id } ] Missing connector payload for request { rid } . "
1168- "Ensure connectors are configured for all incoming edges."
1169- )
1170- _rx_decode_ms_by_rid [rid ] = float (_rx_metrics .get ("rx_decode_time_ms" , 0.0 ))
1171- _rx_bytes_by_rid [rid ] = int (_rx_metrics .get ("rx_transfer_bytes" , 0 ))
1172-
1173- sampling_params = task ["sampling_params" ]
1174- _logging .getLogger (__name__ ).debug ("[Stage-%s] Received batch size=1, request_ids=%s" , stage_id , rid )
1175- print ("--------------------------------" , flush = True )
1176- print (f"[Stage-{ stage_id } ] Received batch size=1, request_ids={ rid } " , flush = True )
1177- print ("--------------------------------" , flush = True )
11781156 try :
1179- _batch_seq += 1
1157+ ein , _rx_metrics = try_recv_via_connector (
1158+ task = task ,
1159+ connectors = connectors ,
1160+ stage_id = stage_id ,
1161+ )
1162+ if ein is None or _rx_metrics is None :
1163+ raise RuntimeError (
1164+ f"[Stage-{ stage_id } ] Missing connector payload for request { rid } . "
1165+ "Ensure connectors are configured for all incoming edges."
1166+ )
1167+ _rx_decode_ms_by_rid [rid ] = float (_rx_metrics .get ("rx_decode_time_ms" , 0.0 ))
1168+ _rx_bytes_by_rid [rid ] = int (_rx_metrics .get ("rx_transfer_bytes" , 0 ))
1169+
1170+ sampling_params = task ["sampling_params" ]
1171+ _logging .getLogger (__name__ ).debug ("[Stage-%s] Received batch size=1, request_ids=%s" , stage_id , rid )
1172+ print ("--------------------------------" , flush = True )
1173+ print (f"[Stage-{ stage_id } ] Received batch size=1, request_ids={ rid } " , flush = True )
1174+ print ("--------------------------------" , flush = True )
11801175 _gen_t0 = _time .time ()
11811176 if isinstance (ein , list ):
11821177 ein = ein [0 ]
1183-
11841178 async for res in stage_engine .generate (ein , sampling_params , rid ):
11851179 gen_output = res
11861180 _gen_t1 = _time .time ()
11871181 _gen_ms = (_gen_t1 - _gen_t0 ) * 1000.0
1182+ await generation_out_q .put ((rid , gen_output , _gen_ms ))
1183+ except Exception as e :
1184+ _logging .getLogger (__name__ ).exception ("[Stage-%s] Failed on request %s: %s" , stage_id , rid , e )
1185+ out_q .put (
1186+ {
1187+ "request_id" : rid ,
1188+ "stage_id" : stage_id ,
1189+ "error" : str (e ),
1190+ }
1191+ )
11881192
1189- r_outputs = [gen_output ]
1190- _num_tokens = count_tokens_from_outputs (r_outputs )
1191- _agg_total_tokens += _num_tokens
1192- _agg_total_gen_time_ms += _gen_ms
1193-
1194- if _stats_file :
1195- _avg_tokens_per_s = (
1196- (_agg_total_tokens * 1000.0 / _agg_total_gen_time_ms ) if _agg_total_gen_time_ms > 0 else 0.0
1197- )
1198- log_stage_running_avg (
1199- _stats_file ,
1200- stage_id ,
1201- int (_agg_total_tokens ),
1202- float (_agg_total_gen_time_ms ),
1203- float (_avg_tokens_per_s ),
1204- )
1193+ _batch_gen_t0 = _time .time ()
1194+ while True :
1195+ try :
1196+ task = in_q .get_nowait ()
1197+ if task is None :
1198+ _logging .getLogger (__name__ ).debug ("[Stage-%s] Received shutdown signal" , stage_id )
1199+ break
1200+ asyncio .create_task (generation_single_request (task ))
1201+ except queue .Empty :
1202+ await asyncio .sleep (0.001 )
1203+ batch_request_outputs : list [Any ] = []
1204+ batch_request_ids : list [Any ] = []
1205+ _gen_ms_list = []
1206+ while True :
1207+ try :
1208+ rids , gen_output , _gen_ms = generation_out_q .get_nowait ()
1209+ _num_tokens = count_tokens_from_outputs ([gen_output ])
1210+ batch_request_outputs .append (gen_output )
1211+ _gen_ms_list .append (_gen_ms )
1212+ batch_request_ids .append (rids )
1213+ _agg_total_tokens += _num_tokens
1214+ except asyncio .QueueEmpty :
1215+ await asyncio .sleep (0.001 )
1216+ break
1217+
1218+ if not batch_request_outputs :
1219+ continue
1220+ _batch_seq += 1
1221+ if _stats_file :
1222+ _batch_gen_t1 = _time .time ()
1223+ _agg_total_gen_time_ms += (_batch_gen_t1 - _batch_gen_t0 ) * 1000
1224+ _batch_gen_t0 = _batch_gen_t1
1225+ _avg_tokens_per_s = (
1226+ (_agg_total_tokens * 1000.0 / _agg_total_gen_time_ms ) if _agg_total_gen_time_ms > 0 else 0.0
1227+ )
1228+ log_stage_running_avg (
1229+ _stats_file ,
1230+ stage_id ,
1231+ int (_agg_total_tokens ),
1232+ float (_agg_total_gen_time_ms ),
1233+ float (_avg_tokens_per_s ),
1234+ )
1235+ logger .info ("[Stage-%s] Running avg: %s tokens/s" , stage_id , _avg_tokens_per_s )
1236+ for rid , _gen_ms in zip (batch_request_ids , _gen_ms_list ):
12051237 log_stage_batch_stats (_stats_file , stage_id , 1 , float (_gen_ms ), [rid ])
12061238
1239+ logger .info ("[Stage-%s] Sending outputs to main process" , stage_id )
1240+ for rid , output , _gen_ms in zip (batch_request_ids , batch_request_outputs , _gen_ms_list ):
12071241 try :
1242+ r_outputs = [output ]
12081243 use_shm , payload = maybe_dump_to_shm (r_outputs , shm_threshold_bytes )
12091244 _metrics = {
12101245 "num_tokens_out" : int (count_tokens_from_outputs (r_outputs )),
12111246 "stage_gen_time_ms" : _gen_ms ,
12121247 "batch_id" : int (_batch_seq ),
1213- "rx_decode_time_ms" : float (_rx_decode_ms_by_rid .get (rid , 0.0 )),
1214- "rx_transfer_bytes" : int (_rx_bytes_by_rid .get (rid , 0 )),
1215- "rx_in_flight_time_ms" : float (_in_flight_ms_by_rid .get (rid , 0.0 )),
1248+ "rx_decode_time_ms" : float (_rx_decode_ms_by_rid .pop (rid , 0.0 )),
1249+ "rx_transfer_bytes" : int (_rx_bytes_by_rid .pop (rid , 0 )),
1250+ "rx_in_flight_time_ms" : float (_in_flight_ms_by_rid .pop (rid , 0.0 )),
12161251 }
12171252 if _stats_file :
12181253 compute_and_log_stage_request_stats (
@@ -1266,23 +1301,14 @@ def filter(self, record: _logging.LogRecord) -> bool:
12661301 "metrics" : {
12671302 "num_tokens_out" : int (count_tokens_from_outputs (r_outputs )),
12681303 "stage_gen_time_ms" : _gen_ms ,
1269- "rx_decode_time_ms" : float (_rx_decode_ms_by_rid .get (rid , 0.0 )),
1270- "rx_transfer_bytes" : int (_rx_bytes_by_rid .get (rid , 0 )),
1271- "rx_in_flight_time_ms" : float (_in_flight_ms_by_rid .get (rid , 0.0 )),
1304+ "rx_decode_time_ms" : float (_rx_decode_ms_by_rid .pop (rid , 0.0 )),
1305+ "rx_transfer_bytes" : int (_rx_bytes_by_rid .pop (rid , 0 )),
1306+ "rx_in_flight_time_ms" : float (_in_flight_ms_by_rid .pop (rid , 0.0 )),
12721307 },
12731308 }
12741309 )
12751310 _logging .getLogger (__name__ ).debug ("[Stage-%s] Enqueued result for request %s to downstream" , stage_id , rid )
12761311
1277- except Exception as e :
1278- _logging .getLogger (__name__ ).exception ("[Stage-%s] Failed on request %s: %s" , stage_id , rid , e )
1279- out_q .put (
1280- {
1281- "request_id" : rid ,
1282- "stage_id" : stage_id ,
1283- "error" : str (e ),
1284- }
1285- )
12861312 print ("--------------------------------" , flush = True )
12871313 print (f"[Stage-{ stage_id } ] Stage worker exiting" , flush = True )
12881314 print ("--------------------------------" , flush = True )
0 commit comments