Skip to content

Commit a2486eb

Browse files
authored
Fix a bug with logprob streaming + chunked prefill (#2403)
1 parent 61dec54 commit a2486eb

3 files changed

Lines changed: 24 additions & 13 deletions

File tree

python/sglang/bench_serving.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ async def async_request_sglang_generate(
321321
},
322322
"stream": not args.disable_stream,
323323
"lora_path": request_func_input.lora_name,
324+
"return_logprob": args.return_logprob,
325+
"logprob_start_len": -1,
324326
**request_func_input.extra_request_body,
325327
}
326328
headers = {}
@@ -911,7 +913,7 @@ async def limited_request_func(request_func_input, pbar):
911913
prompt=test_prompt,
912914
api_url=api_url,
913915
prompt_len=test_prompt_len,
914-
output_len=test_output_len,
916+
output_len=min(test_output_len, 32),
915917
lora_name=lora_name,
916918
extra_request_body=extra_request_body,
917919
)
@@ -1413,6 +1415,11 @@ def set_ulimit(target_soft_limit=65535):
14131415
action="store_true",
14141416
help="Disable ignoring EOS.",
14151417
)
1418+
parser.add_argument(
1419+
"--return-logprob",
1420+
action="store_true",
1421+
help="Return logprob.",
1422+
)
14161423
parser.add_argument(
14171424
"--extra-request-body",
14181425
metavar='{"key1": "value1", "key2": "value2"}',

python/sglang/srt/managers/scheduler.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -440,16 +440,11 @@ def recv_requests(self):
440440
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
441441
recv_reqs = []
442442

443-
if self.last_batch is None:
444-
recv_req = self.recv_from_tokenizer.recv_pyobj()
445-
recv_reqs.append(recv_req)
446-
else:
447-
while True:
448-
try:
449-
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
450-
except zmq.ZMQError:
451-
break
452-
recv_reqs.append(recv_req)
443+
while True:
444+
try:
445+
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
446+
except zmq.ZMQError:
447+
break
453448
else:
454449
recv_reqs = None
455450

@@ -949,6 +944,7 @@ def process_batch_result(self, batch: ScheduleBatch, result):
949944
batch.next_batch_sampling_info.sampling_info_done.set()
950945

951946
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
947+
skip_stream_req = None
952948

953949
if self.is_generation:
954950
logits_output, next_token_ids, bid = result
@@ -1005,6 +1001,10 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
10051001
else:
10061002
# being chunked reqs' prefill is not finished
10071003
req.is_being_chunked -= 1
1004+
# There is only at most one request being currently chunked.
1005+
# Because this request does not finish prefill,
1006+
# we don't want to stream the request currently being chunked.
1007+
skip_stream_req = req
10081008

10091009
if batch.next_batch_sampling_info:
10101010
batch.next_batch_sampling_info.update_regex_vocab_mask()
@@ -1034,7 +1034,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
10341034
# being chunked reqs' prefill is not finished
10351035
req.is_being_chunked -= 1
10361036

1037-
self.stream_output(batch.reqs)
1037+
self.stream_output(batch.reqs, skip_stream_req)
10381038

10391039
def process_batch_result_decode(self, batch: ScheduleBatch, result):
10401040
logits_output, next_token_ids, bid = result
@@ -1179,7 +1179,7 @@ def add_logprob_return_values(
11791179

11801180
return num_input_logprobs
11811181

1182-
def stream_output(self, reqs: List[Req]):
1182+
def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None):
11831183
"""Stream the output to detokenizer."""
11841184
output_rids = []
11851185
output_meta_info: List[dict] = []
@@ -1199,6 +1199,9 @@ def stream_output(self, reqs: List[Req]):
11991199
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
12001200

12011201
for req in reqs:
1202+
if req is skip_req:
1203+
continue
1204+
12021205
# TODO(lianmin): revisit this for overlap + retract + stream
12031206
if req.finished() or (
12041207
req.stream and (is_stream_iter or len(req.output_ids) == 1)

python/sglang/test/test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,7 @@ def run_bench_serving(
568568
disable_tqdm=False,
569569
disable_stream=disable_stream,
570570
disable_ignore_eos=False,
571+
return_logprob=False,
571572
lora_name=None,
572573
extra_request_body=None,
573574
profile=None,

0 commit comments

Comments
 (0)