@@ -2099,12 +2099,13 @@ def from_pretrained(
20992099 template = template .removesuffix (".jinja" )
21002100 vocab_files [f"chat_template_{ template } " ] = f"{ CHAT_TEMPLATE_DIR } /{ template } .jinja"
21012101
2102+ remote_files = []
21022103 if not is_local and not local_files_only :
21032104 try :
21042105 remote_files = list_repo_files (pretrained_model_name_or_path )
21052106 except Exception :
21062107 remote_files = []
2107- else :
2108+ elif pretrained_model_name_or_path and os . path . isdir ( pretrained_model_name_or_path ) :
21082109 remote_files = os .listdir (pretrained_model_name_or_path )
21092110
21102111 if "tokenizer_file" in vocab_files and not re .search (vocab_files ["tokenizer_file" ], "" .join (remote_files )):
@@ -2437,57 +2438,108 @@ def _from_pretrained(
24372438 except NotImplementedError :
24382439 vocab_size = 0
24392440
2441+ # Optionally patches mistral tokenizers with wrong regex
24402442 if (
24412443 vocab_size > 100000
24422444 and hasattr (tokenizer , "_tokenizer" )
24432445 and getattr (tokenizer ._tokenizer , "pre_tokenizer" , None ) is not None
24442446 ):
2445- from huggingface_hub import model_info
2447+ tokenizer = cls ._patch_mistral_regex (
2448+ tokenizer ,
2449+ pretrained_model_name_or_path ,
2450+ token = token ,
2451+ cache_dir = cache_dir ,
2452+ local_files_only = local_files_only ,
2453+ _commit_hash = _commit_hash ,
2454+ _is_local = _is_local ,
2455+ init_kwargs = init_kwargs ,
2456+ fix_mistral_regex = kwargs .get ("fix_mistral_regex" ),
2457+ )
24462458
2447- def is_base_mistral (model_id : str ) -> bool :
2448- model = model_info (model_id )
2449- if model .tags is not None :
2450- if re .search ("base_model:.*mistralai" , "" .join (model .tags )):
2451- return True
2452- return False
2459+ return tokenizer
24532460
2454- if _is_local or is_base_mistral (pretrained_model_name_or_path ):
2455- _config_file = cached_file (
2456- pretrained_model_name_or_path ,
2457- "config.json" ,
2458- cache_dir = cache_dir ,
2459- token = token ,
2460- local_files_only = local_files_only ,
2461- _raise_exceptions_for_missing_entries = False ,
2462- _raise_exceptions_for_connection_errors = False ,
2463- _commit_hash = _commit_hash ,
2464- )
2465- if _config_file is not None :
2466- with open (_config_file , encoding = "utf-8" ) as f :
2467- _config = json .load (f )
2468- transformers_version = _config .get ("transformers_version" )
2461+ @classmethod
2462+ def _patch_mistral_regex (
2463+ cls ,
2464+ tokenizer ,
2465+ pretrained_model_name_or_path ,
2466+ token = None ,
2467+ cache_dir = None ,
2468+ local_files_only = False ,
2469+ _commit_hash = None ,
2470+ _is_local = False ,
2471+ init_kwargs = None ,
2472+ fix_mistral_regex = None ,
2473+ ):
2474+ """
2475+ Patches mistral related tokenizers with incorrect regex if detected
2476+ 1) Local file with an associated config saved next to it
2477+ >> Model type one of the mistral models (on older versions)
2478+ 2) Remote models on the hub from official mistral models
2479+ >> Tags including `base_model:.*mistralai`
2480+ """
2481+ from huggingface_hub import model_info
24692482
2470- if transformers_version and version .parse (transformers_version ) <= version .parse ("4.57.2" ):
2471- if _is_local and _config .model_type not in [
2483+ def is_base_mistral (model_id : str ) -> bool :
2484+ model = model_info (model_id )
2485+ if model .tags is not None :
2486+ if re .search ("base_model:.*mistralai" , "" .join (model .tags )):
2487+ return True
2488+ return False
2489+
2490+ if _is_local or is_base_mistral (pretrained_model_name_or_path ):
2491+ _config_file = cached_file (
2492+ pretrained_model_name_or_path ,
2493+ "config.json" ,
2494+ cache_dir = cache_dir ,
2495+ token = token ,
2496+ local_files_only = local_files_only ,
2497+ _raise_exceptions_for_missing_entries = False ,
2498+ _raise_exceptions_for_connection_errors = False ,
2499+ _commit_hash = _commit_hash ,
2500+ )
2501+
2502+ # Detected using a (local) mistral tokenizer
2503+ mistral_config_detected = False
2504+ if _config_file is not None :
2505+ with open (_config_file , encoding = "utf-8" ) as f :
2506+ _config = json .load (f )
2507+ transformers_version = _config .get ("transformers_version" )
2508+ transformers_model_type = _config .get ("model_type" )
2509+
2510+ # Detect if we can skip the mistral fix by
2511+ # a) having a non-mistral tokenizer
2512+ # b) fixed version of transformers
2513+ if transformers_version and version .parse (transformers_version ) <= version .parse ("4.57.2" ):
2514+ if (
2515+ _is_local
2516+ and transformers_model_type is not None
2517+ and transformers_model_type
2518+ not in [
24722519 "mistral" ,
24732520 "mistral3" ,
2474- "voxstral " ,
2521+ "voxtral " ,
24752522 "ministral" ,
24762523 "pixtral" ,
2477- ]:
2478- return tokenizer
2524+ ]
2525+ ):
2526+ return tokenizer
2527+ elif transformers_version and version .parse (transformers_version ) >= version .parse ("5.0.0" ):
2528+ return tokenizer
24792529
2530+ mistral_config_detected = True
2531+
2532+ if mistral_config_detected or (not _is_local and is_base_mistral (pretrained_model_name_or_path )):
24802533 # Expose the `fix_mistral_regex` flag on the tokenizer when provided, even if no correction is applied.
2481- if "fix_mistral_regex" in init_kwargs :
2534+ if init_kwargs and "fix_mistral_regex" in init_kwargs :
24822535 setattr (tokenizer , "fix_mistral_regex" , init_kwargs ["fix_mistral_regex" ])
24832536
2484- fix_mistral_regex = kwargs .get ("fix_mistral_regex" ) # not init kwargs
24852537 # only warn if its not explicitly passed
24862538 if fix_mistral_regex is None and not getattr (tokenizer , "fix_mistral_regex" , False ):
24872539 setattr (tokenizer , "fix_mistral_regex" , False )
24882540 logger .warning (
24892541 f"The tokenizer you are loading from '{ pretrained_model_name_or_path } '"
2490- f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. "
2542+ f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e."
24912543 " This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue."
24922544 )
24932545 elif fix_mistral_regex is True or getattr (tokenizer , "fix_mistral_regex" , False ):
@@ -2500,7 +2552,6 @@ def is_base_mistral(model_id: str) -> bool:
25002552 ),
25012553 behavior = "isolated" ,
25022554 )
2503-
25042555 return tokenizer
25052556
25062557 @staticmethod
0 commit comments