Skip to content

Commit 26e6a6f

Browse files
committed
fix default parser for llama & logprobs
1 parent 4ae9dfa commit 26e6a6f

8 files changed

Lines changed: 361 additions & 144 deletions

File tree

scratchpad/managers/detokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def __init__(
5050
# Init inter-process communication
5151
context = zmq.Context(2)
5252
self.recv_from_scheduler = get_zmq_socket(
53-
context, zmq.PULL, server_args.detokenizer_ipc_name
53+
context, zmq.PULL, server_args.detokenizer_ipc_name, True
5454
)
5555
self.send_to_tokenizer = get_zmq_socket(
56-
context, zmq.PUSH, server_args.tokenizer_ipc_name
56+
context, zmq.PUSH, server_args.tokenizer_ipc_name, False
5757
)
5858

5959
if server_args.skip_tokenizer_init:
@@ -228,6 +228,6 @@ def run_detokenizer_process(
228228
manager = DetokenizerManager(server_args)
229229
manager.event_loop()
230230
except Exception:
231-
msg = get_exception_traceback()
231+
msg: str = get_exception_traceback()
232232
logger.error(msg)
233233
kill_parent_process()

scratchpad/managers/tokenizer.py

Lines changed: 117 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,35 @@
5151
class ReqState:
5252
"""Store the state a request."""
5353

54-
out_list: List
54+
out_list: List[Dict[Any, Any]]
5555
finished: bool
5656
event: asyncio.Event
57-
obj: Any
57+
obj: Union[GenerateReqInput, EmbeddingReqInput]
5858

5959
# For metrics
6060
created_time: float
61-
first_token_time: Optional[float] = None
61+
finished_time: float = 0.0
62+
first_token_time: float = 0.0
63+
last_time: float = 0.0
64+
last_completion_tokens: int = 1
6265

6366
# For streaming output
6467
last_output_offset: int = 0
68+
# For incremental state update.
69+
text: str = ""
70+
output_ids: List[int] = dataclasses.field(default_factory=list)
71+
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
72+
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
73+
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
74+
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
75+
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
76+
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
77+
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
78+
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
79+
input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
80+
input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
81+
output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
82+
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
6583

6684

6785
class TokenizerManager:
@@ -77,10 +95,16 @@ def __init__(
7795
# Init inter-process communication
7896
context = zmq.asyncio.Context(2)
7997
self.recv_from_detokenizer = get_zmq_socket(
80-
context, zmq.PULL, server_args.tokenizer_ipc_name
98+
context,
99+
zmq.PULL,
100+
server_args.tokenizer_ipc_name,
101+
True,
81102
)
82103
self.send_to_scheduler = get_zmq_socket(
83-
context, zmq.PUSH, server_args.scheduler_input_ipc_name
104+
context,
105+
zmq.PUSH,
106+
server_args.scheduler_input_ipc_name,
107+
True,
84108
)
85109

86110
# Read model args
@@ -503,6 +527,7 @@ async def handle_loop(self):
503527
recv_obj: Union[
504528
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
505529
] = await self.recv_from_detokenizer.recv_pyobj()
530+
506531
if isinstance(recv_obj, UpdateWeightReqOutput):
507532
if self.server_args.dp_size == 1:
508533
self.model_update_result.set_result(recv_obj)
@@ -512,6 +537,7 @@ async def handle_loop(self):
512537
if len(self.model_update_tmp) == self.server_args.dp_size:
513538
self.model_update_result.set_result(self.model_update_tmp)
514539
continue
540+
515541
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
516542
if self.server_args.dp_size == 1:
517543
self.mem_pool_size.set_result(recv_obj)
@@ -525,6 +551,7 @@ async def handle_loop(self):
525551
assert isinstance(
526552
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
527553
), f"Unexpected obj received: {type(recv_obj)}"
554+
528555
for i, rid in enumerate(recv_obj.rids):
529556
state = self.rid_to_state.get(rid, None)
530557
if state is None:
@@ -537,21 +564,21 @@ async def handle_loop(self):
537564
if getattr(state.obj, "return_logprob", False):
538565
self.convert_logprob_style(
539566
meta_info,
567+
state,
540568
state.obj.top_logprobs_num,
541569
state.obj.token_ids_logprob,
542-
state.obj.return_text_in_logprobs,
570+
state.obj.return_text_in_logprobs
571+
and not self.server_args.skip_tokenizer_init,
543572
recv_obj,
544573
i,
545574
)
546-
547575
if not isinstance(recv_obj, BatchEmbeddingOut):
548576
meta_info.update(
549577
{
550578
"completion_tokens": recv_obj.completion_tokens[i],
551579
"cached_tokens": recv_obj.cached_tokens[i],
552580
}
553581
)
554-
555582
if isinstance(recv_obj, BatchStrOut):
556583
out_dict = {
557584
"text": recv_obj.output_strs[i],
@@ -569,80 +596,132 @@ async def handle_loop(self):
569596
"embedding": recv_obj.embeddings[i],
570597
"meta_info": meta_info,
571598
}
599+
572600
state.out_list.append(out_dict)
573601
state.finished = recv_obj.finished_reasons[i] is not None
574602
state.event.set()
575603

576604
def convert_logprob_style(
577605
self,
578606
meta_info: dict,
607+
state: ReqState,
579608
top_logprobs_num: int,
580609
token_ids_logprob: List[int],
581610
return_text_in_logprobs: bool,
582611
recv_obj: BatchStrOut,
583612
recv_obj_index: int,
584-
):
613+
) -> None:
614+
if len(recv_obj.input_token_logprobs_val) > 0:
615+
state.input_token_logprobs_val.extend(
616+
recv_obj.input_token_logprobs_val[recv_obj_index]
617+
)
618+
state.input_token_logprobs_idx.extend(
619+
recv_obj.input_token_logprobs_idx[recv_obj_index]
620+
)
621+
state.output_token_logprobs_val.extend(
622+
recv_obj.output_token_logprobs_val[recv_obj_index]
623+
)
624+
state.output_token_logprobs_idx.extend(
625+
recv_obj.output_token_logprobs_idx[recv_obj_index]
626+
)
585627
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
586-
recv_obj.input_token_logprobs_val[recv_obj_index],
587-
recv_obj.input_token_logprobs_idx[recv_obj_index],
628+
state.input_token_logprobs_val,
629+
state.input_token_logprobs_idx,
588630
return_text_in_logprobs,
589631
)
590632
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
591-
recv_obj.output_token_logprobs_val[recv_obj_index],
592-
recv_obj.output_token_logprobs_idx[recv_obj_index],
633+
state.output_token_logprobs_val,
634+
state.output_token_logprobs_idx,
593635
return_text_in_logprobs,
594636
)
595637

