Skip to content

Commit 47cb2a0

Browse files
committed
Model: Add TokenizerConfig stub and add_eos_token fallback
This stub fetches the add_eos_token field from the HF tokenizer config. Ideally, this should be in the backend rather than tabby. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
1 parent aa657fa commit 47cb2a0

3 files changed

Lines changed: 46 additions & 3 deletions

File tree

backends/exllamav2/model.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from common.multimodal import MultimodalEmbeddingWrapper
5151
from common.sampling import BaseSamplerRequest
5252
from common.templating import PromptTemplate, find_prompt_template
53-
from common.transformers_utils import GenerationConfig
53+
from common.transformers_utils import GenerationConfig, TokenizerConfig
5454
from common.utils import calculate_rope_alpha, coalesce, unwrap
5555
from endpoints.core.types.model import ModelCard, ModelCardParameters
5656

@@ -80,6 +80,7 @@ class ExllamaV2Container(BaseModelContainer):
8080
draft_cache_mode: str = "FP16"
8181
max_batch_size: Optional[int] = None
8282
generation_config: Optional[GenerationConfig] = None
83+
tokenizer_config: Optional[TokenizerConfig] = None
8384

8485
# GPU split vars
8586
gpu_split: List[float] = []
@@ -130,14 +131,27 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
130131
if generation_config_path.exists():
131132
try:
132133
self.generation_config = await GenerationConfig.from_file(
133-
generation_config_path.parent
134+
model_directory
134135
)
135136
except Exception:
136137
logger.error(traceback.format_exc())
137138
logger.warning(
138139
"Skipping generation config load because of an unexpected error."
139140
)
140141

142+
# Load tokenizer config overrides
143+
tokenizer_config_path = model_directory / "tokenizer_config.json"
144+
if tokenizer_config_path.exists():
145+
try:
146+
self.tokenizer_config = await TokenizerConfig.from_file(
147+
model_directory
148+
)
149+
except Exception:
150+
logger.error(traceback.format_exc())
151+
logger.warning(
152+
"Skipping tokenizer config load because of an unexpected error."
153+
)
154+
141155
# Set vision state and error if vision isn't supported on the current model
142156
self.use_vision = unwrap(kwargs.get("vision"), False)
143157
if self.use_vision and not self.config.vision_model_type:
@@ -1240,9 +1254,17 @@ async def generate_gen(
12401254
) and gen_settings.token_repetition_range == -1
12411255

12421256
stop_conditions = params.stop
1243-
add_bos_token = unwrap(params.add_bos_token, True)
12441257
ban_eos_token = params.ban_eos_token
12451258

1259+
1260+
print(self.tokenizer_config.add_bos_token)
1261+
# Set add_bos_token for generation
1262+
add_bos_token = coalesce(
1263+
params.add_bos_token, self.tokenizer_config.add_bos_token, True
1264+
)
1265+
1266+
print(add_bos_token)
1267+
12461268
# Fetch EOS tokens from generation_config if they exist
12471269
eos_tokens = (
12481270
self.generation_config.eos_tokens()

common/templating.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ async def find_prompt_template(template_name, model_dir: pathlib.Path):
239239
]
240240

241241
# Add lookup from prompt template name if provided
242+
# TODO: Possibly link to the TokenizerConfig class
242243
if template_name:
243244
find_template_functions[:0] = [
244245
lambda: PromptTemplate.from_file(pathlib.Path("templates") / template_name),

common/transformers_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,23 @@ async def from_file(cls, model_directory: pathlib.Path):
5353
contents = await hf_config_json.read()
5454
hf_config_dict = json.loads(contents)
5555
return cls.model_validate(hf_config_dict)
56+
57+
58+
class TokenizerConfig(BaseModel):
59+
"""
60+
An abridged version of HuggingFace's tokenizer config.
61+
"""
62+
63+
add_bos_token: Optional[bool] = None
64+
65+
@classmethod
66+
async def from_file(cls, model_directory: pathlib.Path):
67+
"""Create an instance from a tokenizer config file."""
68+
69+
tokenizer_config_path = model_directory / "tokenizer_config.json"
70+
async with aiofiles.open(
71+
tokenizer_config_path, "r", encoding="utf8"
72+
) as tokenizer_config_json:
73+
contents = await tokenizer_config_json.read()
74+
tokenizer_config_dict = json.loads(contents)
75+
return cls.model_validate(tokenizer_config_dict)

0 commit comments

Comments
 (0)