|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
3 | 3 | import asyncio |
4 | | -import codecs |
5 | 4 | import json |
6 | 5 | from abc import ABC, abstractmethod |
7 | 6 | from collections import defaultdict, deque |
@@ -312,16 +311,21 @@ def _resolve_chat_template_content_format( |
312 | 311 | tokenizer: AnyTokenizer, |
313 | 312 | ) -> _ChatTemplateContentFormat: |
314 | 313 | if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): |
315 | | - tokenizer_chat_template = tokenizer.chat_template |
| 314 | + try: |
| 315 | + # Prioritize processor's chat template for multi-modal models |
| 316 | + processor = cached_get_processor(tokenizer.name_or_path) |
| 317 | + hf_chat_template = processor.chat_template |
| 318 | + except Exception: |
| 319 | + hf_chat_template = tokenizer.chat_template |
316 | 320 | else: |
317 | | - tokenizer_chat_template = None |
| 321 | + hf_chat_template = None |
318 | 322 |
|
319 | 323 | jinja_text: Optional[str] |
320 | | - if isinstance(tokenizer_chat_template, str) and chat_template is None: |
321 | | - jinja_text = tokenizer_chat_template |
322 | | - elif (isinstance(tokenizer_chat_template, dict) |
323 | | - and chat_template in tokenizer_chat_template): |
324 | | - jinja_text = tokenizer_chat_template[chat_template] |
| 324 | + if isinstance(hf_chat_template, str) and chat_template is None: |
| 325 | + jinja_text = hf_chat_template |
| 326 | + elif (isinstance(hf_chat_template, dict) |
| 327 | + and chat_template in hf_chat_template): |
| 328 | + jinja_text = hf_chat_template[chat_template] |
325 | 329 | else: |
326 | 330 | jinja_text = load_chat_template(chat_template, is_literal=True) |
327 | 331 |
|
@@ -724,7 +728,7 @@ def load_chat_template( |
724 | 728 | raise TypeError("chat_template is expected to be read directly " |
725 | 729 | "from its value") |
726 | 730 |
|
727 | | - return codecs.decode(chat_template, "unicode_escape") |
| 731 | + return chat_template |
728 | 732 |
|
729 | 733 | try: |
730 | 734 | with open(chat_template) as f: |
@@ -1071,17 +1075,13 @@ def apply_hf_chat_template( |
1071 | 1075 | tokenize: bool = False, # Different from HF's default |
1072 | 1076 | **kwargs: Any, |
1073 | 1077 | ) -> str: |
1074 | | - if chat_template is None: |
1075 | | - chat_template = tokenizer.chat_template |
1076 | | - |
1077 | | - # FIXME: Temporary workaround for |
1078 | | - # https://huggingface.co/mistral-community/pixtral-12b/discussions/31 |
1079 | 1078 | if chat_template is None: |
1080 | 1079 | try: |
| 1080 | + # Prioritize processor's chat template for multi-modal models |
1081 | 1081 | processor = cached_get_processor(tokenizer.name_or_path) |
1082 | 1082 | chat_template = processor.chat_template |
1083 | 1083 | except Exception: |
1084 | | - pass |
| 1084 | + chat_template = tokenizer.chat_template |
1085 | 1085 |
|
1086 | 1086 | if chat_template is None: |
1087 | 1087 | raise ValueError( |
|
0 commit comments