diff --git a/.gitignore b/.gitignore index 0cba7e920a5b..3b504fe57a9a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,7 @@ *.eggs/ *.so build/ +log/ +archive/ +*.csv +.vscode/ diff --git a/benchmark/batch_benchmark.sh b/benchmark/batch_benchmark.sh new file mode 100644 index 000000000000..9b17f47b5e1b --- /dev/null +++ b/benchmark/batch_benchmark.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +mkdir -p log + +MODEL_LOG_NAME="opt-13b" +MODEL="facebook/opt-13b" + +for BATCH_SIZE in 8 32 128; do + for INPUT_LEN in 1 32 256 1024; do + for OUTPUT_LEN in 1 16 128; do + for TENSOR_PARALLEL_SIZE in 1 2 4; do + python benchmark_latency.py \ + --model $MODEL \ + --batch-size $BATCH_SIZE \ + --input-len $INPUT_LEN \ + --output-len $OUTPUT_LEN \ + --tensor-parallel-size $TENSOR_PARALLEL_SIZE \ + | tee -a log/model_${MODEL_LOG_NAME}_bs_${BATCH_SIZE}_in_${INPUT_LEN}_out_${OUTPUT_LEN}_tp_${TENSOR_PARALLEL_SIZE}.log + sleep 0.1 + done + done + done +done diff --git a/benchmark/benchmark_latency.py b/benchmark/benchmark_latency.py index a18ef98f40a8..ff1d2d23c271 100644 --- a/benchmark/benchmark_latency.py +++ b/benchmark/benchmark_latency.py @@ -1,3 +1,4 @@ +import json import argparse import time from typing import List @@ -11,17 +12,24 @@ initialize_ray_cluster) from cacheflow.sampling_params import SamplingParams from cacheflow.utils import get_gpu_memory, get_cpu_memory +from cacheflow.profile import set_sync_for_profiling def main(args: argparse.Namespace): + print(json.dumps(args.__dict__)) + set_sync_for_profiling() + # TODO(zhuohan): Support pipeline parallelism. assert args.pipeline_parallel_size == 1, ( 'Pipeline parallelism is not supported yet.') + cuda_profiler = False + ray_cluster_address = "local" if cuda_profiler else "auto" + (num_nodes, num_devices_per_node, distributed_init_method, all_stage_devices) = ( initialize_ray_cluster( - address='local', + address=ray_cluster_address, pipeline_parallel_size=args.pipeline_parallel_size, tensor_parallel_size=args.tensor_parallel_size)) @@ -60,32 +68,55 @@ def main(args: argparse.Namespace): sampling_params = SamplingParams.from_dict(sampling_params_dict) input_token_ids = [0] * args.input_len - def profile_step(profile=False): - if profile: + def profile_step(): + if cuda_profiler: torch.cuda.cudart().cudaProfilerStart() + server.reset_timer() for _ in range(args.batch_size): frontend._add_query(input_token_ids, sampling_params) server.add_sequence_groups(frontend.get_inputs()) + # Prompt step start_time = time.time() - while True: + server.step() + end_time = time.time() + prompt_latency = end_time - start_time + # Decoding steps + num_decoding_steps = 0 + start_time = time.time() + while server.has_unfinished_requests(): server.step() - if not server.has_unfinished_requests(): - break + num_decoding_steps += 1 end_time = time.time() - latency = end_time - start_time - if profile: + decoding_latency = end_time - start_time + if cuda_profiler: torch.cuda.cudart().cudaProfilerStop() - return latency + server_profile_results = server.get_profile_results() + # First controller's first worker + worker_execution_latency = server_profile_results[0][0]["execution_latency"] + worker_communication_latency = server_profile_results[0][0]["communication_latency"] + return (prompt_latency, decoding_latency, num_decoding_steps, + worker_execution_latency, worker_communication_latency) - print("Warm up step") + print("== Warm up step ==") profile_step() # Benchmark. - latencies = [] - for _ in tqdm(range(3), desc="Profile step"): - latencies.append(profile_step()) - print(f'Avg latency: {np.mean(latencies)} seconds') - + print("== Profile steps ==") + num_profile_steps = 5 + for step in range(num_profile_steps): + (prompt_latency, decoding_latency, num_decoding_steps, + worker_execution_latency, worker_communication_latency) = profile_step() + decoding_latency_per_step = decoding_latency / num_decoding_steps if num_decoding_steps > 0 else 0.0 + result = { + "step": step, + "prompt_latency_seconds": prompt_latency, + "decoding_latency_seconds": decoding_latency, + "decoding_latency_per_step_seconds": decoding_latency_per_step, + "num_decoding_steps": num_decoding_steps, + "worker_execution_latency_seconds": worker_execution_latency, + "worker_communication_latency_seconds": worker_communication_latency, + } + print(json.dumps(result)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='CacheFlow simple server.') @@ -95,5 +126,4 @@ def profile_step(profile=False): parser.add_argument('--batch-size', type=int, default=8) args = parser.parse_args() args.max_batch_size = max(args.max_batch_size, args.batch_size * args.input_len) - print(args) main(args) diff --git a/benchmark/parse_log.py b/benchmark/parse_log.py new file mode 100644 index 000000000000..da25890048b8 --- /dev/null +++ b/benchmark/parse_log.py @@ -0,0 +1,45 @@ +import csv +import json +import os +from argparse import Namespace +from collections import defaultdict + +import numpy as np +import pandas as pd + +log_dir = 'log/' +log_files = os.listdir(log_dir) +all_results = [] + +for log_file in log_files: + file_path = os.path.join(log_dir, log_file) + lines = list(open(file_path).readlines()) + profile_arguments = json.loads(lines[0]) + results = defaultdict(list) + for line in lines: + if "prompt_latency_seconds" not in line: + continue + result = json.loads(line) + for k, v in result.items(): + if k == "step": + continue + results[k].append(v) + final_result = { + "model": profile_arguments["model"], + "batch_size": profile_arguments["batch_size"], + "input_len": profile_arguments["input_len"], + "output_len": profile_arguments["output_len"], + "tensor_parallel_size": profile_arguments["tensor_parallel_size"], + } + + for k, v in results.items(): + final_result[k + "_mean"] = np.mean(v) + final_result[k + "_std"] = np.std(v) + + all_results.append(final_result) + +df = pd.DataFrame.from_records(all_results) + +print(df) + +df.to_csv('parse_result.csv', index=False) \ No newline at end of file diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 1f224316c01b..f00ce612fad0 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -95,6 +95,14 @@ def has_unfinished_requests(self): return (self.scheduler.waiting or self.scheduler.running or self.scheduler.swapped) + def reset_timer(self): + for controller in self.controllers: + controller.reset_timer() + + def get_profile_results(self): + return [controller.get_profile_results() for controller in + self.controllers] + def initialize_ray_cluster( address: str = 'auto', diff --git a/cacheflow/parallel_utils/tensor_parallel/mappings.py b/cacheflow/parallel_utils/tensor_parallel/mappings.py index d9ca3b460d7b..d6946e5c3b30 100644 --- a/cacheflow/parallel_utils/tensor_parallel/mappings.py +++ b/cacheflow/parallel_utils/tensor_parallel/mappings.py @@ -1,5 +1,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import time + import torch from cacheflow.parallel_utils.parallel_state import ( @@ -7,6 +9,8 @@ get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, ) +from cacheflow.profile import (maybe_sync_for_profiling, + add_to_communication_latency) from .utils import split_tensor_along_last_dim @@ -17,9 +21,16 @@ def _reduce(input_): if get_tensor_model_parallel_world_size()==1: return input_ + maybe_sync_for_profiling() + start_time = time.time() + # All-reduce. torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) + maybe_sync_for_profiling() + end_time = time.time() + add_to_communication_latency(end_time - start_time) + return input_ @@ -78,8 +89,16 @@ def _gather_along_last_dim(input_): tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ + + maybe_sync_for_profiling() + start_time = time.time() + torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) + maybe_sync_for_profiling() + end_time = time.time() + add_to_communication_latency(end_time - start_time) + # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=last_dim).contiguous() @@ -99,9 +118,17 @@ def _gather_along_first_dim(input_): output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) + + maybe_sync_for_profiling() + start_time = time.time() + torch.distributed._all_gather_base(output, input_.contiguous(), group=get_tensor_model_parallel_group()) + maybe_sync_for_profiling() + end_time = time.time() + add_to_communication_latency(end_time - start_time) + return output def _reduce_scatter_along_first_dim(input_): @@ -119,8 +146,17 @@ def _reduce_scatter_along_first_dim(input_): output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) + + maybe_sync_for_profiling() + start_time = time.time() + torch.distributed._reduce_scatter_base(output, input_.contiguous(), group=get_tensor_model_parallel_group()) + + maybe_sync_for_profiling() + end_time = time.time() + add_to_communication_latency(end_time - start_time) + return output diff --git a/cacheflow/profile.py b/cacheflow/profile.py new file mode 100644 index 000000000000..f3f648a44935 --- /dev/null +++ b/cacheflow/profile.py @@ -0,0 +1,31 @@ +import torch + +# Global profile option + +SYNC_FOR_PROFILING = False + +def maybe_sync_for_profiling(): + if SYNC_FOR_PROFILING: + torch.cuda.synchronize() + +def get_sync_for_profiling(): + return SYNC_FOR_PROFILING + +def set_sync_for_profiling(new_value: bool = True): + global SYNC_FOR_PROFILING + SYNC_FOR_PROFILING = new_value + +# Communication latency + +COMMUNICATION_LATENCY = 0.0 + +def reset_communication_latency(): + global COMMUNICATION_LATENCY + COMMUNICATION_LATENCY = 0.0 + +def add_to_communication_latency(latency): + global COMMUNICATION_LATENCY + COMMUNICATION_LATENCY += latency + +def get_communication_latency(): + return COMMUNICATION_LATENCY diff --git a/cacheflow/worker/controller.py b/cacheflow/worker/controller.py index bb357b132665..924036a15025 100644 --- a/cacheflow/worker/controller.py +++ b/cacheflow/worker/controller.py @@ -1,10 +1,11 @@ -from typing import Dict, List, Union, Tuple +from typing import Dict, List, Union, Tuple, Any import ray from cacheflow.master.scheduler import Scheduler from cacheflow.sequence import SequenceGroupInputs from cacheflow.worker.worker import Worker +from cacheflow.profile import get_sync_for_profiling DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id @@ -57,6 +58,7 @@ def __init__( tensor_parallel_size=tensor_parallel_size, pipeline_parallel_size=pipeline_parallel_size, model_path=model_path, + sync_for_profiling=get_sync_for_profiling(), ) self.workers.append(worker) @@ -95,3 +97,10 @@ def execute_stage( else: # TODO: Support pipeline parallelism. assert False + + def get_profile_results(self) -> List[Dict[str, Any]]: + return ray.get([worker.get_profile_results.remote() + for worker in self.workers]) + + def reset_timer(self) -> None: + ray.get([worker.reset_timer.remote() for worker in self.workers]) diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index db0d46aabe9e..68a8022dc054 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -1,5 +1,6 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Any +import time import torch from cacheflow.models import get_model @@ -10,6 +11,10 @@ from cacheflow.worker.cache_engine import CacheEngine from cacheflow.parallel_utils.parallel_state import ( initialize_model_parallel, get_tensor_model_parallel_world_size) +from cacheflow.profile import (maybe_sync_for_profiling, + set_sync_for_profiling, + reset_communication_latency, + get_communication_latency) from cacheflow.utils import set_random_seed @@ -29,7 +34,9 @@ def __init__( model_path: str, tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, + sync_for_profiling: bool = False, ) -> None: + set_sync_for_profiling(sync_for_profiling) self.init_distributed_environment(distributed_init_method, rank, world_size, @@ -66,6 +73,8 @@ def __init__( self.cache_events = self.cache_engine.events self.gpu_cache = self.cache_engine.gpu_cache + self.reset_timer() + def init_distributed_environment(self, distributed_init_method: str, @@ -211,6 +220,9 @@ def execute_stage( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ) -> Dict[int, SequenceOutputs]: + maybe_sync_for_profiling() + start_time = time.time() + # Issue cache operations. command_issued = False if blocks_to_swap_in: @@ -247,8 +259,26 @@ def execute_stage( input_metadata=input_metadata, cache_events=cache_events, ) + maybe_sync_for_profiling() + end_time = time.time() + latency = end_time - start_time + self.execution_latency += latency + self.num_profiled_steps += 1 + return output + def reset_timer(self) -> None: + self.execution_latency = 0.0 + self.num_profiled_steps = 0 + reset_communication_latency() + + def get_profile_results(self) -> Dict[str, Any]: + communication_latency = get_communication_latency() + return { + 'execution_latency': self.execution_latency, + 'communication_latency': communication_latency, + 'num_profiled_steps': self.num_profiled_steps, + } def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]: return x + [0] * ((-len(x)) % multiple_of)