1111import aiohttp
1212import requests
1313from tqdm .asyncio import tqdm
14+ import numpy as np
1415
1516from sglang .bench_serving import (
1617 RequestFuncOutput ,
@@ -97,6 +98,30 @@ def parse_args():
9798 default = "performance_metrics.jsonl" ,
9899 help = "File to log performance metrics" ,
99100 )
101+ parser .add_argument (
102+ "--disable-auto-run" ,
103+ action = "store_true" ,
104+ help = "If set, disable automatically testing with a range of request rates." ,
105+ )
106+
107+ parser .add_argument (
108+ "--disable-random-sample" ,
109+ action = "store_true" ,
110+ help = "If set, disable random sampling of requests from the ShareGPT dataset." ,
111+ )
112+ parser .add_argument (
113+ "--sub-question-input-length" ,
114+ type = int ,
115+ default = 0 ,
116+ help = "Length of the sub question input for each request, if set 0 use request_length" ,
117+ )
118+ parser .add_argument (
119+ "--ready-queue-policy" ,
120+ type = str ,
121+ default = "random" ,
122+ help = "Policy for popping requests from the ready queue (random or fifo)" ,
123+ )
124+ parser .add_argument ("--seed" , type = int , default = 1 , help = "The random seed." )
100125 return parser .parse_args ()
101126
102127
@@ -234,13 +259,29 @@ def __init__(self, args):
234259 self .candidate_inputs = sample_random_requests (
235260 input_len = args .request_length ,
236261 output_len = args .output_length ,
237- num_prompts = args .num_clients * args . num_rounds ,
262+ num_prompts = args .num_clients ,
238263 range_ratio = 1.0 ,
239264 tokenizer = self .tokenizer ,
240265 dataset_path = args .dataset_path ,
266+ random_sample = not args .disable_random_sample ,
241267 )
242268 self .candidate_inputs = [i .prompt for i in self .candidate_inputs ]
243269
270+ if args .sub_question_input_length != 0 :
271+ sub_question_input_length = args .sub_question_input_length
272+ else :
273+ sub_question_input_length = args .request_length
274+
275+ self .sub_question_inputs = sample_random_requests (
276+ input_len = sub_question_input_length ,
277+ output_len = args .output_length ,
278+ num_prompts = args .num_clients * max (args .num_rounds - 1 , 1 ),
279+ range_ratio = 1.0 ,
280+ tokenizer = self .tokenizer ,
281+ dataset_path = args .dataset_path ,
282+ random_sample = not args .disable_random_sample ,
283+ )
284+
244285 init_requests = [
245286 (i , gen_payload (self .candidate_inputs [i ], args .output_length ))
246287 for i in range (args .num_clients )
@@ -249,7 +290,7 @@ def __init__(self, args):
249290 i : {"round" : 0 , "history" : init_requests [i ][1 ]["text" ]}
250291 for i in range (args .num_clients )
251292 }
252- self .ready_queue = ReadyQueue (init_requests = init_requests )
293+ self .ready_queue = ReadyQueue (init_requests = init_requests , policy = args . ready_queue_policy )
253294 self .candidate_inputs = self .candidate_inputs [args .num_clients :]
254295
255296 self .response_queue = queue .Queue ()
@@ -314,9 +355,10 @@ def response_handler(self):
314355 self .completed_requests += 1
315356
316357 if self .client_records [client_id ]["round" ] < args .num_rounds :
358+ # append new request to client's history
317359 self .client_records [client_id ][
318360 "history"
319- ] += self .candidate_inputs .pop ()
361+ ] += self .sub_question_inputs .pop ()
320362 self .ready_queue .append (
321363 (
322364 client_id ,
@@ -329,6 +371,9 @@ def response_handler(self):
329371 except queue .Empty :
330372 if self .pbar .n == self .pbar .total :
331373 break
374+ except ValueError as e :
375+ print (f"Error processing response for client { client_id } : { e } " )
376+ continue
332377
333378 def run (self ):
334379 request_thread = threading .Thread (target = self .request_sender , daemon = True )
@@ -388,8 +433,18 @@ def run(self):
388433 args = parse_args ()
389434 flush_cache_url = f"http://{ args .host } :{ args .port } /flush_cache"
390435
391- for request_rate in [16 , 14 , 12 , 10 , 9 , 8 , 7 , 6 , 5 , 4 , 3 , 2 , 1 ]:
392- args .request_rate = request_rate
436+ random .seed (args .seed )
437+ np .random .seed (args .seed )
438+
439+ if args .disable_auto_run :
440+ print ("Running with specified request rate..." )
441+ request_rates = [args .request_rate ]
442+ else :
443+ print ("Auto-running with different request rates..." )
444+ request_rates = [16 , 14 , 12 , 10 , 9 , 8 , 7 , 6 , 5 , 4 , 3 , 2 , 1 ]
445+
446+ for rate in request_rates :
447+ args .request_rate = rate
393448 requests .post (flush_cache_url )
394449 time .sleep (1 )
395450 WorkloadGenerator (args ).run ()
0 commit comments