diff --git a/pyproject.toml b/pyproject.toml
index d438c83d6..4cadd3aa6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -37,7 +37,7 @@ triton = [
]
huggingface = [
- "unsloth_zoo>=2025.5.5",
+ "unsloth_zoo>=2025.5.6",
"packaging",
"tyro",
"transformers==4.51.3,!=4.47.0",
@@ -381,7 +381,7 @@ colab-ampere-torch220 = [
"flash-attn>=2.6.3",
]
colab-new = [
- "unsloth_zoo>=2025.5.5",
+ "unsloth_zoo>=2025.5.6",
"packaging",
"tyro",
"transformers==4.51.3,!=4.47.0",
diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py
index 5c8cc87a5..cfb3ece47 100644
--- a/unsloth/chat_templates.py
+++ b/unsloth/chat_templates.py
@@ -1036,9 +1036,21 @@
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
-{%- for message in messages[::-1] %}
+{%- for forward_message in messages %}
{%- set index = (messages|length - 1) - loop.index0 %}
- {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %}
+ {%- set message = messages[index] %}
+ {%- set current_content = message.content if message.content is not none else '' %}
+ {%- set tool_start = '' %}
+ {%- set tool_start_length = tool_start|length %}
+ {%- set start_of_message = current_content[:tool_start_length] %}
+ {%- set tool_end = '' %}
+ {%- set tool_end_length = tool_end|length %}
+ {%- set start_pos = (current_content|length) - tool_end_length %}
+ {%- if start_pos < 0 %}
+ {%- set start_pos = 0 %}
+ {%- endif %}
+ {%- set end_of_message = current_content[start_pos:] %}
+ {%- if ns.multi_step_tool and message.role == "user" and not(start_of_message == tool_start and end_of_message == tool_end) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
@@ -1053,8 +1065,9 @@
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '' in message.content %}
- {%- set content = message.content.split('')[-1].lstrip('\n') %}
- {%- set reasoning_content = message.content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %}
+ {%- set content = (message.content.split('')|last).lstrip('\n') %}
+ {%- set reasoning_content = (message.content.split('')|first).rstrip('\n') %}
+ {%- set reasoning_content = (reasoning_content.split('')|last).lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
@@ -1110,7 +1123,7 @@
qwen3_ollama = \
'''
FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .Messages }}
+TEMPLATE """{{- if .Messages }}
{{- if or .System .Tools }}<|im_start|>system
{{- if .System }}
{{ .System }}
@@ -1161,8 +1174,12 @@
{{ end }}<|im_start|>assistant
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}"""
PARAMETER stop "<|im_end|>"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
+PARAMETER stop "<|im_start|>"
+PARAMETER temperature 0.6
+PARAMETER min_p 0.0
+PARAMETER top_k 20
+PARAMETER top_p 0.95
+PARAMETER repeat_penalty 1
'''
qwen3_template_eos_token = "<|im_end|>"
diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py
index 882de28cb..118f4f053 100644
--- a/unsloth/models/_utils.py
+++ b/unsloth/models/_utils.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "2025.5.3"
+__version__ = "2025.5.4"
__all__ = [
"SUPPORTS_BFLOAT16",
diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py
index a1dbc8253..a233b26a8 100644
--- a/unsloth/models/loader.py
+++ b/unsloth/models/loader.py
@@ -541,10 +541,12 @@ def from_pretrained(
if transformers_version < Version("4.50.0.dev0"):
raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY)
elif "csm-1b" in model_name.lower():
- os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
- os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1"
+ os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Sesame fails
+ os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)"
elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"):
raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY)
+ elif "whisper" in model_name.lower():
+ os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Whisper fails
pass
if USE_MODELSCOPE and not os.path.exists(model_name):
diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py
index d723fc4bd..4bbd8295c 100644
--- a/unsloth/models/mapper.py
+++ b/unsloth/models/mapper.py
@@ -817,6 +817,26 @@
"microsoft/Phi-4-mini-reasoning",
"unsloth/phi-4-mini-reasoning-bnb-4bit",
),
+ "unsloth/csm-1b" : (
+ "unsloth/csm-1b",
+ "sesame/csm-1b",
+ ),
+ "unsloth/whisper-large-v3" : (
+ "unsloth/whisper-large-v3",
+ "openai/whisper-large-v3",
+ ),
+ "unsloth/whisper-large-v3-turbo" : (
+ "unsloth/whisper-large-v3-turbo",
+ "openai/whisper-large-v3-turbo",
+ ),
+ "unsloth/whisper-small" : (
+ "unsloth/whisper-small",
+ "openai/whisper-small",
+ ),
+ "unsloth/CrisperWhisper" : (
+ "unsloth/CrisperWhisper",
+ "nyrahealth/CrisperWhisper",
+ ),
}
INT_TO_FLOAT_MAPPER = {}
diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py
index cadfed943..2ba7c1391 100644
--- a/unsloth/models/vision.py
+++ b/unsloth/models/vision.py
@@ -188,7 +188,10 @@ def unsloth_base_fast_generate(
# Use hybrid if sliding window seen, otherwise try static
cache_implementation = getattr(self.config, "cache_implementation", None)
if getattr(self, "_supports_static_cache", True):
- cache_implementation = "static"
+ if os.environ.get("UNSLOTH_DISABLE_STATIC_GENERATION", "0") == "0":
+ cache_implementation = "static"
+ else:
+ cache_implementation = None
else:
cache_implementation = None
if cache_implementation is not None:
@@ -199,10 +202,10 @@ def unsloth_base_fast_generate(
cache_implementation = "hybrid"
if "generation_config" in kwargs:
kwargs["generation_config"].cache_implementation = cache_implementation
- kwargs["generation_config"].compile_config = _compile_config
+ kwargs["generation_config"].compile_config = _compile_config if cache_implementation is not None else None
else:
kwargs["cache_implementation"] = cache_implementation
- kwargs["compile_config"] = _compile_config
+ kwargs["compile_config"] = _compile_config if cache_implementation is not None else None
pass
try:
@@ -310,6 +313,19 @@ def from_pretrained(
bnb_compute_dtype = torch.float16
do_forced_float32 = True
pass
+
+ # Check for custom data-types
+ custom_datatype = None
+ correct_dtype = None
+ if os.environ.get("UNSLOTH_FORCE_CUSTOM_DTYPE", "") != "":
+ custom_datatype = os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"]
+ assert custom_datatype.count(";") == 1
+ bnb_compute_dtype, custom_datatype = custom_datatype.split(";", 1)
+ dtype = torch.float32
+ bnb_compute_dtype = eval(bnb_compute_dtype)
+ correct_dtype = bnb_compute_dtype
+ pass
+
# Stop SDPA for some archs like Pixtral / Mistral3
if not ("attn_implementation" in kwargs):
kwargs["attn_implementation"] = "sdpa"
@@ -374,12 +390,18 @@ def from_pretrained(
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
+ # Edit data-types
+ if custom_datatype is not None:
+ for name, module in model.named_modules():
+ exec(custom_datatype)
+ pass
+
# Counteract saved tokenizers
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
is_vlm = (auto_model is AutoModelForVision2Seq)
is_whisper = (whisper_language is not None and whisper_task is not None)
auto_processor = AutoProcessor if (is_vlm or is_whisper) else AutoTokenizer
- if whisper_language and whisper_task:
+ if (whisper_language and whisper_task) or auto_model.__name__.endswith("ForConditionalGeneration"):
tokenizer = auto_processor.from_pretrained(
tokenizer_name,
padding_side = "right",
@@ -415,6 +437,7 @@ def from_pretrained(
downcast_rope = False,
fix_embeddings = False,
do_forced_float32 = do_forced_float32,
+ correct_dtype = correct_dtype,
)
model, tokenizer = patch_tokenizer(model, tokenizer)
model = post_patch_loss_function(model)