Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ jobs:
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default

- name: Benchmark Offline Throughput (Non-streaming, small batch size)
timeout-minutes: 10
run: |
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size

performance-test-1-gpu-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,13 +845,15 @@ def run_benchmark(args_: argparse.Namespace):
tokenizer = get_tokenizer(tokenizer_id)

if args.dataset_name == "sharegpt":
assert args.random_input_len is None and args.random_output_len is None
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
)
elif args.dataset_name == "random":
assert args.random_input_len is not None and args.random_output_len is not None
input_requests = sample_random_requests(
input_len=args.random_input_len,
output_len=args.random_output_len,
Expand Down Expand Up @@ -964,13 +966,11 @@ def set_ulimit(target_soft_limit=65535):
parser.add_argument(
"--random-input-len",
type=int,
default=1024,
help="Number of input tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-output-len",
type=int,
default=128,
help="Number of output tokens per request, used only for random dataset.",
)
parser.add_argument(
Expand Down
18 changes: 10 additions & 8 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __init__(
)
self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.do_not_get_new_batch = False
self.batch_is_full = False

def event_loop(self):
while True:
Expand Down Expand Up @@ -261,12 +261,10 @@ def process_requests(self, recv_reqs: List):
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
):
self.handle_embedding_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
Expand All @@ -279,11 +277,12 @@ def process_requests(self, recv_reqs: List):

@torch.inference_mode()
def forward_step(self):
if self.do_not_get_new_batch and self.current_inflight_req is None:
if (
self.batch_is_full or len(self.waiting_queue) == 0
) and self.current_inflight_req is None:
new_batch = None
else:
new_batch = self.get_new_prefill_batch()
self.do_not_get_new_batch = False

if new_batch is not None:
# Run a new prefill batch
Expand Down Expand Up @@ -447,6 +446,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
if running_bs >= self.max_running_requests:
self.batch_is_full = True
return None

# Get priority queue
Expand Down Expand Up @@ -490,16 +490,19 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
)
> self.max_loras_per_batch
):
self.batch_is_full = True
break

if adder.no_remaining_tokens():
self.batch_is_full = True
break
req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req)
if (
not res
or running_bs + len(adder.can_run_list) >= self.max_running_requests
):
self.batch_is_full = True
break

can_run_list = adder.can_run_list
Expand Down Expand Up @@ -810,9 +813,6 @@ def forward_decode_batch(self, batch: ScheduleBatch):
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

if not has_finished:
self.do_not_get_new_batch = True

self.handle_finished_requests(batch)

def handle_finished_requests(self, batch: ScheduleBatch):
Expand All @@ -833,6 +833,8 @@ def handle_finished_requests(self, batch: ScheduleBatch):
for i, req in enumerate(batch.reqs):
if not req.finished() and req is not self.current_inflight_req:
unfinished_indices.append(i)
else:
self.batch_is_full = False

if req.finished() or (
req.stream
Expand Down
19 changes: 14 additions & 5 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,16 @@ def get_similarities(vec1, vec2):
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)


def run_bench_serving(model, num_prompts, request_rate, other_server_args):
def run_bench_serving(
model,
num_prompts,
request_rate,
other_server_args,
dataset_name="random",
random_input_len=4096,
random_output_len=2048,
disable_stream=False,
):
# Launch the server
base_url = DEFAULT_URL_FOR_TEST
process = popen_launch_server(
Expand All @@ -530,21 +539,21 @@ def run_bench_serving(model, num_prompts, request_rate, other_server_args):
base_url=base_url,
host=None,
port=None,
dataset_name="random",
dataset_name=dataset_name,
dataset_path="",
model=None,
tokenizer=None,
num_prompts=num_prompts,
sharegpt_output_len=None,
random_input_len=4096,
random_output_len=2048,
random_input_len=random_input_len,
random_output_len=random_output_len,
random_range_ratio=0.0,
request_rate=request_rate,
multi=None,
seed=0,
output_file=None,
disable_tqdm=False,
disable_stream=False,
disable_stream=disable_stream,
disable_ignore_eos=False,
extra_request_body=None,
)
Expand Down
21 changes: 18 additions & 3 deletions test/srt/test_bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,22 @@ def test_offline_throughput_default(self):
)

if is_in_ci():
assert res["output_throughput"] > 2600
assert res["output_throughput"] > 2830

def test_offline_throughput_non_stream_small_batch_size(self):
res = run_bench_serving(
model=DEFAULT_MODEL_NAME_FOR_TEST,
num_prompts=200,
request_rate=float("inf"),
dataset_name="sharegpt",
random_input_len=None,
random_output_len=None,
disable_stream=True,
other_server_args=["--max-running-requests", "10"],
)

if is_in_ci():
assert res["output_throughput"] > 1000

def test_offline_throughput_without_radix_cache(self):
res = run_bench_serving(
Expand All @@ -31,7 +46,7 @@ def test_offline_throughput_without_radix_cache(self):
)

if is_in_ci():
assert res["output_throughput"] > 2800
assert res["output_throughput"] > 2880

def test_offline_throughput_without_chunked_prefill(self):
res = run_bench_serving(
Expand All @@ -58,7 +73,7 @@ def test_offline_throughput_with_triton_attention_backend(self):
)

if is_in_ci():
assert res["output_throughput"] > 2600
assert res["output_throughput"] > 2930

def test_offline_throughput_default_fp8(self):
res = run_bench_serving(
Expand Down