diff --git a/README.md b/README.md index 1f85647f9..e6098cbeb 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ See [here](https://github.com/unslothai/unsloth/edit/main/README.md#advanced-pip 7. **Install Unsloth:** ```python -pip install "unsloth[windows] @ git+https://github.com/unslothai/unsloth.git" +pip install unsloth ``` #### Notes diff --git a/pyproject.toml b/pyproject.toml index 667901e76..7b1d2efda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,14 +33,11 @@ exclude = ["images*"] [project.optional-dependencies] triton = [ - "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", - "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", - "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", - "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'" + "triton-windows ; platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.3.9", + "unsloth_zoo>=2025.3.11", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 9bcdd5cf6..7ffddde9b 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.9"): + if Version(unsloth_zoo_version) < Version("2025.3.11"): print( "Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\ "To disable this, set os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'" diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 2c2e36182..c10b2641a 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -1512,10 +1512,7 @@ def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []): # Remove duplicates splitted = joined_text.split("\x01\x00") - final_eos_tokens = [] - for old, new in zip(added_tokens_decoder, splitted): - if old == new: final_eos_tokens.append(old) - pass + final_eos_tokens = [old for old, new in zip(added_tokens_decoder, splitted) if old == new] final_eos_tokens += extra_eos_tokens final_eos_tokens += repeatted_tokens diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 006dfff63..834a74c66 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -37,12 +37,12 @@ def _cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , + VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , - SOFTCAP , - DO_LOGIT_SCALING , - LOGIT_SCALE , + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING : tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -111,13 +111,13 @@ def _chunked_cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - N_CHUNKS , + VOCAB_SIZE : tl.constexpr, + N_CHUNKS : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , - SOFTCAP , - DO_LOGIT_SCALING , - LOGIT_SCALE , + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING : tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ 256K vocab divided in 4 chunks @@ -196,12 +196,12 @@ def _cross_entropy_backward( dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , + VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , - SOFTCAP , - DO_LOGIT_SCALING , - LOGIT_SCALE , + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING : tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) diff --git a/unsloth/kernels/layernorm.py b/unsloth/kernels/layernorm.py index 26a77f03a..ed8182014 100644 --- a/unsloth/kernels/layernorm.py +++ b/unsloth/kernels/layernorm.py @@ -30,7 +30,8 @@ def layernorm_forward( b, r, mu, - n_cols, eps, + n_cols : tl.constexpr, + eps : tl.constexpr, BLOCK_SIZE : tl.constexpr ): row_idx = tl.program_id(0) @@ -68,7 +69,8 @@ def layernorm_backward( b, r, mu, - n_cols, eps, + n_cols : tl.constexpr, + eps : tl.constexpr, BLOCK_SIZE : tl.constexpr ): # Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 1cde6388e..8f54e7490 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -22,9 +22,10 @@ def _rms_layernorm_forward( Y, Y_row_stride, X, X_row_stride, W, W_row_stride, - r, r_row_stride, - n_cols, eps, - BLOCK_SIZE : tl.constexpr + r, r_row_stride : tl.constexpr, + n_cols : tl.constexpr, + eps : tl.constexpr, + BLOCK_SIZE : tl.constexpr, ): """ Fast RMS Layernorm kernel @@ -57,9 +58,10 @@ def _rms_layernorm_backward( dX, dX_row_stride, X, X_row_stride, W, W_row_stride, - r, r_row_stride, + r, r_row_stride : tl.constexpr, # dW, dW_row_stride, - n_cols, eps, + n_cols : tl.constexpr, + eps : tl.constexpr, GEMMA : tl.constexpr, BLOCK_SIZE : tl.constexpr, ): @@ -107,8 +109,9 @@ def _gemma_rms_layernorm_forward( Y, Y_row_stride, X, X_row_stride, W, W_row_stride, - r, r_row_stride, - n_cols, eps, + r, r_row_stride : tl.constexpr, + n_cols : tl.constexpr, + eps : tl.constexpr, BLOCK_SIZE : tl.constexpr, ): # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31 @@ -253,7 +256,6 @@ def unpatch_rms_layernorm(): except: pass return - return pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a3fc12f6d..06a76b19d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.10" +__version__ = "2025.3.11" __all__ = [ "SUPPORTS_BFLOAT16", @@ -72,6 +72,7 @@ platform_system = platform_system() import numpy as np import contextlib +import re import warnings, subprocess, re, inspect, psutil, os, math from unsloth_zoo.utils import Version @@ -181,6 +182,34 @@ def filter(self, x): return not (self.text in x.getMessage()) except: pass +# Patch get_model_param_count to record correct 4bit / 8bit +from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled +def get_model_param_count(model, trainable_only = False): + """ + Calculate model's total param count. If trainable_only is True then count only those requiring grads + """ + if is_deepspeed_zero3_enabled(): + def numel(p): + return p.ds_numel if hasattr(p, "ds_numel") else p.numel() + else: + def numel(p): + return p.numel() + s = sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad) + if (not trainable_only) and \ + hasattr(model, "config") and \ + hasattr(model.config, "quantization_config"): + + billions = re.findall(r"([0-9]{1,})(?:b|B)", model.config.name_or_path) + if len(billions) != 0: + billions = int(billions[0]) + s = 1_000_000_000 * billions + pass + return s +pass +import transformers.trainer_pt_utils +transformers.trainer_pt_utils.get_model_param_count = get_model_param_count +import transformers.trainer +transformers.trainer.get_model_param_count = get_model_param_count # ============================================= # ============================================= diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 700073985..893a09dd1 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1663,6 +1663,10 @@ def from_pretrained( if platform.system().lower() == 'windows': print("Unsloth: vLLM does not work in Windows! Will use Unsloth inference!") fast_inference = False + major_version, minor_version = torch.cuda.get_device_capability() + if major_version < 7: + print("Unsloth: vLLM does not work on older GPUs - will switch to Unsloth inference!") + fast_inference = False pass if token is None: token = get_token() @@ -1786,6 +1790,8 @@ def from_pretrained( attn_implementation = "eager", **kwargs, ) + model.fast_generate = model.generate + model.fast_generate_batches = None else: from unsloth_zoo.vllm_utils import ( load_vllm, @@ -1804,6 +1810,7 @@ def from_pretrained( enable_lora = True, max_lora_rank = max_lora_rank, disable_log_stats = disable_log_stats, + use_bitsandbytes = load_in_4bit, ) for allowed_arg in allowed_args: if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs: @@ -2651,6 +2658,19 @@ def patch_peft_model( torch.cuda.empty_cache() pass + # Patch for fast inference + vllm_engine = getattr(model.model, "vllm_engine", None) + if vllm_engine is not None: + model.vllm_engine = model.model.vllm_engine + model.fast_generate = model.model.fast_generate + model.fast_generate_batches = model.model.fast_generate_batches + + # Also saving and loading LoRA + from unsloth_zoo.vllm_utils import save_lora, load_lora + model.save_lora = functools.partial(save_lora, model) + model.load_lora = functools.partial(load_lora, model) + pass + # Add for_inference and for_training model.for_training = functools.partial(FastLlamaModel.for_training, model) model.for_inference = functools.partial(FastLlamaModel.for_inference, model) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 1b54c8c7f..44475780a 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -405,7 +405,6 @@ def from_pretrained( if is_peft: # From https://github.com/huggingface/peft/issues/184 # Now add PEFT adapters - model.enable_input_require_grads() model = PeftModel.from_pretrained( model, old_model_name, @@ -498,10 +497,22 @@ def from_pretrained( raise RuntimeError("Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST) elif "qwen2.5" in model_name.lower() and transformers_version < Version("4.49.0"): raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST) - elif "aya-vision" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): - raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY) + elif "aya-vision" in model_name.lower(): + # Disable compiling for now - errors out! + os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" + if transformers_version < Version("4.50.0.dev0"): + raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "gemma-3" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY) + elif "c4ai-command-a-03-2025" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): + raise RuntimeError("Unsloth: Cohere's Command model only works on transformers >= 4.50.0." + NIGHTLY) + elif "granite-vision" in model_name.lower(): + # Disable compiling for now - errors out! + os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" + if transformers_version < Version("4.50.0.dev0"): + raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) + elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): + raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) pass if USE_MODELSCOPE and not os.path.exists(model_name): @@ -668,7 +679,7 @@ def from_pretrained( use_gradient_checkpointing = use_gradient_checkpointing, *args, **kwargs, ) - + if resize_model_vocab is not None: model.resize_token_embeddings(resize_model_vocab) pass @@ -703,7 +714,6 @@ def from_pretrained( if is_peft: # From https://github.com/huggingface/peft/issues/184 # Now add PEFT adapters - model.enable_input_require_grads() model = PeftModel.from_pretrained( model, old_model_name, diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index cb0d73c59..9af531798 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -62,6 +62,16 @@ "unsloth/llama-2-7b-chat", "meta-llama/Llama-2-7b-chat-hf", ), + "unsloth/Mixtral-8x7B-v0.1-unsloth-bnb-4bit" : ( + "unsloth/Mixtral-8x7B-v0.1", + "mistralai/Mixtral-8x7B-v0.1", + "unsloth/Mixtral-8x7B-v0.1-bnb-4bit", + ), + "unsloth/Mixtral-8x7B-Instruct-v0.1-unsloth-bnb-4bit" : ( + "unsloth/Mixtral-8x7B-Instruct-v0.1", + "mistralai/Mixtral-8x7B-Instruct-v0.1", + "unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit", + ), "unsloth/codellama-7b-bnb-4bit" : ( "unsloth/codellama-7b", "codellama/CodeLlama-7b-hf", @@ -678,6 +688,36 @@ "google/gemma-3-27b-pt", "unsloth/gemma-3-27b-pt-bnb-4bit", ), + "unsloth/reka-flash-3-unsloth-bnb-4bit" : ( + "unsloth/reka-flash-3", + "RekaAI/reka-flash-3", + "unsloth/reka-flash-3-bnb-4bit", + ), + "unsloth/c4ai-command-a-03-2025-unsloth-bnb-4bit" : ( + "unsloth/c4ai-command-a-03-2025", + "CohereForAI/c4ai-command-a-03-2025", + "unsloth/c4ai-command-a-03-2025-bnb-4bit", + ), + "unsloth/aya-vision-32b-unsloth-bnb-4bit" : ( + "unsloth/aya-vision-32b", + "CohereForAI/aya-vision-32b", + "unsloth/aya-vision-32b-bnb-4bit", + ), + "unsloth/aya-vision-8b-unsloth-bnb-4bit" : ( + "unsloth/aya-vision-8b", + "CohereForAI/aya-vision-8b", + "unsloth/aya-vision-8b-bnb-4bit", + ), + "unsloth/granite-vision-3.2-2b-unsloth-bnb-4bit" : ( + "unsloth/granite-vision-3.2-2b", + "ibm-granite/granite-vision-3.2-2b", + "unsloth/granite-vision-3.2-2b-bnb-4bit", + ), + "unsloth/OLMo-2-0325-32B-Instruct-unsloth-bnb-4bit" : ( + "unsloth/OLMo-2-0325-32B-Instruct", + "allenai/OLMo-2-0325-32B-Instruct", + "unsloth/OLMo-2-0325-32B-Instruct-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 4e158f58b..e412c3a5a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -354,13 +354,28 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Check data collator if it's correct! if "data_collator" in call_args and "train_dataset" in call_args: data_collator_check = \ - "if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\ - " data_collator = DataCollatorForLanguageModeling("\ - "tokenizer = processing_class if 'processing_class' in locals() else tokenizer, mlm = False)\n"\ - "elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\ - " data_collator = DataCollatorForSeq2Seq("\ - "tokenizer = processing_class if 'processing_class' in locals() else tokenizer)\n" + "__tokenizer = processing_class if 'processing_class' in locals() else tokenizer\n"\ + "from unsloth_zoo.vision_utils import UnslothVisionDataCollator\n"\ + "if not isinstance(data_collator, UnslothVisionDataCollator):\n"\ + " if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\ + " data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)\n"\ + " elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\ + " data_collator = DataCollatorForSeq2Seq(__tokenizer)\n"\ + "else:\n"\ + " if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False\n"\ + " if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''\n"\ + " if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}\n" extra_args += data_collator_check + + # Also check if .pad exists -> if not, and is VLM, then change it! + pad_check = \ + "if not isinstance(data_collator, UnslothVisionDataCollator):\n"\ + " if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):\n"\ + " if isinstance(data_collator, DataCollatorForSeq2Seq):\n"\ + " data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)\n"\ + " else:\n"\ + " data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)\n" + extra_args += pad_check pass # Check NEFTune diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 7462d5594..4071ef835 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -207,9 +207,12 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - return None # Unsloth efficient GRPO + if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': + return None # Unsloth efficient GRPO + # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 + if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float32 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits @@ -229,12 +232,14 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) -grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] -UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"] -grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] +grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +grpo_compute_loss_slow = RL_REPLACEMENTS["grpo_compute_loss_slow"] +UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"] +grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss)) +RL_PRE_ITEMS["grpo_trainer"].append(grpo_compute_loss_slow) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): @@ -266,8 +271,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - if False:#per_token_logps is not None: - loss, completion_length, mean_kl = grpo_compute_loss( + if per_token_logps is not None: + loss, completion_length, mean_kl = grpo_compute_loss_slow( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) else: diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 2ef9d2ee9..31733c297 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -66,8 +66,17 @@ ] global FORCE_FLOAT32 -FORCE_FLOAT32 = ["gemma3"] +FORCE_FLOAT32 = [ + "gemma3", +] + +global FORCE_EAGER_ATTENTION +FORCE_EAGER_ATTENTION = [ + "pixtral", # Pixtral SDPA not implemented +] +global NUM_LOGITS_TO_KEEP +NUM_LOGITS_TO_KEEP = dict() def unsloth_base_fast_generate( self, @@ -78,21 +87,45 @@ def unsloth_base_fast_generate( dtype = _get_dtype(self.config.torch_dtype) # Check if VLM - is_vlm = ( + is_vlm = any( x.endswith(("ForConditionalGeneration", "ForVisionText2Text")) for x in self.config.architectures ) is_vlm = is_vlm or hasattr(self.config, "vision_config") + arch = self.config.architectures[0] # Remove token_type_ids kwargs.pop("token_type_ids", None) # VLMs do not allow logits_to_keep if not is_vlm: - kwargs["logits_to_keep"] = 1 + global NUM_LOGITS_TO_KEEP + if arch not in NUM_LOGITS_TO_KEEP: + m = self + # Find which is needed ie + # num_logits_to_keep or logits_to_keep + while hasattr(m, "model"): + if hasattr(m, "forward"): + keys = inspect.signature(m.forward).parameters.keys() + if "num_logits_to_keep" in keys: + NUM_LOGITS_TO_KEEP[arch] = "num_logits_to_keep" + break + elif "logits_to_keep" in keys: + NUM_LOGITS_TO_KEEP[arch] = "logits_to_keep" + break + m = m.model + pass + if arch not in NUM_LOGITS_TO_KEEP: + NUM_LOGITS_TO_KEEP[arch] = None + pass + pass + key = NUM_LOGITS_TO_KEEP[arch] + if key is not None and key not in kwargs: + kwargs[key] = 1 else: - kwargs.pop("logits_to_keep", None) - kwargs.pop("num_logits_to_keep", None) + pass + # kwargs.pop("logits_to_keep", None) + # kwargs.pop("num_logits_to_keep", None) # Check pad_token model_eos_token_id = getattr(self.config, "eos_token_id", None) @@ -186,13 +219,27 @@ def from_pretrained( os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" bnb_compute_dtype = dtype for disable_name in FORCE_FLOAT32: - if disable_name.lower() == model_type_arch.lower() and dtype == torch.float16: + if (disable_name.lower() == model_type_arch.lower() or \ + disable_name.lower() in model_name.lower()) and \ + dtype == torch.float16: + print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" bnb_compute_dtype = torch.float32 break pass + global FORCE_EAGER_ATTENTION + attn_implementation = "sdpa" + for disable_name in FORCE_EAGER_ATTENTION: + if (disable_name.lower() == model_type_arch.lower() or \ + disable_name.lower() in model_name.lower()): + + print(f"Unsloth: {model_type_arch} does not support SDPA - switching to eager!") + attn_implementation = "eager" + break + pass + bnb_config = None if full_finetuning and (load_in_4bit or load_in_8bit): print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.") @@ -249,7 +296,7 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - attn_implementation = "sdpa", #[TODO] Pixtral for eg fails + attn_implementation = attn_implementation, **kwargs, ) # Return old flag @@ -263,10 +310,20 @@ def from_pretrained( padding_side = "right", token = token, ) - # Add padding side as well if hasattr(tokenizer, "tokenizer"): - tokenizer.tokenizer.padding_side = "right" - + __tokenizer = tokenizer.tokenizer + # Add padding side as well + __tokenizer.padding_side = "right" + # Check bos, eos, pad, unk tokens + tokens = ["bos_token", "eos_token", "pad_token", "unk_token"] + for token in tokens: + if hasattr(__tokenizer, token) and not hasattr(tokenizer, token): + _args = {"__tokenizer" : __tokenizer, "tokenizer" : tokenizer} + exec(f"tokenizer.{token} = __tokenizer.{token}", _args) + exec(f"tokenizer.{token}_id = __tokenizer.{token}_id", _args) + pass + pass + pass model, tokenizer = patch_tokenizer(model, tokenizer) model = post_patch_loss_function(model) # Fix other stuff like BnB compute data types diff --git a/unsloth/save.py b/unsloth/save.py index d03f47e87..4b2c01298 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -2219,6 +2219,10 @@ def unsloth_convert_lora_to_ggml_and_save_locally( from .models.loader_utils import get_model_name from unsloth_zoo.saving_utils import merge_and_overwrite_lora +from unsloth_zoo.llama_cpp import ( + install_llama_cpp, + convert_to_gguf, +) @torch.inference_mode def unsloth_generic_save(