596638
if top_logprobs_num > 0:
639+
if len(recv_obj.input_top_logprobs_val) > 0:
640+
state.input_top_logprobs_val.extend(
641+
recv_obj.input_top_logprobs_val[recv_obj_index]
642+
)
643+
state.input_top_logprobs_idx.extend(
644+
recv_obj.input_top_logprobs_idx[recv_obj_index]
645+
)
646+
state.output_top_logprobs_val.extend(
647+
recv_obj.output_top_logprobs_val[recv_obj_index]
648+
)
649+
state.output_top_logprobs_idx.extend(
650+
recv_obj.output_top_logprobs_idx[recv_obj_index]
651+
)
597652
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
598-
recv_obj.input_top_logprobs_val[recv_obj_index],
599-
recv_obj.input_top_logprobs_idx[recv_obj_index],
653+
state.input_top_logprobs_val,
654+
state.input_top_logprobs_idx,
600655
return_text_in_logprobs,
601656
)
602657
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
603-
recv_obj.output_top_logprobs_val[recv_obj_index],
604-
recv_obj.output_top_logprobs_idx[recv_obj_index],
658+
state.output_top_logprobs_val,
659+
state.output_top_logprobs_idx,
605660
return_text_in_logprobs,
606661
)
607662

608663
if token_ids_logprob is not None:
664+
if len(recv_obj.input_token_ids_logprobs_val) > 0:
665+
state.input_token_ids_logprobs_val.extend(
666+
recv_obj.input_token_ids_logprobs_val[recv_obj_index]
667+
)
668+
state.input_token_ids_logprobs_idx.extend(
669+
recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
670+
)
671+
state.output_token_ids_logprobs_val.extend(
672+
recv_obj.output_token_ids_logprobs_val[recv_obj_index]
673+
)
674+
state.output_token_ids_logprobs_idx.extend(
675+
recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
676+
)
609677
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
610-
recv_obj.input_token_ids_logprobs_val[recv_obj_index],
611-
recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
678+
state.input_token_ids_logprobs_val,
679+
state.input_token_ids_logprobs_idx,
612680
return_text_in_logprobs,
613681
)
614682
meta_info[
615683
"output_token_ids_logprobs"
616684
] = self.detokenize_top_logprobs_tokens(
617-
recv_obj.output_token_ids_logprobs_val[recv_obj_index],
618-
recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
685+
state.output_token_ids_logprobs_val,
686+
state.output_token_ids_logprobs_idx,
619687
return_text_in_logprobs,
620688
)
621689

