Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 62 additions & 5 deletions benchmark/hicache/bench_multiturn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Optional

import aiohttp
import numpy as np
import requests
from tqdm.asyncio import tqdm

Expand Down Expand Up @@ -97,6 +98,30 @@ def parse_args():
default="performance_metrics.jsonl",
help="File to log performance metrics",
)
parser.add_argument(
"--disable-auto-run",
action="store_true",
help="If set, disable automatically testing with a range of request rates.",
)

parser.add_argument(
"--disable-random-sample",
action="store_true",
help="If set, disable random sampling of requests from the ShareGPT dataset.",
)
parser.add_argument(
"--sub-question-input-length",
type=int,
default=0,
help="Length of the sub question input for each request, if set 0 use request_length",
)
parser.add_argument(
"--ready-queue-policy",
type=str,
default="random",
help="Policy for popping requests from the ready queue (random or fifo)",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
return parser.parse_args()


Expand Down Expand Up @@ -234,13 +259,29 @@ def __init__(self, args):
self.candidate_inputs = sample_random_requests(
input_len=args.request_length,
output_len=args.output_length,
num_prompts=args.num_clients * args.num_rounds,
num_prompts=args.num_clients,
range_ratio=1.0,
tokenizer=self.tokenizer,
dataset_path=args.dataset_path,
random_sample=not args.disable_random_sample,
)
self.candidate_inputs = [i.prompt for i in self.candidate_inputs]

if args.sub_question_input_length != 0:
sub_question_input_length = args.sub_question_input_length
else:
sub_question_input_length = args.request_length

self.sub_question_inputs = sample_random_requests(
input_len=sub_question_input_length,
output_len=args.output_length,
num_prompts=args.num_clients * max(args.num_rounds - 1, 1),
range_ratio=1.0,
tokenizer=self.tokenizer,
dataset_path=args.dataset_path,
random_sample=not args.disable_random_sample,
)

init_requests = [
(i, gen_payload(self.candidate_inputs[i], args.output_length))
for i in range(args.num_clients)
Expand All @@ -249,7 +290,9 @@ def __init__(self, args):
i: {"round": 0, "history": init_requests[i][1]["text"]}
for i in range(args.num_clients)
}
self.ready_queue = ReadyQueue(init_requests=init_requests)
self.ready_queue = ReadyQueue(
init_requests=init_requests, policy=args.ready_queue_policy
)
self.candidate_inputs = self.candidate_inputs[args.num_clients :]

self.response_queue = queue.Queue()
Expand Down Expand Up @@ -314,9 +357,10 @@ def response_handler(self):
self.completed_requests += 1

if self.client_records[client_id]["round"] < args.num_rounds:
# append new request to client's history
self.client_records[client_id][
"history"
] += self.candidate_inputs.pop()
] += self.sub_question_inputs.pop()
self.ready_queue.append(
(
client_id,
Expand All @@ -329,6 +373,9 @@ def response_handler(self):
except queue.Empty:
if self.pbar.n == self.pbar.total:
break
except ValueError as e:
print(f"Error processing response for client {client_id}: {e}")
continue

def run(self):
request_thread = threading.Thread(target=self.request_sender, daemon=True)
Expand Down Expand Up @@ -388,8 +435,18 @@ def run(self):
args = parse_args()
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"

for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
args.request_rate = request_rate
random.seed(args.seed)
np.random.seed(args.seed)

if args.disable_auto_run:
print("Running with specified request rate...")
request_rates = [args.request_rate]
else:
print("Auto-running with different request rates...")
request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]

for rate in request_rates:
args.request_rate = rate
requests.post(flush_cache_url)
time.sleep(1)
WorkloadGenerator(args).run()
Loading