Skip to content

Commit 699eaa1

Browse files
author
zhongwei.ren
committed
[Benchmark] add auto-run param for hicache/bench_multiturn
1 parent e50109f commit 699eaa1

File tree

1 file changed

+60
-5
lines changed

1 file changed

+60
-5
lines changed

benchmark/hicache/bench_multiturn.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import aiohttp
1212
import requests
1313
from tqdm.asyncio import tqdm
14+
import numpy as np
1415

1516
from sglang.bench_serving import (
1617
RequestFuncOutput,
@@ -97,6 +98,30 @@ def parse_args():
9798
default="performance_metrics.jsonl",
9899
help="File to log performance metrics",
99100
)
101+
parser.add_argument(
102+
"--disable-auto-run",
103+
action="store_true",
104+
help="If set, disable automatically testing with a range of request rates.",
105+
)
106+
107+
parser.add_argument(
108+
"--disable-random-sample",
109+
action="store_true",
110+
help="If set, disable random sampling of requests from the ShareGPT dataset.",
111+
)
112+
parser.add_argument(
113+
"--sub-question-input-length",
114+
type=int,
115+
default=0,
116+
help="Length of the sub question input for each request, if set 0 use request_length",
117+
)
118+
parser.add_argument(
119+
"--ready-queue-policy",
120+
type=str,
121+
default="random",
122+
help="Policy for popping requests from the ready queue (random or fifo)",
123+
)
124+
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
100125
return parser.parse_args()
101126

102127

@@ -234,13 +259,29 @@ def __init__(self, args):
234259
self.candidate_inputs = sample_random_requests(
235260
input_len=args.request_length,
236261
output_len=args.output_length,
237-
num_prompts=args.num_clients * args.num_rounds,
262+
num_prompts=args.num_clients,
238263
range_ratio=1.0,
239264
tokenizer=self.tokenizer,
240265
dataset_path=args.dataset_path,
266+
random_sample=not args.disable_random_sample,
241267
)
242268
self.candidate_inputs = [i.prompt for i in self.candidate_inputs]
243269

270+
if args.sub_question_input_length != 0:
271+
sub_question_input_length = args.sub_question_input_length
272+
else:
273+
sub_question_input_length = args.request_length
274+
275+
self.sub_question_inputs = sample_random_requests(
276+
input_len=sub_question_input_length,
277+
output_len=args.output_length,
278+
num_prompts=args.num_clients * max(args.num_rounds - 1, 1),
279+
range_ratio=1.0,
280+
tokenizer=self.tokenizer,
281+
dataset_path=args.dataset_path,
282+
random_sample=not args.disable_random_sample,
283+
)
284+
244285
init_requests = [
245286
(i, gen_payload(self.candidate_inputs[i], args.output_length))
246287
for i in range(args.num_clients)
@@ -249,7 +290,7 @@ def __init__(self, args):
249290
i: {"round": 0, "history": init_requests[i][1]["text"]}
250291
for i in range(args.num_clients)
251292
}
252-
self.ready_queue = ReadyQueue(init_requests=init_requests)
293+
self.ready_queue = ReadyQueue(init_requests=init_requests, policy=args.ready_queue_policy)
253294
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
254295

255296
self.response_queue = queue.Queue()
@@ -314,9 +355,10 @@ def response_handler(self):
314355
self.completed_requests += 1
315356

316357
if self.client_records[client_id]["round"] < args.num_rounds:
358+
# append new request to client's history
317359
self.client_records[client_id][
318360
"history"
319-
] += self.candidate_inputs.pop()
361+
] += self.sub_question_inputs.pop()
320362
self.ready_queue.append(
321363
(
322364
client_id,
@@ -329,6 +371,9 @@ def response_handler(self):
329371
except queue.Empty:
330372
if self.pbar.n == self.pbar.total:
331373
break
374+
except ValueError as e:
375+
print(f"Error processing response for client {client_id}: {e}")
376+
continue
332377

333378
def run(self):
334379
request_thread = threading.Thread(target=self.request_sender, daemon=True)
@@ -388,8 +433,18 @@ def run(self):
388433
args = parse_args()
389434
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
390435

391-
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
392-
args.request_rate = request_rate
436+
random.seed(args.seed)
437+
np.random.seed(args.seed)
438+
439+
if args.disable_auto_run:
440+
print("Running with specified request rate...")
441+
request_rates = [args.request_rate]
442+
else:
443+
print("Auto-running with different request rates...")
444+
request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
445+
446+
for rate in request_rates:
447+
args.request_rate = rate
393448
requests.post(flush_cache_url)
394449
time.sleep(1)
395450
WorkloadGenerator(args).run()

0 commit comments

Comments
 (0)