Skip to content

Commit e867178

Browse files
authored
Incrementally decode output tokens (#121)
1 parent aedba6d commit e867178

File tree

4 files changed

+83
-17
lines changed

4 files changed

+83
-17
lines changed

cacheflow/core/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def update(
291291
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
292292
# Append a new token to the sequence.
293293
output = seq_outputs[seq.seq_id]
294-
seq.append_token(output.output_token, output.logprobs)
294+
seq.append_token_id(output.output_token, output.logprobs)
295295
return self.running.copy()
296296

297297
def free_seq(self, seq: Sequence) -> None:

cacheflow/sequence.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(
2424
self.output_token_ids: List[int] = []
2525
self.cumulative_logprob = 0.0
2626

27-
def append_token(self, token_id: int, logprob: float) -> None:
27+
def append_token_id(self, token_id: int, logprob: float) -> None:
2828
self.output_token_ids.append(token_id)
2929
self.cumulative_logprob += logprob
3030

@@ -64,6 +64,7 @@ def __init__(
6464

6565
self.data = SequenceData(prompt_token_ids)
6666
self.output_logprobs: List[Dict[int, float]] = []
67+
self.output_tokens: List[str] = []
6768
self.output_text = ""
6869

6970
self.logical_token_blocks: List[LogicalTokenBlock] = []
@@ -92,11 +93,15 @@ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
9293
last_block.append_tokens(token_ids[:num_empty_slots])
9394
token_ids = token_ids[num_empty_slots:]
9495

95-
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
96+
def append_token_id(
97+
self,
98+
token_id: int,
99+
logprobs: Dict[int, float],
100+
) -> None:
96101
assert token_id in logprobs
97102
self._append_tokens_to_blocks([token_id])
98103
self.output_logprobs.append(logprobs)
99-
self.data.append_token(token_id, logprobs[token_id])
104+
self.data.append_token_id(token_id, logprobs[token_id])
100105

101106
def get_len(self) -> int:
102107
return self.data.get_len()

cacheflow/server/llm_server.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from cacheflow.sampling_params import SamplingParams
1515
from cacheflow.server.arg_utils import ServerArgs
1616
from cacheflow.server.ray_utils import initialize_cluster
17-
from cacheflow.server.tokenizer_utils import get_tokenizer
17+
from cacheflow.server.tokenizer_utils import (get_tokenizer,
18+
detokenize_incrementally)
1819
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
1920
from cacheflow.utils import Counter
2021
from cacheflow.worker.worker import Worker
@@ -185,18 +186,17 @@ def step(self) -> List[RequestOutput]:
185186
return request_outputs
186187

187188
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
188-
# Batch-decode the sequence outputs.
189-
seqs: List[Sequence] = []
189+
# Decode the sequence outputs.
190190
for seq_group in seq_groups:
191-
seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING))
192-
output_tokens_per_seq = []
193-
for seq in seqs:
194-
output_tokens_per_seq.append(seq.get_output_token_ids())
195-
output_texts = self.tokenizer.batch_decode(output_tokens_per_seq,
196-
skip_special_tokens=True)
197-
# Update the sequences with the output texts.
198-
for seq, output_text in zip(seqs, output_texts):
199-
seq.output_text = output_text
191+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
192+
new_token, new_output_text = detokenize_incrementally(
193+
self.tokenizer,
194+
seq.output_tokens,
195+
seq.get_last_token_id(),
196+
skip_special_tokens=True,
197+
)
198+
seq.output_tokens.append(new_token)
199+
seq.output_text = new_output_text
200200

201201
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
202202
# Stop the sequences.

cacheflow/server/tokenizer_utils.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
from typing import Union
1+
from typing import List, Tuple, Union
22

33
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
44
PreTrainedTokenizerFast)
55

6+
from cacheflow.logger import init_logger
7+
8+
logger = init_logger(__name__)
9+
610
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
711
# LLaMA fast tokenizer has a bug related to protobuf.
812
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
@@ -17,5 +21,62 @@ def get_tokenizer(
1721
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
1822
config = AutoConfig.from_pretrained(model_name)
1923
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
24+
if getattr(kwargs, "use_fast", False) == True:
25+
raise ValueError(
26+
f"Cannot use the fast tokenizer for {config.model_type} due to "
27+
"bugs in the fast tokenizer.")
28+
logger.info(
29+
f"Using the slow tokenizer for {config.model_type} due to bugs in "
30+
"the fast tokenizer. This could potentially lead to performance "
31+
"degradation.")
2032
kwargs["use_fast"] = False
2133
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
34+
35+
36+
def detokenize_incrementally(
37+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
38+
prev_output_tokens: List[str],
39+
new_token_id: int,
40+
skip_special_tokens: bool,
41+
) -> Tuple[str, str]:
42+
"""Detokenizes the new token in conjuction with the previous output tokens.
43+
44+
NOTE: This function does not update prev_output_tokens.
45+
46+
Returns:
47+
new_token: The new token as a string.
48+
output_text: The new output text as a string.
49+
"""
50+
new_token = tokenizer.convert_ids_to_tokens(
51+
new_token_id, skip_special_tokens=skip_special_tokens)
52+
output_tokens = prev_output_tokens + [new_token]
53+
54+
# Convert the tokens to a string.
55+
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
56+
# then we can directly use `convert_tokens_to_string`.
57+
if not getattr(tokenizer, "added_tokens_encoder", {}):
58+
output_text = tokenizer.convert_tokens_to_string(output_tokens)
59+
return new_token, output_text
60+
61+
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
62+
# NOTE(woosuk): The following code is slow because it runs a for loop over
63+
# the output_tokens. In Python, running a for loop over a list can be slow
64+
# even when the loop body is very simple.
65+
sub_texts = []
66+
current_sub_text = []
67+
for token in output_tokens:
68+
if skip_special_tokens and token in tokenizer.all_special_ids:
69+
continue
70+
if token in tokenizer.added_tokens_encoder:
71+
if current_sub_text:
72+
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
73+
sub_texts.append(sub_text)
74+
current_sub_text = []
75+
sub_texts.append(token)
76+
else:
77+
current_sub_text.append(token)
78+
if current_sub_text:
79+
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
80+
sub_texts.append(sub_text)
81+
output_text = " ".join(sub_texts)
82+
return new_token, output_text

0 commit comments

Comments
 (0)