1+ """
2+ Benchmark the efficiency of prefix caching.
3+
4+ This script allows you to benchmark the performance of
5+ a model with and without prefix caching using either fixed prompts
6+ or prompts sampled from the ShareGPT dataset.
7+
8+ Fixed example usage:
9+ python benchmark_prefix_caching.py \
10+ --model meta-llama/Llama-2-7b-chat-hf \
11+ --enable-prefix-caching \
12+ --num-prompts 1 \
13+ --repeat-count 100
14+
15+ ShareGPT example usage:
16+ # This command samples 20 prompts with input lengths
17+ # between 128 and 256 tokens from the ShareGPT dataset,
18+ # then replicates each prompt 5 times.
19+ python benchmark_prefix_caching.py \
20+ --model meta-llama/Llama-2-7b-chat-hf \
21+ --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \
22+ --enable-prefix-caching \
23+ --num-prompts 20 \
24+ --repeat-count 5 \
25+ --input-length-range 128:256
26+ """
27+
28+ import json
29+ import random
130import time
31+ from typing import List , Optional , Tuple
32+
33+ from transformers import PreTrainedTokenizerBase
234
335from vllm import LLM , SamplingParams
436from vllm .utils import FlexibleArgumentParser
537
38+ try :
39+ from vllm .transformers_utils .tokenizer import get_tokenizer
40+ except ImportError :
41+ from backend_request_func import get_tokenizer
42+
643PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n # Table\n |Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n |----|----|----|----|----|----|----|----|\n |J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n |J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n |J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n |J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n |F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n |F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n |F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n |F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n |F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n |F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n |M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n |M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n |M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n |M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n |M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n \n # Question\n What' s the content in the (1,1) cells\n " # noqa: E501
744
845
@@ -15,7 +52,83 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
1552 print (f"cost time { end_time - start_time } " )
1653
1754
55+ def sample_requests (
56+ dataset_path : str ,
57+ num_requests : int ,
58+ tokenizer : PreTrainedTokenizerBase ,
59+ input_length_range : Tuple [int , int ],
60+ fixed_output_len : Optional [int ],
61+ ) -> List [Tuple [str , int , int ]]:
62+ if fixed_output_len is not None and fixed_output_len < 4 :
63+ raise ValueError ("output_len too small" )
64+
65+ # Load the dataset.
66+ with open (dataset_path ) as f :
67+ dataset = json .load (f )
68+ # Filter out the conversations with less than 2 turns.
69+ dataset = [data for data in dataset if len (data ["conversations" ]) >= 2 ]
70+ # Only keep the first two turns of each conversation.
71+ dataset = [(data ["conversations" ][0 ]["value" ],
72+ data ["conversations" ][1 ]["value" ]) for data in dataset ]
73+
74+ # Shuffle the dataset.
75+ random .shuffle (dataset )
76+
77+ min_len , max_len = input_length_range
78+
79+ # Filter out sequences that are too long or too short
80+ filtered_dataset : List [Tuple [str , int , int ]] = []
81+ for i in range (len (dataset )):
82+ if len (filtered_dataset ) == num_requests :
83+ break
84+
85+ # Tokenize the prompts and completions.
86+ prompt = dataset [i ][0 ]
87+ prompt_token_ids = tokenizer (prompt ).input_ids
88+ completion = dataset [i ][1 ]
89+ completion_token_ids = tokenizer (completion ).input_ids
90+ prompt_len = len (prompt_token_ids )
91+ output_len = len (completion_token_ids
92+ ) if fixed_output_len is None else fixed_output_len
93+ if prompt_len < 4 or output_len < 4 :
94+ # Prune too short sequences.
95+ continue
96+ if min_len <= prompt_len <= max_len :
97+ filtered_dataset .append ((prompt , prompt_len , output_len ))
98+
99+ return filtered_dataset
100+
101+
102+ def repeat_and_sort_requests (requests : List [Tuple [str , int , int ]],
103+ repeat_count : int ,
104+ sort : bool = False ) -> List [str ]:
105+ repeated_requests = requests * repeat_count
106+ if sort :
107+ repeated_requests .sort (key = lambda x : x [1 ])
108+ else :
109+ random .shuffle (repeated_requests )
110+ return [req [0 ] for req in repeated_requests ]
111+
112+
18113def main (args ):
114+ tokenizer = get_tokenizer (args .model , trust_remote_code = True )
115+ input_length_range = tuple (map (int , args .input_length_range .split (':' )))
116+
117+ if args .dataset_path is not None :
118+ print (f"Start to sample { args .num_prompts } prompts"
119+ "from {args.dataset_path}" )
120+ filtered_datasets = sample_requests (
121+ dataset_path = args .dataset_path ,
122+ num_requests = args .num_prompts ,
123+ tokenizer = tokenizer ,
124+ input_length_range = input_length_range ,
125+ fixed_output_len = args .output_len ,
126+ )
127+ else :
128+ prompt_len = len (tokenizer (PROMPT ).input_ids )
129+ filtered_datasets = [(PROMPT , prompt_len , args .output_len )
130+ ] * args .num_prompts
131+
19132 llm = LLM (model = args .model ,
20133 tokenizer_mode = 'auto' ,
21134 trust_remote_code = True ,
@@ -24,10 +137,13 @@ def main(args):
24137 tensor_parallel_size = args .tensor_parallel_size ,
25138 enable_prefix_caching = args .enable_prefix_caching )
26139
27- num_prompts = 100
28- prompts = [PROMPT ] * num_prompts
29140 sampling_params = SamplingParams (temperature = 0 , max_tokens = args .output_len )
30141
142+ print ("Testing filtered datasets" )
143+ prompts = repeat_and_sort_requests (filtered_datasets ,
144+ repeat_count = args .repeat_count ,
145+ sort = args .sort )
146+
31147 print ("------warm up------" )
32148 test_prefix (
33149 llm = llm ,
@@ -45,11 +161,15 @@ def main(args):
45161
46162if __name__ == "__main__" :
47163 parser = FlexibleArgumentParser (
48- description = 'Benchmark the performance with or without automatic '
49- 'prefix caching.' )
164+ description =
165+ 'Benchmark the performance with or without automatic prefix caching.' )
50166 parser .add_argument ('--model' ,
51167 type = str ,
52168 default = 'baichuan-inc/Baichuan2-13B-Chat' )
169+ parser .add_argument ("--dataset-path" ,
170+ type = str ,
171+ default = None ,
172+ help = "Path to the dataset." )
53173 parser .add_argument ('--tensor-parallel-size' , '-tp' , type = int , default = 1 )
54174 parser .add_argument ('--output-len' , type = int , default = 10 )
55175 parser .add_argument ('--enable-prefix-caching' ,
@@ -58,5 +178,21 @@ def main(args):
58178 parser .add_argument ('--use-v2-block-manager' ,
59179 action = 'store_true' ,
60180 help = 'Use BlockSpaceMangerV2' )
181+ parser .add_argument ('--num-prompts' ,
182+ type = int ,
183+ default = 1 ,
184+ help = "Number of the prompts sampled from dataset" )
185+ parser .add_argument ('--repeat-count' ,
186+ type = int ,
187+ default = 100 ,
188+ help = 'Number of times to repeat each prompt' )
189+ parser .add_argument ('--sort' ,
190+ action = 'store_true' ,
191+ help = 'Sort prompts by input length' )
192+ parser .add_argument ('--input-length-range' ,
193+ type = str ,
194+ default = '128:256' ,
195+ help = 'Range of input lengths for sampling prompts,'
196+ 'specified as "min:max" (e.g., "128:256").' )
61197 args = parser .parse_args ()
62198 main (args )
0 commit comments