From 496e28dacd4e95b19885116783188fe104bfb087 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Mon, 22 May 2023 03:22:41 +0000 Subject: [PATCH 1/2] Add prompt_token_ids --- cacheflow/entrypoints/llm.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/cacheflow/entrypoints/llm.py b/cacheflow/entrypoints/llm.py index 821b7db90308..acb9a7473ad9 100644 --- a/cacheflow/entrypoints/llm.py +++ b/cacheflow/entrypoints/llm.py @@ -35,18 +35,26 @@ def generate( self, prompts: List[str], sampling_params: Optional[SamplingParams] = None, + prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, ) -> List[RequestOutput]: if sampling_params is None: + # Use default sampling params. sampling_params = SamplingParams() # Initialize tqdm. if use_tqdm: pbar = tqdm(total=len(prompts), desc="Processed prompts") # Add requests to the server. - for prompt in prompts: + for i in range(len(prompts)): + prompt = prompts[i] + if prompt_token_ids is None: + token_ids = None + else: + token_ids = prompt_token_ids[i] request_id = str(next(self.request_counter)) - self.llm_server.add_request(request_id, prompt, sampling_params) + self.llm_server.add_request(request_id, prompt, sampling_params, + token_ids) # Run the server. outputs: List[RequestOutput] = [] From 572275559754a19476db081e24fc4a7d150bf8e3 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Mon, 22 May 2023 03:23:08 +0000 Subject: [PATCH 2/2] Fix benchmark_latency: --- benchmark/benchmark_latency.py | 62 ++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/benchmark/benchmark_latency.py b/benchmark/benchmark_latency.py index 410f7f0bd813..930b34a0074c 100644 --- a/benchmark/benchmark_latency.py +++ b/benchmark/benchmark_latency.py @@ -1,71 +1,75 @@ import argparse import time -from typing import List -from tqdm import tqdm import numpy as np import torch +from tqdm import tqdm -from cacheflow.core.server import ( - add_server_arguments, process_server_arguments, - init_local_server_and_frontend_with_arguments) -from cacheflow.sampling_params import SamplingParams +from cacheflow import LLM, SamplingParams def main(args: argparse.Namespace): - server, frontend = init_local_server_and_frontend_with_arguments(args) + print(args) + + # Process all the requests in a single batch if possible. + # NOTE(woosuk): If the request cannot be processed in a single batch, + # the server will automatically process the request in multiple batches. + llm = LLM( + model=args.model, + tensor_parallel_size=args.tensor_parallel_size, + max_num_seqs=args.batch_size, + max_num_batched_tokens=args.batch_size * args.input_len, + ) sampling_params = SamplingParams( n=args.n, temperature=0.0 if args.use_beam_search else 1.0, top_p=1.0, use_beam_search=args.use_beam_search, - stop_token_ids=set(), + ignore_eos=True, max_tokens=args.output_len, ) print(sampling_params) - input_token_ids = [0] * args.input_len + dummy_prompts = [""] * args.batch_size + dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size - def profile_step(profile=False): + def run_to_completion(profile: bool = False): if profile: torch.cuda.cudart().cudaProfilerStart() - for _ in range(args.batch_size): - dummy_prompt = "" - frontend._add_query(dummy_prompt, input_token_ids, sampling_params) - server.add_sequence_groups(frontend.get_inputs()) start_time = time.time() - while True: - server.step() - if not server.has_unfinished_requests(): - break + + llm.generate(dummy_prompts, sampling_params, dummy_prompt_token_ids, + use_tqdm=False) + end_time = time.time() latency = end_time - start_time if profile: torch.cuda.cudart().cudaProfilerStop() return latency - print("Warm up step") - profile_step() + print("Warming up...") + run_to_completion(profile=False) # Benchmark. latencies = [] - for _ in tqdm(range(3), desc="Profile step"): - latencies.append(profile_step()) + for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): + latencies.append(run_to_completion(profile=False)) print(f'Avg latency: {np.mean(latencies)} seconds') if __name__ == '__main__': parser = argparse.ArgumentParser( - description='Benchmark the latency of decoding a single sentence.') - parser = add_server_arguments(parser) + description='Benchmark the latency of processing a single batch of ' + 'requests till completion.') + parser.add_argument('--model', type=str, default='facebook/opt-125m') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--batch-size', type=int, default=8) - parser.add_argument('--n', type=int, default=1) + parser.add_argument('--n', type=int, default=1, + help='Number of generated sequences per prompt.') parser.add_argument('--use-beam-search', action='store_true') + parser.add_argument('--num-iters', type=int, default=3, + help='Number of iterations to run.') args = parser.parse_args() - args = process_server_arguments(args) - args.max_num_batched_tokens = max( - args.max_num_batched_tokens, args.batch_size * args.input_len) - print(args) main(args)