diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 3367cb752a69..a5bf4c096f4b 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -68,6 +68,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( + DeepEPMode, configure_logger, get_bool_env_var, kill_process_tree, @@ -275,6 +276,9 @@ def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner): disable_cuda_graph=model_runner.server_args.disable_cuda_graph, spec_algorithm=SpeculativeAlgorithm.NONE, speculative_num_draft_tokens=None, + enable_two_batch_overlap=model_runner.server_args.enable_two_batch_overlap, + enable_deepep_moe=model_runner.server_args.enable_deepep_moe, + deepep_mode=DeepEPMode[model_runner.server_args.deepep_mode], ) @@ -339,6 +343,7 @@ def latency_test_run_once( log_decode_step, profile, profile_filename_prefix, + dp_size, ): max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) if batch_size > max_batch_size: @@ -353,7 +358,7 @@ def latency_test_run_once( measurement_results = { "run_name": run_name, - "batch_size": batch_size, + "batch_size": batch_size * dp_size, "input_len": input_len, "output_len": output_len, } @@ -378,7 +383,7 @@ def latency_test_run_once( synchronize(device) prefill_latency = time.perf_counter() - tic tot_latency += prefill_latency - throughput = input_len * batch_size / prefill_latency + throughput = input_len * batch_size * dp_size / prefill_latency rank_print( f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" ) @@ -394,16 +399,16 @@ def latency_test_run_once( synchronize(device) latency = time.perf_counter() - tic tot_latency += latency - throughput = batch_size / latency + throughput = batch_size * dp_size / latency decode_latencies.append(latency) if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0): rank_print( - f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" + f"Decode {i}. Batch size: {batch_size * dp_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" ) if profile: profiler.stop() - profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz" + profile_filename = f"{profile_filename_prefix}_batch{batch_size * dp_size}_input{input_len}_output{output_len}.trace.json.gz" parent_dir = os.path.dirname(os.path.abspath(profile_filename)) os.makedirs(parent_dir, exist_ok=True) profiler.export_chrome_trace(profile_filename) @@ -412,14 +417,14 @@ def latency_test_run_once( # Record decode timing from 2nd output if output_len > 1: med_decode_latency = np.median(decode_latencies) - med_decode_throughput = batch_size / med_decode_latency + med_decode_throughput = batch_size * dp_size / med_decode_latency rank_print( f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s" ) measurement_results["median_decode_latency"] = med_decode_latency measurement_results["median_decode_throughput"] = med_decode_throughput - throughput = (input_len + output_len) * batch_size / tot_latency + throughput = (input_len + output_len) * batch_size * dp_size / tot_latency rank_print( f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" ) @@ -464,6 +469,7 @@ def latency_test( log_decode_step=0, profile=False, profile_filename_prefix="", # not used + dp_size=1 if not server_args.enable_dp_attention else server_args.dp_size, ) rank_print("Benchmark ...") @@ -486,6 +492,7 @@ def latency_test( bench_args.log_decode_step, bench_args.profile if tp_rank == 0 else None, bench_args.profile_filename_prefix, + 1 if not server_args.enable_dp_attention else server_args.dp_size, ) if ret is not None: result_list.append(ret) @@ -502,6 +509,12 @@ def latency_test( def main(server_args, bench_args): server_args.cuda_graph_max_bs = max(bench_args.batch_size) + if server_args.enable_dp_attention: + sub_batch_size = [] + for i in range(len(bench_args.batch_size)): + assert bench_args.batch_size[i] % server_args.dp_size == 0 + sub_batch_size.append(bench_args.batch_size[i] // server_args.dp_size) + bench_args.batch_size = tuple(sub_batch_size) _set_envs_and_config(server_args)