diff --git a/pyproject.toml b/pyproject.toml index d438c83d6..4cadd3aa6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.5.5", + "unsloth_zoo>=2025.5.6", "packaging", "tyro", "transformers==4.51.3,!=4.47.0", @@ -381,7 +381,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.5.5", + "unsloth_zoo>=2025.5.6", "packaging", "tyro", "transformers==4.51.3,!=4.47.0", diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 5c8cc87a5..cfb3ece47 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -1036,9 +1036,21 @@ {%- endif %} {%- endif %} {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} +{%- for forward_message in messages %} {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set message = messages[index] %} + {%- set current_content = message.content if message.content is not none else '' %} + {%- set tool_start = '' %} + {%- set tool_start_length = tool_start|length %} + {%- set start_of_message = current_content[:tool_start_length] %} + {%- set tool_end = '' %} + {%- set tool_end_length = tool_end|length %} + {%- set start_pos = (current_content|length) - tool_end_length %} + {%- if start_pos < 0 %} + {%- set start_pos = 0 %} + {%- endif %} + {%- set end_of_message = current_content[start_pos:] %} + {%- if ns.multi_step_tool and message.role == "user" and not(start_of_message == tool_start and end_of_message == tool_end) %} {%- set ns.multi_step_tool = false %} {%- set ns.last_query_index = index %} {%- endif %} @@ -1053,8 +1065,9 @@ {%- set reasoning_content = message.reasoning_content %} {%- else %} {%- if '' in message.content %} - {%- set content = message.content.split('')[-1].lstrip('\n') %} - {%- set reasoning_content = message.content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = (message.content.split('')|last).lstrip('\n') %} + {%- set reasoning_content = (message.content.split('')|first).rstrip('\n') %} + {%- set reasoning_content = (reasoning_content.split('')|last).lstrip('\n') %} {%- endif %} {%- endif %} {%- if loop.index0 > ns.last_query_index %} @@ -1110,7 +1123,7 @@ qwen3_ollama = \ ''' FROM {__FILE_LOCATION__} -TEMPLATE """{{ if .Messages }} +TEMPLATE """{{- if .Messages }} {{- if or .System .Tools }}<|im_start|>system {{- if .System }} {{ .System }} @@ -1161,8 +1174,12 @@ {{ end }}<|im_start|>assistant {{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}""" PARAMETER stop "<|im_end|>" -PARAMETER temperature 1.5 -PARAMETER min_p 0.1 +PARAMETER stop "<|im_start|>" +PARAMETER temperature 0.6 +PARAMETER min_p 0.0 +PARAMETER top_k 20 +PARAMETER top_p 0.95 +PARAMETER repeat_penalty 1 ''' qwen3_template_eos_token = "<|im_end|>" diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 882de28cb..118f4f053 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.5.3" +__version__ = "2025.5.4" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index a1dbc8253..a233b26a8 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -541,10 +541,12 @@ def from_pretrained( if transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "csm-1b" in model_name.lower(): - os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" - os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1" + os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Sesame fails + os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) + elif "whisper" in model_name.lower(): + os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Whisper fails pass if USE_MODELSCOPE and not os.path.exists(model_name): diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index d723fc4bd..4bbd8295c 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -817,6 +817,26 @@ "microsoft/Phi-4-mini-reasoning", "unsloth/phi-4-mini-reasoning-bnb-4bit", ), + "unsloth/csm-1b" : ( + "unsloth/csm-1b", + "sesame/csm-1b", + ), + "unsloth/whisper-large-v3" : ( + "unsloth/whisper-large-v3", + "openai/whisper-large-v3", + ), + "unsloth/whisper-large-v3-turbo" : ( + "unsloth/whisper-large-v3-turbo", + "openai/whisper-large-v3-turbo", + ), + "unsloth/whisper-small" : ( + "unsloth/whisper-small", + "openai/whisper-small", + ), + "unsloth/CrisperWhisper" : ( + "unsloth/CrisperWhisper", + "nyrahealth/CrisperWhisper", + ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index cadfed943..2ba7c1391 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -188,7 +188,10 @@ def unsloth_base_fast_generate( # Use hybrid if sliding window seen, otherwise try static cache_implementation = getattr(self.config, "cache_implementation", None) if getattr(self, "_supports_static_cache", True): - cache_implementation = "static" + if os.environ.get("UNSLOTH_DISABLE_STATIC_GENERATION", "0") == "0": + cache_implementation = "static" + else: + cache_implementation = None else: cache_implementation = None if cache_implementation is not None: @@ -199,10 +202,10 @@ def unsloth_base_fast_generate( cache_implementation = "hybrid" if "generation_config" in kwargs: kwargs["generation_config"].cache_implementation = cache_implementation - kwargs["generation_config"].compile_config = _compile_config + kwargs["generation_config"].compile_config = _compile_config if cache_implementation is not None else None else: kwargs["cache_implementation"] = cache_implementation - kwargs["compile_config"] = _compile_config + kwargs["compile_config"] = _compile_config if cache_implementation is not None else None pass try: @@ -310,6 +313,19 @@ def from_pretrained( bnb_compute_dtype = torch.float16 do_forced_float32 = True pass + + # Check for custom data-types + custom_datatype = None + correct_dtype = None + if os.environ.get("UNSLOTH_FORCE_CUSTOM_DTYPE", "") != "": + custom_datatype = os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] + assert custom_datatype.count(";") == 1 + bnb_compute_dtype, custom_datatype = custom_datatype.split(";", 1) + dtype = torch.float32 + bnb_compute_dtype = eval(bnb_compute_dtype) + correct_dtype = bnb_compute_dtype + pass + # Stop SDPA for some archs like Pixtral / Mistral3 if not ("attn_implementation" in kwargs): kwargs["attn_implementation"] = "sdpa" @@ -374,12 +390,18 @@ def from_pretrained( # Return old flag os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer + # Edit data-types + if custom_datatype is not None: + for name, module in model.named_modules(): + exec(custom_datatype) + pass + # Counteract saved tokenizers tokenizer_name = model_name if tokenizer_name is None else tokenizer_name is_vlm = (auto_model is AutoModelForVision2Seq) is_whisper = (whisper_language is not None and whisper_task is not None) auto_processor = AutoProcessor if (is_vlm or is_whisper) else AutoTokenizer - if whisper_language and whisper_task: + if (whisper_language and whisper_task) or auto_model.__name__.endswith("ForConditionalGeneration"): tokenizer = auto_processor.from_pretrained( tokenizer_name, padding_side = "right", @@ -415,6 +437,7 @@ def from_pretrained( downcast_rope = False, fix_embeddings = False, do_forced_float32 = do_forced_float32, + correct_dtype = correct_dtype, ) model, tokenizer = patch_tokenizer(model, tokenizer) model = post_patch_loss_function(model)