Skip to content

Commit b605555

Browse files
vasquArthurZucker
andauthored
[Mistral Tokenizers] Fix tokenizer detection (#42389)
* fix * sanity check * style * comments * make it v5 explicit * make explicit fixes possible in local tokenizers * remove hub usage on local * fix * extend test for no config case * move mistral patch outside to separate fn * fix local path only * add a tes * make sure test does not pass before this PR * styling * make sure it exists * fix * fix * rename * up * last nit i hope lord --------- Co-authored-by: Arthur <[email protected]>
1 parent 1ce12ac commit b605555

File tree

4 files changed

+149
-33
lines changed

4 files changed

+149
-33
lines changed

src/transformers/tokenization_utils_base.py

Lines changed: 83 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2099,12 +2099,13 @@ def from_pretrained(
20992099
template = template.removesuffix(".jinja")
21002100
vocab_files[f"chat_template_{template}"] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"
21012101

2102+
remote_files = []
21022103
if not is_local and not local_files_only:
21032104
try:
21042105
remote_files = list_repo_files(pretrained_model_name_or_path)
21052106
except Exception:
21062107
remote_files = []
2107-
else:
2108+
elif pretrained_model_name_or_path and os.path.isdir(pretrained_model_name_or_path):
21082109
remote_files = os.listdir(pretrained_model_name_or_path)
21092110

21102111
if "tokenizer_file" in vocab_files and not re.search(vocab_files["tokenizer_file"], "".join(remote_files)):
@@ -2437,57 +2438,108 @@ def _from_pretrained(
24372438
except NotImplementedError:
24382439
vocab_size = 0
24392440

2441+
# Optionally patches mistral tokenizers with wrong regex
24402442
if (
24412443
vocab_size > 100000
24422444
and hasattr(tokenizer, "_tokenizer")
24432445
and getattr(tokenizer._tokenizer, "pre_tokenizer", None) is not None
24442446
):
2445-
from huggingface_hub import model_info
2447+
tokenizer = cls._patch_mistral_regex(
2448+
tokenizer,
2449+
pretrained_model_name_or_path,
2450+
token=token,
2451+
cache_dir=cache_dir,
2452+
local_files_only=local_files_only,
2453+
_commit_hash=_commit_hash,
2454+
_is_local=_is_local,
2455+
init_kwargs=init_kwargs,
2456+
fix_mistral_regex=kwargs.get("fix_mistral_regex"),
2457+
)
24462458

2447-
def is_base_mistral(model_id: str) -> bool:
2448-
model = model_info(model_id)
2449-
if model.tags is not None:
2450-
if re.search("base_model:.*mistralai", "".join(model.tags)):
2451-
return True
2452-
return False
2459+
return tokenizer
24532460

2454-
if _is_local or is_base_mistral(pretrained_model_name_or_path):
2455-
_config_file = cached_file(
2456-
pretrained_model_name_or_path,
2457-
"config.json",
2458-
cache_dir=cache_dir,
2459-
token=token,
2460-
local_files_only=local_files_only,
2461-
_raise_exceptions_for_missing_entries=False,
2462-
_raise_exceptions_for_connection_errors=False,
2463-
_commit_hash=_commit_hash,
2464-
)
2465-
if _config_file is not None:
2466-
with open(_config_file, encoding="utf-8") as f:
2467-
_config = json.load(f)
2468-
transformers_version = _config.get("transformers_version")
2461+
@classmethod
2462+
def _patch_mistral_regex(
2463+
cls,
2464+
tokenizer,
2465+
pretrained_model_name_or_path,
2466+
token=None,
2467+
cache_dir=None,
2468+
local_files_only=False,
2469+
_commit_hash=None,
2470+
_is_local=False,
2471+
init_kwargs=None,
2472+
fix_mistral_regex=None,
2473+
):
2474+
"""
2475+
Patches mistral related tokenizers with incorrect regex if detected
2476+
1) Local file with an associated config saved next to it
2477+
>> Model type one of the mistral models (on older versions)
2478+
2) Remote models on the hub from official mistral models
2479+
>> Tags including `base_model:.*mistralai`
2480+
"""
2481+
from huggingface_hub import model_info
24692482

2470-
if transformers_version and version.parse(transformers_version) <= version.parse("4.57.2"):
2471-
if _is_local and _config.model_type not in [
2483+
def is_base_mistral(model_id: str) -> bool:
2484+
model = model_info(model_id)
2485+
if model.tags is not None:
2486+
if re.search("base_model:.*mistralai", "".join(model.tags)):
2487+
return True
2488+
return False
2489+
2490+
if _is_local or is_base_mistral(pretrained_model_name_or_path):
2491+
_config_file = cached_file(
2492+
pretrained_model_name_or_path,
2493+
"config.json",
2494+
cache_dir=cache_dir,
2495+
token=token,
2496+
local_files_only=local_files_only,
2497+
_raise_exceptions_for_missing_entries=False,
2498+
_raise_exceptions_for_connection_errors=False,
2499+
_commit_hash=_commit_hash,
2500+
)
2501+
2502+
# Detected using a (local) mistral tokenizer
2503+
mistral_config_detected = False
2504+
if _config_file is not None:
2505+
with open(_config_file, encoding="utf-8") as f:
2506+
_config = json.load(f)
2507+
transformers_version = _config.get("transformers_version")
2508+
transformers_model_type = _config.get("model_type")
2509+
2510+
# Detect if we can skip the mistral fix by
2511+
# a) having a non-mistral tokenizer
2512+
# b) fixed version of transformers
2513+
if transformers_version and version.parse(transformers_version) <= version.parse("4.57.2"):
2514+
if (
2515+
_is_local
2516+
and transformers_model_type is not None
2517+
and transformers_model_type
2518+
not in [
24722519
"mistral",
24732520
"mistral3",
2474-
"voxstral",
2521+
"voxtral",
24752522
"ministral",
24762523
"pixtral",
2477-
]:
2478-
return tokenizer
2524+
]
2525+
):
2526+
return tokenizer
2527+
elif transformers_version and version.parse(transformers_version) >= version.parse("5.0.0"):
2528+
return tokenizer
24792529

2530+
mistral_config_detected = True
2531+
2532+
if mistral_config_detected or (not _is_local and is_base_mistral(pretrained_model_name_or_path)):
24802533
# Expose the `fix_mistral_regex` flag on the tokenizer when provided, even if no correction is applied.
2481-
if "fix_mistral_regex" in init_kwargs:
2534+
if init_kwargs and "fix_mistral_regex" in init_kwargs:
24822535
setattr(tokenizer, "fix_mistral_regex", init_kwargs["fix_mistral_regex"])
24832536

2484-
fix_mistral_regex = kwargs.get("fix_mistral_regex") # not init kwargs
24852537
# only warn if its not explicitly passed
24862538
if fix_mistral_regex is None and not getattr(tokenizer, "fix_mistral_regex", False):
24872539
setattr(tokenizer, "fix_mistral_regex", False)
24882540
logger.warning(
24892541
f"The tokenizer you are loading from '{pretrained_model_name_or_path}'"
2490-
f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. "
2542+
f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e."
24912543
" This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue."
24922544
)
24932545
elif fix_mistral_regex is True or getattr(tokenizer, "fix_mistral_regex", False):
@@ -2500,7 +2552,6 @@ def is_base_mistral(model_id: str) -> bool:
25002552
),
25012553
behavior="isolated",
25022554
)
2503-
25042555
return tokenizer
25052556

25062557
@staticmethod

tests/models/auto/test_tokenization_auto.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,13 @@
3434
GPT2Tokenizer,
3535
GPT2TokenizerFast,
3636
PreTrainedTokenizerFast,
37+
Qwen2Tokenizer,
38+
Qwen2TokenizerFast,
39+
Qwen3MoeConfig,
3740
RobertaTokenizer,
3841
RobertaTokenizerFast,
3942
is_tokenizers_available,
43+
logging,
4044
)
4145
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
4246
from transformers.models.auto.tokenization_auto import (
@@ -49,6 +53,7 @@
4953
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
5054
DUMMY_UNKNOWN_IDENTIFIER,
5155
SMALL_MODEL_IDENTIFIER,
56+
CaptureLogger,
5257
RequestCounter,
5358
is_flaky,
5459
require_tokenizers,
@@ -229,6 +234,40 @@ def test_auto_tokenizer_from_local_folder(self):
229234
self.assertIsInstance(tokenizer2, tokenizer.__class__)
230235
self.assertEqual(tokenizer2.vocab_size, 12)
231236

237+
def test_auto_tokenizer_from_local_folder_mistral_detection(self):
238+
"""See #42374 for reference, ensuring proper mistral detection on local tokenizers"""
239+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-235B-A22B-Thinking-2507")
240+
config = Qwen3MoeConfig.from_pretrained("Qwen/Qwen3-235B-A22B-Thinking-2507")
241+
self.assertIsInstance(tokenizer, (Qwen2Tokenizer, Qwen2TokenizerFast))
242+
243+
with tempfile.TemporaryDirectory() as tmp_dir:
244+
tokenizer.save_pretrained(tmp_dir)
245+
246+
# Case 1: Tokenizer with no config associated
247+
logger = logging.get_logger("transformers.tokenization_utils_base")
248+
with CaptureLogger(logger) as cl:
249+
AutoTokenizer.from_pretrained(tmp_dir)
250+
self.assertNotIn(
251+
"with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e",
252+
cl.out,
253+
)
254+
255+
# Case 2: Tokenizer with config associated
256+
# Needed to be saved along the tokenizer to detect (non)mistral
257+
# for a version where the regex bug occurs
258+
config_dict = config.to_diff_dict()
259+
config_dict["transformers_version"] = "4.57.2"
260+
261+
# Manually saving to avoid versioning clashes
262+
config_path = os.path.join(tmp_dir, "config.json")
263+
with open(config_path, "w", encoding="utf-8") as f:
264+
json.dump(config_dict, f, indent=2, sort_keys=True)
265+
266+
tokenizer2 = AutoTokenizer.from_pretrained(tmp_dir)
267+
268+
self.assertIsInstance(tokenizer2, tokenizer.__class__)
269+
self.assertTrue(tokenizer2.vocab_size > 100_000)
270+
232271
def test_auto_tokenizer_fast_no_slow(self):
233272
tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
234273
# There is no fast CTRL so this always gives us a slow tokenizer.

tests/models/llama/test_tokenization_llama.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@
4949
@require_sentencepiece
5050
@require_tokenizers
5151
class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
52-
from_pretrained_id = ["hf-internal-testing/llama-tokenizer", "meta-llama/Llama-2-7b-hf"]
52+
from_pretrained_id = [
53+
"hf-internal-testing/llama-tokenizer",
54+
"meta-llama/Llama-2-7b-hf",
55+
"meta-llama/Meta-Llama-3-8B",
56+
]
5357
tokenizer_class = LlamaTokenizer
5458
rust_tokenizer_class = LlamaTokenizerFast
5559

tests/test_tokenization_common.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4670,3 +4670,25 @@ def test_empty_input_string(self):
46704670
for return_type, target_type in zip(tokenizer_return_type, output_tensor_type):
46714671
output = tokenizer(empty_input_string, return_tensors=return_type)
46724672
self.assertEqual(output.input_ids.dtype, target_type)
4673+
4674+
def test_local_files_only(self):
4675+
from transformers import AutoTokenizer
4676+
4677+
pretrained_list = getattr(self, "from_pretrained_id", []) or []
4678+
for pretrained_name in pretrained_list:
4679+
with self.subTest(f"AutoTokenizer ({pretrained_name})"):
4680+
# First cache the tokenizer files
4681+
try:
4682+
tokenizer_cached = AutoTokenizer.from_pretrained(pretrained_name)
4683+
4684+
# Now load with local_files_only=True
4685+
tokenizer_local = AutoTokenizer.from_pretrained(pretrained_name, local_files_only=True)
4686+
4687+
# Check that the two tokenizers are identical
4688+
self.assertEqual(tokenizer_cached.get_vocab(), tokenizer_local.get_vocab())
4689+
self.assertEqual(
4690+
tokenizer_cached.all_special_tokens_extended,
4691+
tokenizer_local.all_special_tokens_extended,
4692+
)
4693+
except Exception as _:
4694+
pass # if the pretrained model is not loadable how could it pass locally :)

0 commit comments

Comments
 (0)