diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 22060243f97..67551f09c0f 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -130,6 +130,12 @@ jobs: cd test/srt python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default + - name: Benchmark Offline Throughput (Non-streaming, small batch size) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size + performance-test-1-gpu-part-2: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 1-gpu-runner diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 2f68b39bbc3..4afccf73a4f 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -845,6 +845,7 @@ def run_benchmark(args_: argparse.Namespace): tokenizer = get_tokenizer(tokenizer_id) if args.dataset_name == "sharegpt": + assert args.random_input_len is None and args.random_output_len is None input_requests = sample_sharegpt_requests( dataset_path=args.dataset_path, num_requests=args.num_prompts, @@ -852,6 +853,7 @@ def run_benchmark(args_: argparse.Namespace): fixed_output_len=args.sharegpt_output_len, ) elif args.dataset_name == "random": + assert args.random_input_len is not None and args.random_output_len is not None input_requests = sample_random_requests( input_len=args.random_input_len, output_len=args.random_output_len, @@ -964,13 +966,11 @@ def set_ulimit(target_soft_limit=65535): parser.add_argument( "--random-input-len", type=int, - default=1024, help="Number of input tokens per request, used only for random dataset.", ) parser.add_argument( "--random-output-len", type=int, - default=128, help="Number of output tokens per request, used only for random dataset.", ) parser.add_argument( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5577c7fa408..9b3fc5cefab 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -222,7 +222,7 @@ def __init__( ) self.new_token_ratio = self.min_new_token_ratio self.new_token_ratio_decay = global_config.new_token_ratio_decay - self.do_not_get_new_batch = False + self.batch_is_full = False def event_loop(self): while True: @@ -261,12 +261,10 @@ def process_requests(self, recv_reqs: List): for recv_req in recv_reqs: if isinstance(recv_req, TokenizedGenerateReqInput): self.handle_generate_request(recv_req) - self.do_not_get_new_batch = False elif isinstance( recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput) ): self.handle_embedding_request(recv_req) - self.do_not_get_new_batch = False elif isinstance(recv_req, FlushCacheReq): self.flush_cache() elif isinstance(recv_req, AbortReq): @@ -279,11 +277,12 @@ def process_requests(self, recv_reqs: List): @torch.inference_mode() def forward_step(self): - if self.do_not_get_new_batch and self.current_inflight_req is None: + if ( + self.batch_is_full or len(self.waiting_queue) == 0 + ) and self.current_inflight_req is None: new_batch = None else: new_batch = self.get_new_prefill_batch() - self.do_not_get_new_batch = False if new_batch is not None: # Run a new prefill batch @@ -447,6 +446,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: len(self.running_batch.reqs) if self.running_batch is not None else 0 ) if running_bs >= self.max_running_requests: + self.batch_is_full = True return None # Get priority queue @@ -490,9 +490,11 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: ) > self.max_loras_per_batch ): + self.batch_is_full = True break if adder.no_remaining_tokens(): + self.batch_is_full = True break req.init_next_round_input(None if prefix_computed else self.tree_cache) res = adder.add_one_req(req) @@ -500,6 +502,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: not res or running_bs + len(adder.can_run_list) >= self.max_running_requests ): + self.batch_is_full = True break can_run_list = adder.can_run_list @@ -810,9 +813,6 @@ def forward_decode_batch(self, batch: ScheduleBatch): if req.top_logprobs_num > 0: req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) - if not has_finished: - self.do_not_get_new_batch = True - self.handle_finished_requests(batch) def handle_finished_requests(self, batch: ScheduleBatch): @@ -833,6 +833,8 @@ def handle_finished_requests(self, batch: ScheduleBatch): for i, req in enumerate(batch.reqs): if not req.finished() and req is not self.current_inflight_req: unfinished_indices.append(i) + else: + self.batch_is_full = False if req.finished() or ( req.stream diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 2c22f8d9017..7d844c9bc6e 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -514,7 +514,16 @@ def get_similarities(vec1, vec2): return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0) -def run_bench_serving(model, num_prompts, request_rate, other_server_args): +def run_bench_serving( + model, + num_prompts, + request_rate, + other_server_args, + dataset_name="random", + random_input_len=4096, + random_output_len=2048, + disable_stream=False, +): # Launch the server base_url = DEFAULT_URL_FOR_TEST process = popen_launch_server( @@ -530,21 +539,21 @@ def run_bench_serving(model, num_prompts, request_rate, other_server_args): base_url=base_url, host=None, port=None, - dataset_name="random", + dataset_name=dataset_name, dataset_path="", model=None, tokenizer=None, num_prompts=num_prompts, sharegpt_output_len=None, - random_input_len=4096, - random_output_len=2048, + random_input_len=random_input_len, + random_output_len=random_output_len, random_range_ratio=0.0, request_rate=request_rate, multi=None, seed=0, output_file=None, disable_tqdm=False, - disable_stream=False, + disable_stream=disable_stream, disable_ignore_eos=False, extra_request_body=None, ) diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index 2a327f858ea..056483487ba 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -20,7 +20,22 @@ def test_offline_throughput_default(self): ) if is_in_ci(): - assert res["output_throughput"] > 2600 + assert res["output_throughput"] > 2830 + + def test_offline_throughput_non_stream_small_batch_size(self): + res = run_bench_serving( + model=DEFAULT_MODEL_NAME_FOR_TEST, + num_prompts=200, + request_rate=float("inf"), + dataset_name="sharegpt", + random_input_len=None, + random_output_len=None, + disable_stream=True, + other_server_args=["--max-running-requests", "10"], + ) + + if is_in_ci(): + assert res["output_throughput"] > 1000 def test_offline_throughput_without_radix_cache(self): res = run_bench_serving( @@ -31,7 +46,7 @@ def test_offline_throughput_without_radix_cache(self): ) if is_in_ci(): - assert res["output_throughput"] > 2800 + assert res["output_throughput"] > 2880 def test_offline_throughput_without_chunked_prefill(self): res = run_bench_serving( @@ -58,7 +73,7 @@ def test_offline_throughput_with_triton_attention_backend(self): ) if is_in_ci(): - assert res["output_throughput"] > 2600 + assert res["output_throughput"] > 2930 def test_offline_throughput_default_fp8(self): res = run_bench_serving(