|
50 | 50 | from common.multimodal import MultimodalEmbeddingWrapper |
51 | 51 | from common.sampling import BaseSamplerRequest |
52 | 52 | from common.templating import PromptTemplate, find_prompt_template |
53 | | -from common.transformers_utils import GenerationConfig |
| 53 | +from common.transformers_utils import GenerationConfig, TokenizerConfig |
54 | 54 | from common.utils import calculate_rope_alpha, coalesce, unwrap |
55 | 55 | from endpoints.core.types.model import ModelCard, ModelCardParameters |
56 | 56 |
|
@@ -80,6 +80,7 @@ class ExllamaV2Container(BaseModelContainer): |
80 | 80 | draft_cache_mode: str = "FP16" |
81 | 81 | max_batch_size: Optional[int] = None |
82 | 82 | generation_config: Optional[GenerationConfig] = None |
| 83 | + tokenizer_config: Optional[TokenizerConfig] = None |
83 | 84 |
|
84 | 85 | # GPU split vars |
85 | 86 | gpu_split: List[float] = [] |
@@ -130,14 +131,27 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): |
130 | 131 | if generation_config_path.exists(): |
131 | 132 | try: |
132 | 133 | self.generation_config = await GenerationConfig.from_file( |
133 | | - generation_config_path.parent |
| 134 | + model_directory |
134 | 135 | ) |
135 | 136 | except Exception: |
136 | 137 | logger.error(traceback.format_exc()) |
137 | 138 | logger.warning( |
138 | 139 | "Skipping generation config load because of an unexpected error." |
139 | 140 | ) |
140 | 141 |
|
| 142 | + # Load tokenizer config overrides |
| 143 | + tokenizer_config_path = model_directory / "tokenizer_config.json" |
| 144 | + if tokenizer_config_path.exists(): |
| 145 | + try: |
| 146 | + self.tokenizer_config = await TokenizerConfig.from_file( |
| 147 | + model_directory |
| 148 | + ) |
| 149 | + except Exception: |
| 150 | + logger.error(traceback.format_exc()) |
| 151 | + logger.warning( |
| 152 | + "Skipping tokenizer config load because of an unexpected error." |
| 153 | + ) |
| 154 | + |
141 | 155 | # Set vision state and error if vision isn't supported on the current model |
142 | 156 | self.use_vision = unwrap(kwargs.get("vision"), False) |
143 | 157 | if self.use_vision and not self.config.vision_model_type: |
@@ -1240,9 +1254,17 @@ async def generate_gen( |
1240 | 1254 | ) and gen_settings.token_repetition_range == -1 |
1241 | 1255 |
|
1242 | 1256 | stop_conditions = params.stop |
1243 | | - add_bos_token = unwrap(params.add_bos_token, True) |
1244 | 1257 | ban_eos_token = params.ban_eos_token |
1245 | 1258 |
|
| 1259 | + |
| 1260 | + print(self.tokenizer_config.add_bos_token) |
| 1261 | + # Set add_bos_token for generation |
| 1262 | + add_bos_token = coalesce( |
| 1263 | + params.add_bos_token, self.tokenizer_config.add_bos_token, True |
| 1264 | + ) |
| 1265 | + |
| 1266 | + print(add_bos_token) |
| 1267 | + |
1246 | 1268 | # Fetch EOS tokens from generation_config if they exist |
1247 | 1269 | eos_tokens = ( |
1248 | 1270 | self.generation_config.eos_tokens() |
|
0 commit comments