1- from typing import Union
1+ from typing import List , Tuple , Union
22
33from transformers import (AutoConfig , AutoTokenizer , PreTrainedTokenizer ,
44 PreTrainedTokenizerFast )
55
6+ from cacheflow .logger import init_logger
7+
8+ logger = init_logger (__name__ )
9+
610_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
711 # LLaMA fast tokenizer has a bug related to protobuf.
812 # See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
@@ -17,5 +21,62 @@ def get_tokenizer(
1721) -> Union [PreTrainedTokenizer , PreTrainedTokenizerFast ]:
1822 config = AutoConfig .from_pretrained (model_name )
1923 if config .model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER :
24+ if getattr (kwargs , "use_fast" , False ) == True :
25+ raise ValueError (
26+ f"Cannot use the fast tokenizer for { config .model_type } due to "
27+ "bugs in the fast tokenizer." )
28+ logger .info (
29+ f"Using the slow tokenizer for { config .model_type } due to bugs in "
30+ "the fast tokenizer. This could potentially lead to performance "
31+ "degradation." )
2032 kwargs ["use_fast" ] = False
2133 return AutoTokenizer .from_pretrained (model_name , * args , ** kwargs )
34+
35+
36+ def detokenize_incrementally (
37+ tokenizer : Union [PreTrainedTokenizer , PreTrainedTokenizerFast ],
38+ prev_output_tokens : List [str ],
39+ new_token_id : int ,
40+ skip_special_tokens : bool ,
41+ ) -> Tuple [str , str ]:
42+ """Detokenizes the new token in conjuction with the previous output tokens.
43+
44+ NOTE: This function does not update prev_output_tokens.
45+
46+ Returns:
47+ new_token: The new token as a string.
48+ output_text: The new output text as a string.
49+ """
50+ new_token = tokenizer .convert_ids_to_tokens (
51+ new_token_id , skip_special_tokens = skip_special_tokens )
52+ output_tokens = prev_output_tokens + [new_token ]
53+
54+ # Convert the tokens to a string.
55+ # Optimization: If the tokenizer does not have `added_tokens_encoder`,
56+ # then we can directly use `convert_tokens_to_string`.
57+ if not getattr (tokenizer , "added_tokens_encoder" , {}):
58+ output_text = tokenizer .convert_tokens_to_string (output_tokens )
59+ return new_token , output_text
60+
61+ # Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
62+ # NOTE(woosuk): The following code is slow because it runs a for loop over
63+ # the output_tokens. In Python, running a for loop over a list can be slow
64+ # even when the loop body is very simple.
65+ sub_texts = []
66+ current_sub_text = []
67+ for token in output_tokens :
68+ if skip_special_tokens and token in tokenizer .all_special_ids :
69+ continue
70+ if token in tokenizer .added_tokens_encoder :
71+ if current_sub_text :
72+ sub_text = tokenizer .convert_tokens_to_string (current_sub_text )
73+ sub_texts .append (sub_text )
74+ current_sub_text = []
75+ sub_texts .append (token )
76+ else :
77+ current_sub_text .append (token )
78+ if current_sub_text :
79+ sub_text = tokenizer .convert_tokens_to_string (current_sub_text )
80+ sub_texts .append (sub_text )
81+ output_text = " " .join (sub_texts )
82+ return new_token , output_text
0 commit comments