@@ -440,16 +440,11 @@ def recv_requests(self):
440440 if self .tp_rank == 0 or self .server_args .enable_dp_attention :
441441 recv_reqs = []
442442
443- if self .last_batch is None :
444- recv_req = self .recv_from_tokenizer .recv_pyobj ()
445- recv_reqs .append (recv_req )
446- else :
447- while True :
448- try :
449- recv_req = self .recv_from_tokenizer .recv_pyobj (zmq .NOBLOCK )
450- except zmq .ZMQError :
451- break
452- recv_reqs .append (recv_req )
443+ while True :
444+ try :
445+ recv_req = self .recv_from_tokenizer .recv_pyobj (zmq .NOBLOCK )
446+ except zmq .ZMQError :
447+ break
453448 else :
454449 recv_reqs = None
455450
@@ -949,6 +944,7 @@ def process_batch_result(self, batch: ScheduleBatch, result):
949944 batch .next_batch_sampling_info .sampling_info_done .set ()
950945
951946 def process_batch_result_prefill (self , batch : ScheduleBatch , result ):
947+ skip_stream_req = None
952948
953949 if self .is_generation :
954950 logits_output , next_token_ids , bid = result
@@ -1005,6 +1001,10 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
10051001 else :
10061002 # being chunked reqs' prefill is not finished
10071003 req .is_being_chunked -= 1
1004+ # There is only at most one request being currently chunked.
1005+ # Because this request does not finish prefill,
1006+ # we don't want to stream the request currently being chunked.
1007+ skip_stream_req = req
10081008
10091009 if batch .next_batch_sampling_info :
10101010 batch .next_batch_sampling_info .update_regex_vocab_mask ()
@@ -1034,7 +1034,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
10341034 # being chunked reqs' prefill is not finished
10351035 req .is_being_chunked -= 1
10361036
1037- self .stream_output (batch .reqs )
1037+ self .stream_output (batch .reqs , skip_stream_req )
10381038
10391039 def process_batch_result_decode (self , batch : ScheduleBatch , result ):
10401040 logits_output , next_token_ids , bid = result
@@ -1179,7 +1179,7 @@ def add_logprob_return_values(
11791179
11801180 return num_input_logprobs
11811181
1182- def stream_output (self , reqs : List [Req ]):
1182+ def stream_output (self , reqs : List [Req ], skip_req : Optional [ Req ] = None ):
11831183 """Stream the output to detokenizer."""
11841184 output_rids = []
11851185 output_meta_info : List [dict ] = []
@@ -1199,6 +1199,9 @@ def stream_output(self, reqs: List[Req]):
11991199 is_stream_iter = self .forward_ct_decode % self .stream_interval == 0
12001200
12011201 for req in reqs :
1202+ if req is skip_req :
1203+ continue
1204+
12021205 # TODO(lianmin): revisit this for overlap + retract + stream
12031206 if req .finished () or (
12041207 req .stream and (is_stream_iter or len (req .output_ids ) == 1 )
0 commit comments