Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ jobs:
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default

- name: Benchmark Offline Throughput (Non-streaming, small batch size)
timeout-minutes: 10
run: |
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size

performance-test-1-gpu-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,13 +845,15 @@ def run_benchmark(args_: argparse.Namespace):
tokenizer = get_tokenizer(tokenizer_id)

if args.dataset_name == "sharegpt":
assert args.random_input_len is None and args.random_output_len is None
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
)
elif args.dataset_name == "random":
assert args.random_input_len is not None and args.random_output_len is not None
input_requests = sample_random_requests(
input_len=args.random_input_len,
output_len=args.random_output_len,
Expand Down Expand Up @@ -964,13 +966,11 @@ def set_ulimit(target_soft_limit=65535):
parser.add_argument(
"--random-input-len",
type=int,
default=1024,
help="Number of input tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-output-len",
type=int,
default=128,
help="Number of output tokens per request, used only for random dataset.",
)
parser.add_argument(
Expand Down
33 changes: 25 additions & 8 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import time
import warnings
from enum import IntEnum, auto
from typing import List, Optional, Union

import torch
Expand Down Expand Up @@ -72,6 +73,12 @@
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"


class NewBatchState(IntEnum):
FREE = auto()
FULL = auto()
NOREQ = auto()


class Scheduler:
"""A scheduler that manages a tensor parallel GPU worker."""

Expand Down Expand Up @@ -224,7 +231,7 @@ def __init__(
)
self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.do_not_get_new_batch = False
self.get_new_batch_state = NewBatchState.FREE

def event_loop(self):
while True:
Expand Down Expand Up @@ -263,12 +270,14 @@ def process_requests(self, recv_reqs: List):
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False
if self.get_new_batch_state == NewBatchState.NOREQ:
self.get_new_batch_state = NewBatchState.FREE
elif isinstance(
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
):
self.handle_embedding_request(recv_req)
self.do_not_get_new_batch = False
if self.get_new_batch_state == NewBatchState.NOREQ:
self.get_new_batch_state = NewBatchState.FREE
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
Expand All @@ -281,11 +290,13 @@ def process_requests(self, recv_reqs: List):

@torch.inference_mode()
def forward_step(self):
if self.do_not_get_new_batch and self.current_inflight_req is None:
if (
self.get_new_batch_state != NewBatchState.FREE
and self.current_inflight_req is None
):
new_batch = None
else:
new_batch = self.get_new_prefill_batch()
self.do_not_get_new_batch = False

if new_batch is not None:
# Run a new prefill batch
Expand Down Expand Up @@ -449,6 +460,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
if running_bs >= self.max_running_requests:
self.get_new_batch_state = NewBatchState.FULL
return None

# Get priority queue
Expand Down Expand Up @@ -492,16 +504,19 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
)
> self.max_loras_per_batch
):
self.get_new_batch_state = NewBatchState.FULL
break

if adder.no_remaining_tokens():
self.get_new_batch_state = NewBatchState.FULL
break
req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req)
if (
not res
or running_bs + len(adder.can_run_list) >= self.max_running_requests
):
self.get_new_batch_state = NewBatchState.FULL
break

can_run_list = adder.can_run_list
Expand All @@ -511,6 +526,8 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
self.current_inflight_req = adder.new_inflight_req

if len(can_run_list) == 0:
if len(self.waiting_queue) == 0:
self.get_new_batch_state = NewBatchState.NOREQ
return None

# Print stats
Expand Down Expand Up @@ -812,9 +829,6 @@ def forward_decode_batch(self, batch: ScheduleBatch):
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

if not has_finished:
self.do_not_get_new_batch = True

self.handle_finished_requests(batch)

def handle_finished_requests(self, batch: ScheduleBatch):
Expand All @@ -835,6 +849,9 @@ def handle_finished_requests(self, batch: ScheduleBatch):
for i, req in enumerate(batch.reqs):
if not req.finished() and req is not self.current_inflight_req:
unfinished_indices.append(i)
else:
if self.get_new_batch_state == NewBatchState.FULL:
self.get_new_batch_state = NewBatchState.FREE

if req.finished() or (
req.stream
Expand Down
19 changes: 14 additions & 5 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,16 @@ def get_similarities(vec1, vec2):
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)


def run_bench_serving(model, num_prompts, request_rate, other_server_args):
def run_bench_serving(
model,
num_prompts,
request_rate,
other_server_args,
dataset_name="random",
random_input_len=4096,
random_output_len=2048,
disable_stream=False,
):
# Launch the server
base_url = DEFAULT_URL_FOR_TEST
process = popen_launch_server(
Expand All @@ -530,21 +539,21 @@ def run_bench_serving(model, num_prompts, request_rate, other_server_args):
base_url=base_url,
host=None,
port=None,
dataset_name="random",
dataset_name=dataset_name,
dataset_path="",
model=None,
tokenizer=None,
num_prompts=num_prompts,
sharegpt_output_len=None,
random_input_len=4096,
random_output_len=2048,
random_input_len=random_input_len,
random_output_len=random_output_len,
random_range_ratio=0.0,
request_rate=request_rate,
multi=None,
seed=0,
output_file=None,
disable_tqdm=False,
disable_stream=False,
disable_stream=disable_stream,
disable_ignore_eos=False,
extra_request_body=None,
)
Expand Down
19 changes: 17 additions & 2 deletions test/srt/test_bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,22 @@ def test_offline_throughput_default(self):
)

if is_in_ci():
assert res["output_throughput"] > 2600
assert res["output_throughput"] > 2850

def test_offline_throughput_non_stream_small_batch_size(self):
res = run_bench_serving(
model=DEFAULT_MODEL_NAME_FOR_TEST,
num_prompts=50,
request_rate=float("inf"),
dataset_name="sharegpt",
random_input_len=None,
random_output_len=None,
disable_stream=True,
other_server_args=["--max-running-requests", "10"],
)

if is_in_ci():
assert res["output_throughput"] > 880

def test_offline_throughput_without_radix_cache(self):
res = run_bench_serving(
Expand Down Expand Up @@ -58,7 +73,7 @@ def test_offline_throughput_with_triton_attention_backend(self):
)

if is_in_ci():
assert res["output_throughput"] > 2600
assert res["output_throughput"] > 2960

def test_offline_throughput_default_fp8(self):
res = run_bench_serving(
Expand Down