diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 5785894a2..2c2e36182 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -20,6 +20,7 @@ "to_sharegpt", "standardize_sharegpt", + "standardize_data_formats", "apply_chat_template", "train_on_responses_only", @@ -37,7 +38,9 @@ import re from unsloth_zoo.dataset_utils import ( train_on_responses_only, + standardize_data_formats, ) +standardize_sharegpt = standardize_data_formats CHAT_TEMPLATES = {} DEFAULT_SYSTEM_MESSAGE = {} @@ -934,6 +937,84 @@ pass +# =========================================== Gemma-3 +# Obtained via +# print(tokenizer.chat_template.replace("}\n", "####").replace("\n", "\\n").replace("####", "}\n")) +gemma3_template = \ +"""{{ bos_token }} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{%- for message in loop_messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {{ '' + role + '\n' + (first_user_prefix if loop.first else "") }} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + {{ '\n' }} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{ 'model\n' }} +{%- endif -%} +""" + +# Ollama from https://ollama.com/library/gemma3/blobs/e0a42594d802 +gemma3_ollama = \ +''' +FROM {__FILE_LOCATION__} +TEMPLATE """{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 }} +{{- if or (eq .Role "user") (eq .Role "system") }}user +{{ .Content }} +{{ if $last }}model +{{ end }} +{{- else if eq .Role "assistant" }}model +{{ .Content }}{{ if not $last }} +{{ end }} +{{- end }} +{{- end }}""" +PARAMETER stop "" +PARAMETER stop "" +PARAMETER temperature 0.1 +PARAMETER min_p 0.0 +PARAMETER top_k 64 +PARAMETER top_p 0.95 +PARAMETER num_predict 32768 +''' + +gemma3_template_eos_token = "" +CHAT_TEMPLATES["gemma-3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,) +DEFAULT_SYSTEM_MESSAGE["gemma-3"] = None # No system message in Gemma-3 + +CHAT_TEMPLATES["gemma3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,) +DEFAULT_SYSTEM_MESSAGE["gemma3"] = None # No system message in Gemma-3 +pass + def _change_system_message(template: str, type_chat_template: str, system_message: str = None): system_message_pattern = r"\{system_message\}" @@ -1033,11 +1114,12 @@ def get_chat_template( # Check fast tokenizer if not is_fast_tokenizer: - print( - "Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\ - "Please log a Github issue if you want this as a new feature!\n"\ - "Your chat template will still work, but it won't add or edit tokens." - ) + pass + # print( + # "Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\ + # "Please log a Github issue if you want this as a new feature!\n"\ + # "Your chat template will still work, but it won't add or edit tokens." + # ) elif token_mapping is not None: # token_mapping = {"" : "<|im_start|>", "" : "<|im_end|>"} @@ -1396,82 +1478,6 @@ def __convert_to_sharegpt__(examples): pass -def standardize_sharegpt( - dataset, - aliases_for_system = ["system",], - aliases_for_user = ["user", "human", "input",], - aliases_for_assistant = ["gpt", "assistant", "output",], -): - """ - Standardizes ShareGPT and other formats to user/assistant Hugging Face format. - - Get aliases for the system, user and assistant roles. - These shall map to "system", "user" and "assistant" respectively. - - aliases_for_system = ["system",], - aliases_for_user = ["user", "human", "input",], - aliases_for_assistant = ["gpt", "assistant", "output",], - """ - import collections - import itertools - - convos = dataset[:10]["conversations"] - uniques = collections.defaultdict(list) - for convo in convos: - for message in convo: - for key, value in message.items(): - uniques[key].append(value) - pass - - # Must be only 2 entries - assert(len(uniques.keys()) == 2) - - keys = list(uniques.keys()) - length_first = len(set(uniques[keys[0]])) - length_second = len(set(uniques[keys[1]])) - - if length_first < length_second: - # Role is assigned to the first element - role_key = keys[0] - content_key = keys[1] - else: - role_key = keys[1] - content_key = keys[0] - pass - - # Check roles are in aliases - all_aliases = set(aliases_for_system + aliases_for_user + aliases_for_assistant) - roles = set(uniques[role_key]) - leftover_aliases = (all_aliases | roles) - all_aliases - if len(leftover_aliases) != 0: - raise TypeError( - f"Unsloth: {list(leftover_aliases)} are not in aliases. Please update aliases." - ) - pass - - # Mapping for aliases - aliases_mapping = {} - for x in aliases_for_system: aliases_mapping[x] = "system" - for x in aliases_for_user: aliases_mapping[x] = "user" - for x in aliases_for_assistant: aliases_mapping[x] = "assistant" - - def _standardize_dataset(examples): - convos = examples["conversations"] - all_convos = [] - for convo in convos: - new_convo = [ - { "role" : aliases_mapping[message[role_key]], "content" : message[content_key], } - for message in convo - ] - all_convos.append(new_convo) - pass - return { "conversations" : all_convos, } - pass - - return dataset.map(_standardize_dataset, batched = True, desc = "Standardizing format") -pass - - def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []): added_tokens_decoder = tokenizer.added_tokens_decoder.values() added_tokens_decoder = [str(x) for x in added_tokens_decoder] @@ -1934,6 +1940,11 @@ def formatting_prompts_func(examples): tokenizer._ollama_modelfile = modelfile tokenizer._unsloth_input_part = input_part tokenizer._unsloth_output_part = output_part + if hasattr(tokenizer, "tokenizer"): + tokenizer.tokenizer.chat_template = jinja_template + tokenizer.tokenizer._ollama_modelfile = modelfile + tokenizer.tokenizer._unsloth_input_part = input_part + tokenizer.tokenizer._unsloth_output_part = output_part return dataset.map(formatting_prompts_func, batched = True,) pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 77bfa8762..a3fc12f6d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -71,6 +71,7 @@ from platform import system as platform_system platform_system = platform_system() import numpy as np +import contextlib import warnings, subprocess, re, inspect, psutil, os, math from unsloth_zoo.utils import Version @@ -113,6 +114,11 @@ from unsloth_zoo.training_utils import ( prepare_model_for_training, ) +from unsloth_zoo.temporary_patches import ( + TEMPORARY_PATCHES, +) +for temporary_patch in TEMPORARY_PATCHES: + temporary_patch() # ============================================= # Disable some warnings which can get annoying @@ -981,7 +987,14 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient" ) pass - return self._old_compute_loss(model, inputs, *args, **kwargs) + + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": + autocaster = contextlib.nullcontext() + else: + autocaster = torch.autocast(device_type = "cuda", dtype = torch.float32) + with autocaster: + outputs = self._old_compute_loss(model, inputs, *args, **kwargs) + return outputs pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7ae6e92d1..700073985 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -38,6 +38,7 @@ from ..tokenizer_utils import * if HAS_FLASH_ATTENTION: from flash_attn import flash_attn_func +from .vision import FastBaseModel # Final patching code from transformers.models.llama.modeling_llama import ( @@ -1648,6 +1649,7 @@ def from_pretrained( disable_log_stats = False, **kwargs, ): + os.environ["UNSLOTH_USE_NEW_MODEL"] = "0" if trust_remote_code: if fast_inference: raise NotImplementedError("Unsloth: Fast inference does not support `trust_remote_code` yet.") @@ -2016,6 +2018,31 @@ def get_peft_model( temporary_location = "_unsloth_temporary_saved_buffers", **kwargs, ): + if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1": + return FastBaseModel.get_peft_model( + model = model, + r = r, + target_modules = target_modules, + lora_alpha = lora_alpha, + lora_dropout = lora_dropout, + bias = bias, + finetune_vision_layers = False, + finetune_language_layers = True, + finetune_attention_modules = True, + finetune_mlp_modules = True, + layers_to_transform = layers_to_transform, + layers_pattern = layers_pattern, + use_gradient_checkpointing = use_gradient_checkpointing, + random_state = random_state, + max_seq_length = max_seq_length, + use_rslora = use_rslora, + modules_to_save = modules_to_save, + init_lora_weights = init_lora_weights, + loftq_config = loftq_config, + temporary_location = temporary_location, + **kwargs, + ) + pass if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1": print("Unsloth: Full finetuning is enabled, so .get_peft_model has no effect") return model @@ -2435,6 +2462,12 @@ def patch_peft_model( model, use_gradient_checkpointing = True, ): + if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1": + return FastBaseModel.patch_peft_model( + model = model, + use_gradient_checkpointing = use_gradient_checkpointing, + ) + pass if not isinstance(model, PeftModelForCausalLM): raise TypeError( "Unsloth: Your model needs to call `.get_peft_model` first!" diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 92a166f69..1b54c8c7f 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -70,7 +70,7 @@ class FastLanguageModel(FastLlamaModel): @staticmethod def from_pretrained( model_name = "unsloth/Llama-3.2-1B-Instruct", - max_seq_length = None, + max_seq_length = 2048, dtype = None, load_in_4bit = True, load_in_8bit = False, @@ -96,7 +96,7 @@ def from_pretrained( if load_in_8bit or full_finetuning: return FastModel.from_pretrained( model_name = model_name, - max_seq_length = max_seq_length, # [TODO] No effect + max_seq_length = max_seq_length, dtype = dtype, load_in_4bit = load_in_4bit, load_in_8bit = load_in_8bit, @@ -295,7 +295,7 @@ def from_pretrained( else: return FastModel.from_pretrained( model_name = model_name, - max_seq_length = max_seq_length, # [TODO] No effect + max_seq_length = max_seq_length, dtype = dtype, load_in_4bit = load_in_4bit, load_in_8bit = load_in_8bit, @@ -442,7 +442,7 @@ class FastModel(FastBaseModel): @staticmethod def from_pretrained( model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", - max_seq_length = None, # [TODO] No effect + max_seq_length = 2048, dtype = None, load_in_4bit = True, load_in_8bit = False, @@ -500,6 +500,8 @@ def from_pretrained( raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST) elif "aya-vision" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY) + elif "gemma-3" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): + raise RuntimeError("Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY) pass if USE_MODELSCOPE and not os.path.exists(model_name): diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index b4facf729..cb0d73c59 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -638,37 +638,45 @@ "Qwen/QwQ-32B", "unsloth/QwQ-32B-bnb-4bit", ), - "unsloth/gemma-3-1b-it-bnb-4bit" : ( + "unsloth/gemma-3-1b-it-unsloth-bnb-4bit" : ( "unsloth/gemma-3-1b-it", "google/gemma-3-1b-it", + "unsloth/gemma-3-1b-it-bnb-4bit", ), - "unsloth/gemma-3-4b-it-bnb-4bit" : ( + "unsloth/gemma-3-4b-it-unsloth-bnb-4bit" : ( "unsloth/gemma-3-4b-it", "google/gemma-3-4b-it", + "unsloth/gemma-3-4b-it-bnb-4bit", ), - "unsloth/gemma-3-12b-it-bnb-4bit" : ( + "unsloth/gemma-3-12b-it-unsloth-bnb-4bit" : ( "unsloth/gemma-3-12b-it", "google/gemma-3-12b-it", + "unsloth/gemma-3-12b-it-bnb-4bit", ), - "unsloth/gemma-3-27b-it-bnb-4bit" : ( + "unsloth/gemma-3-27b-it-unsloth-bnb-4bit" : ( "unsloth/gemma-3-27b-it", "google/gemma-3-27b-it", + "unsloth/gemma-3-27b-it-bnb-4bit", ), - "unsloth/gemma-3-1b-pt-bnb-4bit" : ( + "unsloth/gemma-3-1b-pt-unsloth-bnb-4bit" : ( "unsloth/gemma-3-1b-pt", "google/gemma-3-1b-pt", + "unsloth/gemma-3-1b-pt-bnb-4bit", ), - "unsloth/gemma-3-4b-pt-bnb-4bit" : ( + "unsloth/gemma-3-4b-pt-unsloth-bnb-4bit" : ( "unsloth/gemma-3-4b-pt", "google/gemma-3-4b-pt", + "unsloth/gemma-3-4b-pt-bnb-4bit", ), - "unsloth/gemma-3-12b-pt-bnb-4bit" : ( + "unsloth/gemma-3-12b-pt-unsloth-bnb-4bit" : ( "unsloth/gemma-3-12b-pt", "google/gemma-3-12b-pt", + "unsloth/gemma-3-12b-pt-bnb-4bit", ), - "unsloth/gemma-3-27b-pt-bnb-4bit" : ( + "unsloth/gemma-3-27b-pt-unsloth-bnb-4bit" : ( "unsloth/gemma-3-27b-pt", "google/gemma-3-27b-pt", + "unsloth/gemma-3-27b-pt-bnb-4bit", ), } diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 86a174ebf..4e158f58b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -236,15 +236,24 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): mixed_precision = \ "use_bf16 = getattr(args, 'bf16', False)\n"\ "use_fp16 = getattr(args, 'fp16', False)\n"\ + "force_float32 = False\n"\ + "if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':\n"\ + " if use_bf16 or use_fp16:\n"\ + " print('Unsloth: Switching to float32 training since model cannot work with float16')\n"\ + " force_float32 = True\n"\ "mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')\n"\ "dtype = getattr(model.config, 'torch_dtype', None)\n"\ "if dtype is None: dtype = model.get_input_embeddings().dtype\n"\ "from unsloth_zoo.utils import _get_dtype\n"\ "dtype = _get_dtype(dtype)\n"\ "float16 = dtype == torch.float16\n"\ - "if float16 and use_bf16: raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"\ - "if not float16 and use_fp16: raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\ - "if (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':\n"\ + "if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"\ + "if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\ + "if force_float32:\n"\ + " args.fp16 = False\n"\ + " args.bf16 = False\n"\ + " os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\n"\ + "elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':\n"\ " args.fp16 = float16\n"\ " args.bf16 = not float16\n"\ " os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'\n" @@ -287,7 +296,10 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "bf16_full_eval = getattr(args, 'bf16_full_eval', False)\n"\ "if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"\ "if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"\ - "if os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':\n"\ + "if force_float32:\n"\ + " args.bf16_full_eval = False\n"\ + " args.fp16_full_eval = False\n"\ + "elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':\n"\ " args.bf16_full_eval = True\n"\ " args.fp16_full_eval = False\n"\ "elif not bf16_full_eval and not fp16_full_eval:\n"\ @@ -343,11 +355,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if "data_collator" in call_args and "train_dataset" in call_args: data_collator_check = \ "if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\ - " print('Unsloth: Changing data collator to `DataCollatorForLanguageModeling` since `labels` not found.')\n"\ " data_collator = DataCollatorForLanguageModeling("\ "tokenizer = processing_class if 'processing_class' in locals() else tokenizer, mlm = False)\n"\ "elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\ - " print('Unsloth: Changing data collator to `DataCollatorForSeq2Seq` since `labels` found.')\n"\ " data_collator = DataCollatorForSeq2Seq("\ "tokenizer = processing_class if 'processing_class' in locals() else tokenizer)\n" extra_args += data_collator_check diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index fa5547ec5..2ef9d2ee9 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -25,29 +25,49 @@ except: from transformers import AutoModelForVision2Seq pass -from .llama import * from ..kernels import ( post_patch_loss_function, ) from ._utils import __version__ +from ._utils import * +from ..save import patch_saving_functions from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model +from peft import PeftModelForCausalLM from transformers import set_seed as transformers_set_seed from unsloth_zoo.peft_utils import ( get_peft_regex, SKIP_QUANTIZATION_MODULES, requires_grad_for_gradient_checkpointing, ) +from transformers.models.llama.modeling_llama import logger +from transformers import __version__ as transformers_version from triton import __version__ as triton_version from unsloth_zoo.utils import _get_dtype from unsloth_zoo.patching_utils import patch_model_and_tokenizer from unsloth_zoo.training_utils import prepare_model_for_training import types import functools +import os +import gc +import math +import functools +from typing import Optional, Tuple, List, Union +import re, inspect, sys +import types +try: + from huggingface_hub.utils import get_token +except: + # Old HF Hub versions <= 0.0.25 + from huggingface_hub.utils._token import get_token +pass __all__ = [ "FastBaseModel", ] +global FORCE_FLOAT32 +FORCE_FLOAT32 = ["gemma3"] + def unsloth_base_fast_generate( self, @@ -86,6 +106,7 @@ def unsloth_base_fast_generate( except: pass # Mixed precision autocast + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): output = self._old_generate(*args, **kwargs) pass @@ -100,7 +121,7 @@ class FastBaseModel: @staticmethod def from_pretrained( model_name = "unsloth/Llama-3.2-1B-Instruct", - max_seq_length = None, + max_seq_length = 2048, dtype = None, load_in_4bit = True, load_in_8bit = False, @@ -114,6 +135,7 @@ def from_pretrained( use_gradient_checkpointing = "unsloth", **kwargs, ): + os.environ["UNSLOTH_USE_NEW_MODEL"] = "1" if trust_remote_code: print( "Unsloth: WARNING `trust_remote_code` is True.\n"\ @@ -129,8 +151,12 @@ def from_pretrained( try: vllm_version = f" vLLM: {importlib_version('vllm')}." except: vllm_version = "" + model_type_arch = model_types[0] + if model_type_arch == "siglip" and len(model_types) != 1: + model_type_arch = model_types[1] + statistics = \ - f"==((====))== Unsloth {__version__}: Fast {model_types[0].title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ + f"==((====))== Unsloth {__version__}: Fast {model_type_arch.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ @@ -156,6 +182,17 @@ def from_pretrained( assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) + global FORCE_FLOAT32 + os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" + bnb_compute_dtype = dtype + for disable_name in FORCE_FLOAT32: + if disable_name.lower() == model_type_arch.lower() and dtype == torch.float16: + print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") + os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" + bnb_compute_dtype = torch.float32 + break + pass + bnb_config = None if full_finetuning and (load_in_4bit or load_in_8bit): print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.") @@ -170,13 +207,13 @@ def from_pretrained( load_in_4bit = True, bnb_4bit_use_double_quant = True, bnb_4bit_quant_type = "nf4", - bnb_4bit_compute_dtype = dtype, - llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES, + bnb_4bit_compute_dtype = bnb_compute_dtype, + llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), ) elif load_in_8bit: bnb_config = BitsAndBytesConfig( load_in_8bit = True, - llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES, + llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), ) elif not load_in_4bit and not load_in_8bit and not full_finetuning: print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") @@ -185,8 +222,8 @@ def from_pretrained( load_in_4bit = True, bnb_4bit_use_double_quant = True, bnb_4bit_quant_type = "nf4", - bnb_4bit_compute_dtype = dtype, - llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES, + bnb_4bit_compute_dtype = bnb_compute_dtype, + llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), ) pass @@ -212,7 +249,7 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - # attn_implementation = "sdpa", [TODO] Pixtral for eg fails + attn_implementation = "sdpa", #[TODO] Pixtral for eg fails **kwargs, ) # Return old flag @@ -408,12 +445,7 @@ def post_patch_model( from transformers.trainer import Trainer if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop": - raise RuntimeError( - 'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\ - 'enabling it will require much more work, so we have to prioritize. Please understand!\n'\ - 'We do have a separate beta version, which you can contact us about!\n'\ - 'Thank you for your understanding and we appreciate it immensely!' - ) + raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop') pass patch_saving_functions(model, vision = True)