622690
def detokenize_logprob_tokens(
623-
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
691+
self,
692+
token_logprobs_val: List[float],
693+
token_logprobs_idx: List[int],
694+
decode_to_text: bool,
624695
):
625-
# TODO(lianmin): This should run on DetokenizerManager
626696
if not decode_to_text:
627-
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
628-
629-
assert self.tokenizer is not None
630-
token_ids = [tid for _, tid in token_logprobs]
631-
token_texts = self.tokenizer.batch_decode(token_ids)
632-
return [
633-
(logprob, token_id, token_text)
634-
for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
635-
]
697+
return [
698+
(logprob, token_id, None)
699+
for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
700+
]
701+
else:
702+
assert self.tokenizer is not None
703+
token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
704+
return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
636705

637-
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
706+
def detokenize_top_logprobs_tokens(
707+
self,
708+
token_logprobs_val: List[float],
709+
token_logprobs_idx: List[int],
710+
decode_to_text: bool,
711+
):
638712
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
639713
# We should batch all top-k tokens in all positions.
640-
for i, token_top_logprobs in enumerate(top_logprobs):
641-
if token_top_logprobs:
642-
top_logprobs[i] = self.detokenize_logprob_tokens(
643-
token_top_logprobs, decode_to_text
714+
ret = []
715+
for i in range(len(token_logprobs_val)):
716+
if token_logprobs_val[i]:
717+
ret.append(
718+
self.detokenize_logprob_tokens(
719+
token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
720+
)
644721
)
645-
return top_logprobs
722+
else:
723+
ret.append(None)
724+
return ret
646725

647726

648727
class SignalHandler:

scratchpad/nn/layers/logits_processor.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,32 @@
2020

2121
@dataclass
2222
class LogitsProcessorOutput:
23-
## Part 1: This part will be assigned in nn/layers/logits_processor.py::LogitsProcessor
23+
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
2424
# The logits of the next tokens. shape: [#seq, vocab_size]
2525
next_token_logits: torch.Tensor
2626
# Used by speculative decoding (EAGLE)
2727
# The last hidden layers
2828
hidden_states: Optional[torch.Tensor] = None
2929

30-
## Part 2: This part will be assigned in nn/layers/sampler.py::Sampler
30+
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
3131
# The logprobs of the next tokens. shape: [#seq]
3232
next_token_logprobs: Optional[torch.Tensor] = None
3333
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
3434
next_token_top_logprobs_val: Optional[List] = None
3535
next_token_top_logprobs_idx: Optional[List] = None
36+
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
37+
next_token_token_ids_logprobs_val: Optional[List] = None
38+
next_token_token_ids_logprobs_idx: Optional[List] = None
3639

37-
## Part 3: Prefill-only. This part will be assigned in nn/layers/logits_processor.py::LogitsProcessor
38-
# The normlaized logprobs of prompts. shape: [#seq]
39-
normalized_prompt_logprobs: torch.Tensor = None
40+
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
4041
# The logprobs of input tokens. shape: [#token]
41-
input_token_logprobs: torch.Tensor = None
42+
input_token_logprobs: Optional[torch.Tensor] = None
4243
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
4344
input_top_logprobs_val: List = None
4445
input_top_logprobs_idx: List = None
46+
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
47+
input_token_ids_logprobs_val: Optional[List] = None
48+
input_token_ids_logprobs_idx: Optional[List] = None
4549

4650

4751
@dataclass

0 commit comments

Comments
 (0)