From 12f97c811790e39002206f9628f4a46a48f03a91 Mon Sep 17 00:00:00 2001 From: Itsuro Tajima Date: Tue, 26 Nov 2024 20:20:34 +0900 Subject: [PATCH 001/473] use exact model name --- unsloth/models/loader.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 232fe6acf..19747cb4e 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -78,12 +78,14 @@ def from_pretrained( use_gradient_checkpointing = "unsloth", resize_model_vocab = None, revision = None, + use_exact_model_name = False, *args, **kwargs, ): if token is None: token = get_token() old_model_name = model_name - model_name = get_model_name(model_name, load_in_4bit) + if not use_exact_model_name: + model_name = get_model_name(model_name, load_in_4bit) # First check if it's a normal model via AutoConfig from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled @@ -162,7 +164,10 @@ def from_pretrained( # Get base model for PEFT: if is_peft: # Check base model again for PEFT - model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit) + if not use_exact_model_name: + model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit) + else: + model_name = peft_config.base_model_name_or_path model_config = AutoConfig.from_pretrained( model_name, token = token, @@ -249,6 +254,8 @@ def from_pretrained( tokenizer_name = None pass + original_kwargs = kwargs.copy() + model, tokenizer = dispatch_model.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, @@ -262,7 +269,7 @@ def from_pretrained( tokenizer_name = tokenizer_name, trust_remote_code = trust_remote_code, revision = revision if not is_peft else None, - *args, **kwargs, + *args, **original_kwargs, ) if resize_model_vocab is not None: @@ -347,6 +354,7 @@ def from_pretrained( use_gradient_checkpointing = "unsloth", resize_model_vocab = None, # [TODO] No effect revision = None, + use_exact_model_name = False, *args, **kwargs, ): if token is None: token = get_token() @@ -357,7 +365,8 @@ def from_pretrained( patch_unsloth_smart_gradient_checkpointing() old_model_name = model_name - model_name = get_model_name(model_name, load_in_4bit) + if not use_exact_model_name: + model_name = get_model_name(model_name, load_in_4bit) with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) @@ -462,7 +471,10 @@ def from_pretrained( # Get base model for PEFT: if is_peft: # Check base model again for PEFT - model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit) + if not use_exact_model_name: + model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit) + else: + model_name = peft_config.base_model_name_or_path model_config = AutoConfig.from_pretrained( model_name, token = token, @@ -483,6 +495,8 @@ def from_pretrained( tokenizer_name = None pass + original_kwargs = kwargs.copy() + model, tokenizer = FastBaseVisionModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, @@ -494,7 +508,7 @@ def from_pretrained( revision = revision if not is_peft else None, model_types = model_types, tokenizer_name = tokenizer_name, - *args, **kwargs, + *args, **original_kwargs, ) if resize_model_vocab is not None: From c4cb50bd1396c052280da8582798eb87f0de8dbc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 17:07:25 -0800 Subject: [PATCH 002/473] Update save.py --- unsloth/save.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/save.py b/unsloth/save.py index 8db3b6dc3..d3ba1928c 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -2131,7 +2131,8 @@ def unsloth_generic_save( if token is None and push_to_hub: token = get_token() merge_and_overwrite_lora( get_model_name, - model, + model = model, + tokenizer = tokenizer, save_directory = save_directory, push_to_hub = push_to_hub, private = private, From 75e4756a4ea8b2813f9afd80ed8252f1778dc58f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 17:11:22 -0800 Subject: [PATCH 003/473] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 4f1b40884..e508c96b0 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1104,7 +1104,7 @@ def patch_gradient_accumulation_fix(Trainer): "else:\n"\ "\2if num_items_in_batch is None:\n"\ - "\3loss /= self.args.gradient_accumulation_steps\n"\ + "\3loss = loss / self.args.gradient_accumulation_steps\n"\ "\1self.accelerator.backward(loss, **kwargs)", function, From e86b18f0470a1517bf02929ee450d15c5f59b5af Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 18:12:52 -0800 Subject: [PATCH 004/473] Update _utils.py --- unsloth/models/_utils.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index e508c96b0..1a8b20365 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1008,15 +1008,38 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): batch_samples += [next(epoch_iterator)] except StopIteration: break + if len(batch_samples) > 0 and "labels" in batch_samples[0]: try: - num_items_in_batch = sum( - [torch.count_nonzero(x["labels"][..., 1:] != -100) for x in batch_samples] - ) - except TypeError: + num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) + except (TypeError, AttributeError): pass + + if self.args.average_tokens_across_devices: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() + + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.item() + return batch_samples, num_items_in_batch -pass + +# def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): +# batch_samples = [] +# num_items_in_batch = None +# for _ in range(num_batches): +# try: +# batch_samples += [next(epoch_iterator)] +# except StopIteration: +# break +# if len(batch_samples) > 0 and "labels" in batch_samples[0]: +# try: +# num_items_in_batch = sum( +# [torch.count_nonzero(x["labels"][..., 1:] != -100) for x in batch_samples] +# ) +# except TypeError: +# pass +# return batch_samples, num_items_in_batch +# pass def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): From f565ccfea16c7854c19d310af2e0b7e6e8d3c651 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 18:19:45 -0800 Subject: [PATCH 005/473] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 1a8b20365..c9ca3eb1e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1126,6 +1126,7 @@ def patch_gradient_accumulation_fix(Trainer): r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps", "else:\n"\ + "\1print(self.args.gradient_accumulation_steps)\n" "\2if num_items_in_batch is None:\n"\ "\3loss = loss / self.args.gradient_accumulation_steps\n"\ "\1self.accelerator.backward(loss, **kwargs)", From c5d0aa983e0dc74e76af469ecf8807c31e70fc39 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 18:21:30 -0800 Subject: [PATCH 006/473] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index c9ca3eb1e..5fa6b5de5 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1009,6 +1009,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): except StopIteration: break + print("NUM_ITMES = ", num_items_in_batch) if len(batch_samples) > 0 and "labels" in batch_samples[0]: try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) From af7d6cc8710085c3a930ff99dcfce60c5043762e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 18:45:16 -0800 Subject: [PATCH 007/473] print --- unsloth/models/_utils.py | 1 - unsloth/models/llama.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5fa6b5de5..4bedce38e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1127,7 +1127,6 @@ def patch_gradient_accumulation_fix(Trainer): r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps", "else:\n"\ - "\1print(self.args.gradient_accumulation_steps)\n" "\2if num_items_in_batch is None:\n"\ "\3loss = loss / self.args.gradient_accumulation_steps\n"\ "\1self.accelerator.backward(loss, **kwargs)", diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c94514966..ddee9e901 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1009,6 +1009,7 @@ def _CausalLM_fast_forward( if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) + print(0, n_items) loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = lm_head, @@ -1055,6 +1056,7 @@ def _CausalLM_fast_forward( # Fixes https://github.com/unslothai/unsloth/issues/10 self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") pass + print(1, kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)) shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) loss = fast_cross_entropy_loss( logits = shift_logits, From 281cb7348577f8431a72f3bf81c32be3f1db3cc0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 19:12:02 -0800 Subject: [PATCH 008/473] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 4bedce38e..512812cb7 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1009,7 +1009,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): except StopIteration: break - print("NUM_ITMES = ", num_items_in_batch) + print("NUM_ITMES = ", num_items_in_batch, type(batch_samples)) if len(batch_samples) > 0 and "labels" in batch_samples[0]: try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) From b60acdad485179a64e0b176e39fe2880c60f6f19 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 19:12:13 -0800 Subject: [PATCH 009/473] Update _utils.py --- unsloth/models/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 512812cb7..14da9fc42 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1008,8 +1008,6 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): batch_samples += [next(epoch_iterator)] except StopIteration: break - - print("NUM_ITMES = ", num_items_in_batch, type(batch_samples)) if len(batch_samples) > 0 and "labels" in batch_samples[0]: try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) @@ -1022,6 +1020,8 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.item() + print("NUM_ITMES = ", num_items_in_batch, type(batch_samples)) + return batch_samples, num_items_in_batch # def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): From 855d0f8bed06b5d23588acccbb31f296518bcd09 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 19:16:58 -0800 Subject: [PATCH 010/473] Update llama.py --- unsloth/models/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ddee9e901..c94514966 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1009,7 +1009,6 @@ def _CausalLM_fast_forward( if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) - print(0, n_items) loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = lm_head, @@ -1056,7 +1055,6 @@ def _CausalLM_fast_forward( # Fixes https://github.com/unslothai/unsloth/issues/10 self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") pass - print(1, kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)) shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) loss = fast_cross_entropy_loss( logits = shift_logits, From fe4e9b8f65b40edadac22fe4a3052f215014ce88 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 19:30:49 -0800 Subject: [PATCH 011/473] Update _utils.py --- unsloth/models/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 14da9fc42..18918a3c7 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1011,7 +1011,8 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): if len(batch_samples) > 0 and "labels" in batch_samples[0]: try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) - except (TypeError, AttributeError): + except Exception as exception: + logger.warning_once(exception) pass if self.args.average_tokens_across_devices: From 48161a23427386d1a1ad7661658805a7a55e846f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 21:39:38 -0800 Subject: [PATCH 012/473] Update vision.py --- unsloth/models/vision.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 709cd1cb5..2dc4b88df 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -186,6 +186,10 @@ def from_pretrained( patch_saving_functions(model, vision = True) patch_saving_functions(tokenizer, vision = True) + # Fix gradient accumulation + from transformers.trainer import Trainer + patch_gradient_accumulation_fix(Trainer) + # Save tokenizer for inference purposes tokenizer.padding_side = "left" # Force inference tokenizer.tokenizer.padding_side = "left" # Force inference From 52b24512de064080096ec7949fbe48efbeef8aca Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 22:25:19 -0800 Subject: [PATCH 013/473] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 18918a3c7..986b938f1 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1021,7 +1021,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.item() - print("NUM_ITMES = ", num_items_in_batch, type(batch_samples)) + print("NUM_ITMES = ", num_items_in_batch, type(batch_samples), self.model) return batch_samples, num_items_in_batch From 8d39e731207c2d550f900a626eeb145d8a144553 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 23:35:15 -0800 Subject: [PATCH 014/473] Update _utils.py --- unsloth/models/_utils.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 986b938f1..c1bc7aa97 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1003,25 +1003,30 @@ def test_mask_creation(): def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): batch_samples = [] num_items_in_batch = None + + # Check if model allows **kwargs + model = self.model + f = model.base_model.model.forward if hasattr(model, "base_model") else model.forward + has_kwargs = tuple(inspect.signature(f).parameters.values())[-1].kind == inspect._VAR_KEYWORD + for _ in range(num_batches): try: batch_samples += [next(epoch_iterator)] except StopIteration: break - if len(batch_samples) > 0 and "labels" in batch_samples[0]: + if has_kwargs and len(batch_samples) > 0 and "labels" in batch_samples[0]: try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) + + if self.args.average_tokens_across_devices: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() + + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.item() except Exception as exception: logger.warning_once(exception) pass - - if self.args.average_tokens_across_devices: - num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() - - if torch.is_tensor(num_items_in_batch): - num_items_in_batch = num_items_in_batch.item() - - print("NUM_ITMES = ", num_items_in_batch, type(batch_samples), self.model) + pass return batch_samples, num_items_in_batch From a7e580386d8bdc3a7270235261d76a8e4195dad0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 23:37:25 -0800 Subject: [PATCH 015/473] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index c1bc7aa97..9725f624a 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1025,9 +1025,9 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): num_items_in_batch = num_items_in_batch.item() except Exception as exception: logger.warning_once(exception) - pass pass + print(batch_samples, num_items_in_batch) return batch_samples, num_items_in_batch # def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): From 5038ba73435265ce66c569fff04aced57b1b7727 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 23:41:44 -0800 Subject: [PATCH 016/473] Update _utils.py --- unsloth/models/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 9725f624a..2a7532a99 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1026,8 +1026,6 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): except Exception as exception: logger.warning_once(exception) pass - - print(batch_samples, num_items_in_batch) return batch_samples, num_items_in_batch # def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): @@ -1051,6 +1049,9 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): if "num_items_in_batch" in kwargs: + if kwargs["num_items_in_batch"] is None: + # Remove it since the model does not support it! + kwargs.pop("num_items_in_batch", None) if "num_items_in_batch" not in inputs: inputs["num_items_in_batch"] = kwargs["num_items_in_batch"] pass From 0882287a730fbd9af5d327da925e06d4371b29b4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 23:45:19 -0800 Subject: [PATCH 017/473] Update _utils.py --- unsloth/models/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 2a7532a99..29de5858d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1051,8 +1051,8 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): if "num_items_in_batch" in kwargs: if kwargs["num_items_in_batch"] is None: # Remove it since the model does not support it! - kwargs.pop("num_items_in_batch", None) - if "num_items_in_batch" not in inputs: + kwargs.pop("num_items_in_batch") + elif "num_items_in_batch" not in inputs: inputs["num_items_in_batch"] = kwargs["num_items_in_batch"] pass pass From ab71dce435e9f3f6c66fb4d0a018e01693ca24a7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 00:33:41 -0800 Subject: [PATCH 018/473] Update _utils.py --- unsloth/models/_utils.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 29de5858d..32b1daaa0 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1009,42 +1009,34 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): f = model.base_model.model.forward if hasattr(model, "base_model") else model.forward has_kwargs = tuple(inspect.signature(f).parameters.values())[-1].kind == inspect._VAR_KEYWORD + # Iterate to find all batches for _ in range(num_batches): try: batch_samples += [next(epoch_iterator)] except StopIteration: break + pass + + # Get num_items_in_batch if has_kwargs and len(batch_samples) > 0 and "labels" in batch_samples[0]: try: - num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) + num_items_in_batch = sum( + [(x["labels"][..., 1:] != -100).sum() for x in batch_samples] + ) + # num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) if self.args.average_tokens_across_devices: num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.item() + except Exception as exception: logger.warning_once(exception) pass - return batch_samples, num_items_in_batch -# def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): -# batch_samples = [] -# num_items_in_batch = None -# for _ in range(num_batches): -# try: -# batch_samples += [next(epoch_iterator)] -# except StopIteration: -# break -# if len(batch_samples) > 0 and "labels" in batch_samples[0]: -# try: -# num_items_in_batch = sum( -# [torch.count_nonzero(x["labels"][..., 1:] != -100) for x in batch_samples] -# ) -# except TypeError: -# pass -# return batch_samples, num_items_in_batch -# pass + return batch_samples, num_items_in_batch +pass def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): From dd054c3dd409984fbb02843747edb7f6af003cae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 00:54:12 -0800 Subject: [PATCH 019/473] Update _utils.py --- unsloth/models/_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 32b1daaa0..762ebd1fd 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1047,6 +1047,13 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): elif "num_items_in_batch" not in inputs: inputs["num_items_in_batch"] = kwargs["num_items_in_batch"] pass + else: + name = (model.base_model.model if hasattr(model, "base_model") else model).__class__.__name__ + logger.warning_once( + f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"\ + "Using gradient accumulation will be very slightly less accurate.\n"\ + "Read more on gradient accumulation issues on our blog post: https://unsloth.ai/blog/gradient" + ) pass return self._old_compute_loss(model, inputs, *args, **kwargs) pass From 6c80d0fb545c79fa86766a757dfc55f6b025565b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 00:59:16 -0800 Subject: [PATCH 020/473] Update _utils.py --- unsloth/models/_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 762ebd1fd..af1d35bd9 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1040,19 +1040,24 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): + num_items_in_batch = None + if "num_items_in_batch" in kwargs: - if kwargs["num_items_in_batch"] is None: + num_items_in_batch = kwargs["num_items_in_batch"] + if num_items_in_batch is None: # Remove it since the model does not support it! kwargs.pop("num_items_in_batch") elif "num_items_in_batch" not in inputs: - inputs["num_items_in_batch"] = kwargs["num_items_in_batch"] + inputs["num_items_in_batch"] = num_items_in_batch pass - else: + pass + + if num_items_in_batch is None: name = (model.base_model.model if hasattr(model, "base_model") else model).__class__.__name__ logger.warning_once( f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"\ "Using gradient accumulation will be very slightly less accurate.\n"\ - "Read more on gradient accumulation issues on our blog post: https://unsloth.ai/blog/gradient" + "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient" ) pass return self._old_compute_loss(model, inputs, *args, **kwargs) From ea8e8a2126f2063dc33698f67476e28811d58e29 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 01:02:40 -0800 Subject: [PATCH 021/473] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index d1c8b1e07..824986dc1 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,7 +454,7 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - with contextlib.redirect_stdout(open(os.devnull, "w")): + if True: #with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, From 33ed089d846b43928e1b79f11a89f4697912e777 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 03:11:53 -0800 Subject: [PATCH 022/473] accurate_accumulation --- unsloth/models/_utils.py | 2 ++ unsloth/models/loader.py | 1 + 2 files changed, 3 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index af1d35bd9..1f2f9018d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1183,6 +1183,7 @@ def unsloth_compile_transformers( manual_replacements = True, fast_lora_forwards = True, fast_residual_stream = True, + accurate_accumulation = True, epilogue_fusion = True, max_autotune = False, shape_padding = True, @@ -1229,6 +1230,7 @@ def unsloth_compile_transformers( manual_replacements = manual_replacements, fast_lora_forwards = fast_lora_forwards, fast_residual_stream = fast_residual_stream, + accurate_accumulation = accurate_accumulation, epilogue_fusion = epilogue_fusion, max_autotune = max_autotune, shape_padding = shape_padding, diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 824986dc1..2fe037eb3 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -472,6 +472,7 @@ def from_pretrained( manual_replacements = True, fast_lora_forwards = False, fast_residual_stream = False, + accurate_accumulation = True, epilogue_fusion = True, max_autotune = False, shape_padding = True, From c3b41b8f65e3db5275de03b2633c935cedb8b3c7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 03:12:03 -0800 Subject: [PATCH 023/473] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 2fe037eb3..16f8c76d9 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -470,7 +470,7 @@ def from_pretrained( fuse_lm_head = True, gradient_checkpointing = True, manual_replacements = True, - fast_lora_forwards = False, + fast_lora_forwards = True, fast_residual_stream = False, accurate_accumulation = True, epilogue_fusion = True, From 142f026391c88693bcc3eb398528d5884c79b227 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 03:21:41 -0800 Subject: [PATCH 024/473] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 16f8c76d9..6aa6830b8 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -471,7 +471,7 @@ def from_pretrained( gradient_checkpointing = True, manual_replacements = True, fast_lora_forwards = True, - fast_residual_stream = False, + fast_residual_stream = True, accurate_accumulation = True, epilogue_fusion = True, max_autotune = False, From eecab406017ae9c6f2f47c4064297146f00b5586 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 03:24:18 -0800 Subject: [PATCH 025/473] Update _utils.py --- unsloth/models/_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 1f2f9018d..86346d7e2 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1023,8 +1023,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): num_items_in_batch = sum( [(x["labels"][..., 1:] != -100).sum() for x in batch_samples] ) - # num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) - + if self.args.average_tokens_across_devices: num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() From 8cec2facdb5b42957979791019cd7691108132f4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 03:29:58 -0800 Subject: [PATCH 026/473] Update loader.py --- unsloth/models/loader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 6aa6830b8..113c4fbc7 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,7 +454,7 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - if True: #with contextlib.redirect_stdout(open(os.devnull, "w")): + with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, @@ -470,8 +470,8 @@ def from_pretrained( fuse_lm_head = True, gradient_checkpointing = True, manual_replacements = True, - fast_lora_forwards = True, - fast_residual_stream = True, + fast_lora_forwards = False, + fast_residual_stream = False, accurate_accumulation = True, epilogue_fusion = True, max_autotune = False, From c68007cc1c97c67355f282ccf0d494863752e106 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 18:02:00 -0800 Subject: [PATCH 027/473] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 113c4fbc7..2ec774515 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -470,7 +470,7 @@ def from_pretrained( fuse_lm_head = True, gradient_checkpointing = True, manual_replacements = True, - fast_lora_forwards = False, + fast_lora_forwards = True, fast_residual_stream = False, accurate_accumulation = True, epilogue_fusion = True, From 549531125f4c5ba3c122fbaa89f704453c6ddda4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 21:28:49 -0800 Subject: [PATCH 028/473] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 2ec774515..16f8c76d9 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,7 +454,7 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - with contextlib.redirect_stdout(open(os.devnull, "w")): + if True: #with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, From ea2c6475b216a548fc1c93aecf68fcc76990dd2b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 03:53:46 -0800 Subject: [PATCH 029/473] Update loader.py --- unsloth/models/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 16f8c76d9..113c4fbc7 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,7 +454,7 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - if True: #with contextlib.redirect_stdout(open(os.devnull, "w")): + with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, @@ -470,7 +470,7 @@ def from_pretrained( fuse_lm_head = True, gradient_checkpointing = True, manual_replacements = True, - fast_lora_forwards = True, + fast_lora_forwards = False, fast_residual_stream = False, accurate_accumulation = True, epilogue_fusion = True, From f1da2a63f3000197d19415e0f516e1c02b060139 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 03:57:21 -0800 Subject: [PATCH 030/473] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9abe7a5d8..ce3301547 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2024.12.6", + "unsloth_zoo>=2024.12.7", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -285,7 +285,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2024.12.6", + "unsloth_zoo>=2024.12.7", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", From 3e1dbaba6321ed13cf3a7b21ffe56b5a8a349abd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 19:29:19 -0800 Subject: [PATCH 031/473] Update __init__.py --- unsloth/__init__.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 980425e1f..f8239ccf9 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -89,6 +89,36 @@ del os.environ["PYTORCH_CUDA_ALLOC_CONF"] pass +# Fix Xformers +import importlib.util +from pathlib import Path +from importlib.metadata import version as importlib_version +from packaging.version import Version +try: + xformers_version = importlib_version("xformers") + if Version(xformers_version) < Version("0.0.29"): + xformers_location = importlib.util.find_spec("xformers").origin + xformers_location = os.path.split(xformers_location)[0] + cutlass = Path(xformers_location) / "ops" / "fmha" / "cutlass.py" + + if cutlass.exists(): + with open(cutlass, "r+") as f: + text = f.read() + # See https://github.com/facebookresearch/xformers/issues/1176#issuecomment-2545829591 + if "num_splits_key=-1," in text: + print("Unsloth: Patching Xformers to fix some performance issues.") + text = text.replace("num_splits_key=-1,", "num_splits_key=None,") + pass + f.seek(0) + f.write(text) + f.truncate() + pass + pass + pass +except: + pass +pass + # Torch 2.4 has including_emulation major_version, minor_version = torch.cuda.get_device_capability() SUPPORTS_BFLOAT16 = (major_version >= 8) From a0d39ffbca35d8e2eed5e0c1517d8f420a962cd4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 19:34:30 -0800 Subject: [PATCH 032/473] Update pyproject.toml --- pyproject.toml | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ce3301547..ec17247d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,20 +148,20 @@ cu124onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu121onlytorch251 = [ - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", ] cu124onlytorch251 = [ - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu118 = [ "unsloth[huggingface]", From c3d4e188a5f0d058c8d9f7b8bf9c5462f74fbb8c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 19:35:05 -0800 Subject: [PATCH 033/473] Update __init__.py --- unsloth/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index f8239ccf9..10bcd2508 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -89,7 +89,7 @@ del os.environ["PYTORCH_CUDA_ALLOC_CONF"] pass -# Fix Xformers +# Fix Xformers performance issues since 0.0.25 import importlib.util from pathlib import Path from importlib.metadata import version as importlib_version From 7d7a1b0ef43b575aa6589e8283667b9fdf7d0590 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 19:36:11 -0800 Subject: [PATCH 034/473] Update __init__.py --- unsloth/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 10bcd2508..afd255dc3 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -106,12 +106,12 @@ text = f.read() # See https://github.com/facebookresearch/xformers/issues/1176#issuecomment-2545829591 if "num_splits_key=-1," in text: - print("Unsloth: Patching Xformers to fix some performance issues.") text = text.replace("num_splits_key=-1,", "num_splits_key=None,") + f.seek(0) + f.write(text) + f.truncate() + print("Unsloth: Patching Xformers to fix some performance issues.") pass - f.seek(0) - f.write(text) - f.truncate() pass pass pass From bfce3d402c152b084acdc3fda064d585aafef25d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Dec 2024 13:52:42 -0800 Subject: [PATCH 035/473] Fix Triton heuristics https://github.com/triton-lang/triton/issues/5224 --- unsloth/kernels/cross_entropy_loss.py | 37 +++++++++++++++------------ unsloth/kernels/rms_layernorm.py | 8 ++++-- unsloth/kernels/rope_embedding.py | 8 ++++-- 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index d347cd187..fcba2eb6d 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -25,11 +25,6 @@ ) -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), - "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), -}) -@triton.jit def _cross_entropy_forward( logits_ptr , logits_row_stride , @@ -95,13 +90,15 @@ def _cross_entropy_forward( tl.store(logsumexp_ptr, logsumexp) tl.store(loss_ptr, loss) pass +_cross_entropy_forward = triton.jit(_cross_entropy_forward) +_cross_entropy_forward = triton.heuristics( + { + "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), + "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), + } +)(_cross_entropy_forward) -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), - "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), -}) -@triton.jit def _chunked_cross_entropy_forward( logits_ptr , logits_row_stride , @@ -177,13 +174,15 @@ def _chunked_cross_entropy_forward( pass tl.store(logsumexp_ptr, logsumexp) pass +_chunked_cross_entropy_forward = triton.jit(_chunked_cross_entropy_forward) +_chunked_cross_entropy_forward = triton.heuristics( + { + "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), + "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), + } +)(_chunked_cross_entropy_forward) -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), - "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), -}) -@triton.jit def _cross_entropy_backward( logits_ptr , logits_row_stride , @@ -264,10 +263,16 @@ def _cross_entropy_backward( # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0. tl.store(logits_ptr + col_offsets, dloss * y, mask = mask) pass +_cross_entropy_backward = triton.jit(_cross_entropy_backward) +_cross_entropy_backward = triton.heuristics( + { + "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), + "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), + } +)(_cross_entropy_backward) MAX_FUSED_SIZE = 65536 # 2**16 - class Fast_CrossEntropyLoss(torch.autograd.Function): @staticmethod def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0): diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index b74d636c6..6310f7f39 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -53,8 +53,6 @@ def _rms_layernorm_forward( pass -@triton.heuristics({"GEMMA": lambda args: bool(args["GEMMA"]),}) -@triton.jit def _rms_layernorm_backward( dY, dY_row_stride, dX, dX_row_stride, @@ -97,6 +95,12 @@ def _rms_layernorm_backward( output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) tl.store(dX + col_offsets, output, mask = mask) pass +_rms_layernorm_backward = triton.jit(_rms_layernorm_backward) +_rms_layernorm_backward = triton.heuristics( + { + "GEMMA": lambda args: bool(args["GEMMA"]), + } +)(_rms_layernorm_backward) @triton.jit diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py index 7fe15d0e3..88b9ccadb 100644 --- a/unsloth/kernels/rope_embedding.py +++ b/unsloth/kernels/rope_embedding.py @@ -18,8 +18,6 @@ from .utils import calculate_settings ROPE_GROUP_SIZE : int = 4 -@triton.heuristics({"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),}) -@triton.jit def _rope_embedding( Q, Q_row_stride, cos, cos_row_stride, @@ -69,6 +67,12 @@ def _rope_embedding( tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask) pass pass +_rope_embedding = triton.jit(_rope_embedding) +_rope_embedding = triton.heuristics( + { + "BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]), + } +)(_rope_embedding) class Fast_RoPE_Embedding(torch.autograd.Function): From 743106eaf617677bb39aaa4b9fce43a485c5376a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Dec 2024 14:13:49 -0800 Subject: [PATCH 036/473] Update __init__.py --- unsloth/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index afd255dc3..90d2a6351 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -55,7 +55,12 @@ pass # Reduce VRAM usage by reducing fragmentation -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:[64:128,256:64,>:32]" +# And optimize pinning of memory +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \ + "expandable_segments:True,"\ + "roundup_power2_divisions:[32:256,64:128,256:64,>:32],"\ + "pinned_use_cuda_host_register:True,"\ + "pinned_num_register_threads:8" # Hugging Face Hub faster downloads if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: From 4e0986fbe45c8267fc27ee32675f06bc645570ad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 00:38:18 -0800 Subject: [PATCH 037/473] Update __init__.py --- unsloth/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 90d2a6351..25d4e2b0a 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -58,7 +58,7 @@ # And optimize pinning of memory os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \ "expandable_segments:True,"\ - "roundup_power2_divisions:[32:256,64:128,256:64,>:32],"\ + "roundup_power2_divisions:[64:128,256:64,>:32],"\ "pinned_use_cuda_host_register:True,"\ "pinned_num_register_threads:8" From abebd113befc427dae39856c108176fa851bef33 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 12:31:30 -0800 Subject: [PATCH 038/473] Update __init__.py --- unsloth/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 25d4e2b0a..0b46794e9 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -58,9 +58,7 @@ # And optimize pinning of memory os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \ "expandable_segments:True,"\ - "roundup_power2_divisions:[64:128,256:64,>:32],"\ - "pinned_use_cuda_host_register:True,"\ - "pinned_num_register_threads:8" + "roundup_power2_divisions:[64:128,256:64,>:32]" # Hugging Face Hub faster downloads if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: From f0216092b9bb60a799e021b5dadd2290ef43b756 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 12:35:56 -0800 Subject: [PATCH 039/473] Update __init__.py --- unsloth/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 0b46794e9..90d2a6351 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -58,7 +58,9 @@ # And optimize pinning of memory os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \ "expandable_segments:True,"\ - "roundup_power2_divisions:[64:128,256:64,>:32]" + "roundup_power2_divisions:[32:256,64:128,256:64,>:32],"\ + "pinned_use_cuda_host_register:True,"\ + "pinned_num_register_threads:8" # Hugging Face Hub faster downloads if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: From 512773e69166fa405bb0450cc486ddd596f100ff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 22:42:28 -0800 Subject: [PATCH 040/473] Xformers --- pyproject.toml | 24 ++++++++++++------------ unsloth/models/loader.py | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ec17247d1..bf4c99528 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,20 +148,20 @@ cu124onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu121onlytorch251 = [ - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", ] cu124onlytorch251 = [ - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu118 = [ "unsloth[huggingface]", diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 113c4fbc7..2fe037eb3 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,7 +454,7 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - with contextlib.redirect_stdout(open(os.devnull, "w")): + if True: #with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, From b4549cd93e7a3dfad8001c80d07e914e27d62537 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 02:06:41 -0800 Subject: [PATCH 041/473] Update loader.py --- unsloth/models/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 2fe037eb3..2ec774515 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,7 +454,7 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - if True: #with contextlib.redirect_stdout(open(os.devnull, "w")): + with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, @@ -470,7 +470,7 @@ def from_pretrained( fuse_lm_head = True, gradient_checkpointing = True, manual_replacements = True, - fast_lora_forwards = False, + fast_lora_forwards = True, fast_residual_stream = False, accurate_accumulation = True, epilogue_fusion = True, From 67604993b0493ffc47f2dfabac90c95faeaa3e6b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 16:27:08 -0800 Subject: [PATCH 042/473] Update loader.py --- unsloth/models/loader.py | 60 +++++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 2ec774515..20c0177d7 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,35 +454,37 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - with contextlib.redirect_stdout(open(os.devnull, "w")): - patch_loss_functions(torch_compile = False) - model_types = unsloth_compile_transformers( - model_name = model_name, - sdpa_dynamic_mask = True, - sdpa_bool_masks = True, - sdpa_gqa_replace = True, - sdpa_dynamic_compile = True, - compile_attention = True, - disable_causal_masks = True, - compile_torch_modules = True, - compile_custom_modules = True, - compile_function_calls = True, - fuse_lm_head = True, - gradient_checkpointing = True, - manual_replacements = True, - fast_lora_forwards = True, - fast_residual_stream = False, - accurate_accumulation = True, - epilogue_fusion = True, - max_autotune = False, - shape_padding = True, - cudagraphs = False, - debug = False, - fullgraph = fullgraph, - import_from_cache = False, - disable = False, - return_logits = return_logits, - ) + if os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "0": + with contextlib.redirect_stdout(open(os.devnull, "w")): + patch_loss_functions(torch_compile = False) + model_types = unsloth_compile_transformers( + model_name = model_name, + sdpa_dynamic_mask = True, + sdpa_bool_masks = True, + sdpa_gqa_replace = True, + sdpa_dynamic_compile = True, + compile_attention = True, + disable_causal_masks = True, + compile_torch_modules = True, + compile_custom_modules = True, + compile_function_calls = True, + fuse_lm_head = True, + gradient_checkpointing = True, + manual_replacements = True, + fast_lora_forwards = True, + fast_residual_stream = False, + accurate_accumulation = True, + epilogue_fusion = True, + max_autotune = False, + shape_padding = True, + cudagraphs = False, + debug = False, + fullgraph = fullgraph, + import_from_cache = False, + disable = False, + return_logits = return_logits, + ) + pass pass # Check if this is local model since the tokenizer gets overwritten From c25f20ce70062a16a87f2beba2fb449b9f9d8a46 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 16:34:18 -0800 Subject: [PATCH 043/473] Rewind --- unsloth/models/_utils.py | 4 +-- unsloth/models/loader.py | 60 +++++++++++++++++++--------------------- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 3cb6ffb8f..386d71354 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1203,8 +1203,6 @@ def unsloth_compile_transformers( return pass - if disable: return - model_types = get_transformers_model_type( model_name = model_name, token = token, @@ -1212,6 +1210,8 @@ def unsloth_compile_transformers( trust_remote_code = trust_remote_code, ) + if disable: return + for model_type in model_types: _unsloth_compile_transformers( model_type, diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 20c0177d7..2ec774515 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,37 +454,35 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - if os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "0": - with contextlib.redirect_stdout(open(os.devnull, "w")): - patch_loss_functions(torch_compile = False) - model_types = unsloth_compile_transformers( - model_name = model_name, - sdpa_dynamic_mask = True, - sdpa_bool_masks = True, - sdpa_gqa_replace = True, - sdpa_dynamic_compile = True, - compile_attention = True, - disable_causal_masks = True, - compile_torch_modules = True, - compile_custom_modules = True, - compile_function_calls = True, - fuse_lm_head = True, - gradient_checkpointing = True, - manual_replacements = True, - fast_lora_forwards = True, - fast_residual_stream = False, - accurate_accumulation = True, - epilogue_fusion = True, - max_autotune = False, - shape_padding = True, - cudagraphs = False, - debug = False, - fullgraph = fullgraph, - import_from_cache = False, - disable = False, - return_logits = return_logits, - ) - pass + with contextlib.redirect_stdout(open(os.devnull, "w")): + patch_loss_functions(torch_compile = False) + model_types = unsloth_compile_transformers( + model_name = model_name, + sdpa_dynamic_mask = True, + sdpa_bool_masks = True, + sdpa_gqa_replace = True, + sdpa_dynamic_compile = True, + compile_attention = True, + disable_causal_masks = True, + compile_torch_modules = True, + compile_custom_modules = True, + compile_function_calls = True, + fuse_lm_head = True, + gradient_checkpointing = True, + manual_replacements = True, + fast_lora_forwards = True, + fast_residual_stream = False, + accurate_accumulation = True, + epilogue_fusion = True, + max_autotune = False, + shape_padding = True, + cudagraphs = False, + debug = False, + fullgraph = fullgraph, + import_from_cache = False, + disable = False, + return_logits = return_logits, + ) pass # Check if this is local model since the tokenizer gets overwritten From c90b3bfecfd04e51534a1b855c76cc3f3fc88426 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 22:32:58 -0800 Subject: [PATCH 044/473] Update _utils.py --- unsloth/models/_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 386d71354..9bd3598b1 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.12.12" +__version__ = "2025.1.1" __all__ = [ "prepare_model_for_kbit_training", @@ -110,6 +110,9 @@ get_transformers_model_type, unsloth_compile_transformers as _unsloth_compile_transformers, ) +from unsloth_zoo.peft_utils import ( + requires_grad_for_gradient_checkpointing, +) # ============================================= # Disable some warnings which can get annoying @@ -557,6 +560,10 @@ def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + # Enable grads on non language models as well + requires_grad_for_gradient_checkpointing() + pass + return model pass From 937952292efd0cfc2a0f1e662192f96ecdec3d2c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 22:34:40 -0800 Subject: [PATCH 045/473] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 9bd3598b1..33fb36e8b 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -561,7 +561,7 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) # Enable grads on non language models as well - requires_grad_for_gradient_checkpointing() + requires_grad_for_gradient_checkpointing(model) pass return model From 9a66c6f1578a2eeee7ecad9c169b65f4a7394947 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 18:25:25 -0800 Subject: [PATCH 046/473] requires grad --- unsloth/__init__.py | 21 ++++++++++----------- unsloth/models/_utils.py | 6 ------ unsloth/models/vision.py | 3 +++ 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 90d2a6351..bbeded9fc 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -17,16 +17,6 @@ import os, re, subprocess, inspect import numpy as np -# # Define a list of modules to check -# MODULES_TO_CHECK = ["bitsandbytes"] - -# # Check if any of the modules in the list have been imported -# for module in MODULES_TO_CHECK: -# if module in sys.modules: -# raise ImportError(f"Unsloth: Please import Unsloth before {module}.") -# pass -# pass - # Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so # enabling it will require much more work, so we have to prioritize. Please understand! # We do have a beta version, which you can contact us about! @@ -201,9 +191,18 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: + unsloth_zoo_version = importlib_version("unsloth_zoo") + if Version(unsloth_zoo_version) < Version("2025.1.1"): + try: + os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") + except: + try: + os.system("pip install --upgrade --no-cache-dir --no-deps --user unsloth_zoo") + except: + raise ImportError("Unsloth: Please update unsloth_zoo via `pip install --upgrade --no-cache-dir --no-deps unsloth_zoo`") import unsloth_zoo except: - raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`") + raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo`") pass from .models import * diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 33fb36e8b..098f5c3e4 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -110,9 +110,6 @@ get_transformers_model_type, unsloth_compile_transformers as _unsloth_compile_transformers, ) -from unsloth_zoo.peft_utils import ( - requires_grad_for_gradient_checkpointing, -) # ============================================= # Disable some warnings which can get annoying @@ -559,9 +556,6 @@ def prepare_model_for_kbit_training( def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - # Enable grads on non language models as well - requires_grad_for_gradient_checkpointing(model) pass return model diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 2dc4b88df..51450aa0d 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -30,6 +30,7 @@ from unsloth_zoo.peft_utils import ( get_peft_regex, SKIP_QUANTIZATION_MODULES, + requires_grad_for_gradient_checkpointing, ) from triton import __version__ as triton_version @@ -275,6 +276,8 @@ def get_peft_model( use_gradient_checkpointing = use_gradient_checkpointing, ) model = get_peft_model(model, lora_config) + # Enable gradients on modules which are trainable + requires_grad_for_gradient_checkpointing(model) model = FastBaseVisionModel.patch_peft_model(model, use_gradient_checkpointing) From bb9ab04dd8402ec10aaba86ab7383da58e25239a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 22:44:17 -0800 Subject: [PATCH 047/473] Update loader.py --- unsloth/models/loader.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 2ec774515..3e54ef2cd 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -32,7 +32,7 @@ from huggingface_hub import HfFileSystem # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! -from unsloth_zoo.utils import Version +from unsloth_zoo.utils import Version, _get_dtype transformers_version = Version(transformers_version) SUPPORTS_FOURBIT = transformers_version >= Version("4.37") SUPPORTS_GEMMA = transformers_version >= Version("4.38") @@ -47,23 +47,6 @@ pass import torch -def _get_dtype(dtype): - __DTYPE_MAP = { - "float32": torch.float32, - torch.float32: torch.float32, - "float16": torch.float16, - torch.float16: torch.float16, - "bfloat16": torch.bfloat16, - torch.bfloat16: torch.bfloat16, - } - if dtype is None or dtype == None: return None - elif dtype in __DTYPE_MAP: return __DTYPE_MAP[dtype] - else: - print(f"Unsloth: {dtype} is not recognized, so we'll default to None") - return None - pass -pass - class FastLanguageModel(FastLlamaModel): @staticmethod From 3e096ac6ba40a2aad9ed7f5036d168798976ea90 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 4 Jan 2025 00:42:39 -0800 Subject: [PATCH 048/473] Update _utils.py --- unsloth/models/_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 098f5c3e4..3752d46d0 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -58,7 +58,6 @@ "fused_linear_cross_entropy", "patch_unsloth_smart_gradient_checkpointing", "unpatch_unsloth_smart_gradient_checkpointing", - "create_gradient_checkpointing_buffer", "patch_compiled_autograd", "process_vision_info", @@ -97,7 +96,6 @@ patch_unsloth_smart_gradient_checkpointing, unpatch_unsloth_smart_gradient_checkpointing, - create_gradient_checkpointing_buffer, ) from unsloth_zoo.loss_utils import ( HAS_CUT_CROSS_ENTROPY, From 99898da0d34226ac2f040bc0ac4e17094e19de6d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 4 Jan 2025 22:03:11 -0800 Subject: [PATCH 049/473] Update loader.py --- unsloth/models/loader.py | 114 ++++++++++++++++++--------------------- 1 file changed, 52 insertions(+), 62 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 19747cb4e..a88114669 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -13,6 +13,7 @@ # limitations under the License. from ._utils import is_bfloat16_supported, HAS_FLASH_ATTENTION, HAS_FLASH_ATTENTION_SOFTCAPPING +from .granite import FastGraniteModel from .llama import FastLlamaModel, logger from .mistral import FastMistralModel from .qwen2 import FastQwen2Model @@ -31,13 +32,14 @@ from huggingface_hub import HfFileSystem # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! -from packaging.version import Version +from unsloth_zoo.utils import Version, _get_dtype transformers_version = Version(transformers_version) SUPPORTS_FOURBIT = transformers_version >= Version("4.37") SUPPORTS_GEMMA = transformers_version >= Version("4.38") SUPPORTS_GEMMA2 = transformers_version >= Version("4.42") SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.2") SUPPORTS_LLAMA32 = transformers_version > Version("4.45.0") +SUPPORTS_GRANITE = transformers_version >= Version("4.46.0") if SUPPORTS_GEMMA: from .gemma import FastGemmaModel if SUPPORTS_GEMMA2: @@ -45,28 +47,11 @@ pass import torch -def _get_dtype(dtype): - __DTYPE_MAP = { - "float32": torch.float32, - torch.float32: torch.float32, - "float16": torch.float16, - torch.float16: torch.float16, - "bfloat16": torch.bfloat16, - torch.bfloat16: torch.bfloat16, - } - if dtype is None or dtype == None: return None - elif dtype in __DTYPE_MAP: return __DTYPE_MAP[dtype] - else: - print(f"Unsloth: {dtype} is not recognized, so we'll default to None") - return None - pass -pass - class FastLanguageModel(FastLlamaModel): @staticmethod def from_pretrained( - model_name = "unsloth/llama-3-8b-bnb-4bit", + model_name = "unsloth/Llama-3.2-1B-Instruct", max_seq_length = None, dtype = None, load_in_4bit = True, @@ -131,7 +116,8 @@ def from_pretrained( exist_config = os.path.exists(os.path.join(model_name, "config.json")) both_exist = exist_adapter_config and exist_config else: - files = HfFileSystem(token = token).glob(os.path.join(model_name, "*.json")) + # Because HfFileSystem assumes linux paths, we need to set the path with forward slashes, even on Windows. + files = HfFileSystem(token = token).glob(f"{model_name}/*.json") files = (os.path.split(x)[-1] for x in files) if sum(x == "adapter_config.json" or x == "config.json" for x in files) >= 2: both_exist = True @@ -164,10 +150,9 @@ def from_pretrained( # Get base model for PEFT: if is_peft: # Check base model again for PEFT + model_name = peft_config.base_model_name_or_path if not use_exact_model_name: - model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit) - else: - model_name = peft_config.base_model_name_or_path + model_name = get_model_name(model_name, load_in_4bit) model_config = AutoConfig.from_pretrained( model_name, token = token, @@ -180,7 +165,7 @@ def from_pretrained( model_type = model_config.model_type - if model_type == "llama": + if model_type == "llama": scaling_type = None if getattr(model_config, "rope_scaling", None) is not None: scaling_type1 = model_config.rope_scaling.get("type", None) @@ -236,6 +221,8 @@ def from_pretrained( dispatch_model = FastQwen2Model elif model_type == "cohere": dispatch_model = FastCohereModel + elif model_type == "granite": + dispatch_model = FastGraniteModel else: raise NotImplementedError( f"Unsloth: {model_name} not supported yet!\n"\ @@ -254,8 +241,6 @@ def from_pretrained( tokenizer_name = None pass - original_kwargs = kwargs.copy() - model, tokenizer = dispatch_model.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, @@ -269,7 +254,7 @@ def from_pretrained( tokenizer_name = tokenizer_name, trust_remote_code = trust_remote_code, revision = revision if not is_peft else None, - *args, **original_kwargs, + *args, **kwargs, ) if resize_model_vocab is not None: @@ -354,6 +339,8 @@ def from_pretrained( use_gradient_checkpointing = "unsloth", resize_model_vocab = None, # [TODO] No effect revision = None, + return_logits = False, # Return logits + fullgraph = True, # No graph breaks use_exact_model_name = False, *args, **kwargs, ): @@ -362,43 +349,17 @@ def from_pretrained( patch_compiled_autograd() patch_compiling_bitsandbytes() if use_gradient_checkpointing == "unsloth": - patch_unsloth_smart_gradient_checkpointing() + patch_unsloth_smart_gradient_checkpointing(dtype = dtype) old_model_name = model_name if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) - with contextlib.redirect_stdout(open(os.devnull, "w")): - patch_loss_functions(torch_compile = False) - model_types = unsloth_compile_transformers( - model_name = model_name, - sdpa_dynamic_mask = True, - sdpa_bool_masks = True, - sdpa_gqa_replace = True, - sdpa_dynamic_compile = True, - compile_attention = True, - disable_causal_masks = True, - compile_torch_modules = True, - compile_custom_modules = True, - compile_function_calls = True, - fuse_lm_head = True, - gradient_checkpointing = True, - manual_replacements = True, - epilogue_fusion = True, - max_autotune = False, - shape_padding = True, - cudagraphs = False, - debug = False, - import_from_cache = False, - disable = False, - ) - pass - # First check if it's a normal model via AutoConfig from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled was_disabled = are_progress_bars_disabled() disable_progress_bars() - + autoconfig_error = None peft_error = None try: @@ -438,7 +399,7 @@ def from_pretrained( exist_config = os.path.exists(os.path.join(model_name, "config.json")) both_exist = exist_adapter_config and exist_config else: - files = HfFileSystem(token = token).glob(os.path.join(model_name, "*.json")) + files = HfFileSystem(token = token).glob(f"{model_name}/*.json") files = (os.path.split(x)[-1] for x in files) if sum(x == "adapter_config.json" or x == "config.json" for x in files) >= 2: both_exist = True @@ -471,10 +432,10 @@ def from_pretrained( # Get base model for PEFT: if is_peft: # Check base model again for PEFT + model_name = peft_config.base_model_name_or_path if not use_exact_model_name: - model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit) - else: - model_name = peft_config.base_model_name_or_path + model_name = get_model_name(model_name, load_in_4bit) + model_config = AutoConfig.from_pretrained( model_name, token = token, @@ -485,6 +446,37 @@ def from_pretrained( if not was_disabled: enable_progress_bars() + with contextlib.redirect_stdout(open(os.devnull, "w")): + patch_loss_functions(torch_compile = False) + model_types = unsloth_compile_transformers( + model_name = model_name, + sdpa_dynamic_mask = True, + sdpa_bool_masks = True, + sdpa_gqa_replace = True, + sdpa_dynamic_compile = True, + compile_attention = True, + disable_causal_masks = True, + compile_torch_modules = True, + compile_custom_modules = True, + compile_function_calls = True, + fuse_lm_head = True, + gradient_checkpointing = True, + manual_replacements = True, + fast_lora_forwards = True, + fast_residual_stream = False, + accurate_accumulation = True, + epilogue_fusion = True, + max_autotune = False, + shape_padding = True, + cudagraphs = False, + debug = False, + fullgraph = fullgraph, + import_from_cache = False, + disable = False, + return_logits = return_logits, + ) + pass + # Check if this is local model since the tokenizer gets overwritten if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \ os.path.exists(os.path.join(old_model_name, "tokenizer.json")) and \ @@ -495,8 +487,6 @@ def from_pretrained( tokenizer_name = None pass - original_kwargs = kwargs.copy() - model, tokenizer = FastBaseVisionModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, @@ -508,7 +498,7 @@ def from_pretrained( revision = revision if not is_peft else None, model_types = model_types, tokenizer_name = tokenizer_name, - *args, **original_kwargs, + *args, **kwargs, ) if resize_model_vocab is not None: From 86ab9f19313ce17ba267a23c6fc77fce9eeb2175 Mon Sep 17 00:00:00 2001 From: Muhammad Osama Date: Sun, 5 Jan 2025 18:18:42 -0600 Subject: [PATCH 050/473] changing model to base_model if peft model is already used --- unsloth/models/llama.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c94514966..128e0fd76 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1967,29 +1967,29 @@ def get_peft_model( if "embed_tokens" in new_target_modules: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype - model.model.model.embed_tokens.modules_to_save.default\ + dtype = model.base_model.model.embed_tokens.modules_to_save.default.weight.dtype + model.base_model.model.embed_tokens.modules_to_save.default\ .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) - model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) + model.base_model.model.embed_tokens.modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! - model.model.model.embed_tokens.original_module\ + model.base_model.model.embed_tokens.original_module\ .to(device = "cpu", non_blocking = True) - model.model.model.embed_tokens.original_module.requires_grad_(False) + model.base_model.model.embed_tokens.original_module.requires_grad_(False) pass if "lm_head" in new_target_modules: print("Unsloth: Training lm_head in mixed precision to save VRAM") - dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype - model.model.lm_head.modules_to_save.default\ + dtype = model.base_model.model.lm_head.modules_to_save.default.weight.dtype + model.base_model.lm_head.modules_to_save.default\ .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) - model.model.lm_head.modules_to_save.default.requires_grad_(True) + model.base_model.lm_head.modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! - model.model.lm_head.original_module\ + model.base_model.lm_head.original_module\ .to(device = "cpu", non_blocking = True) - model.model.lm_head.original_module.requires_grad_(False) + model.base_model.lm_head.original_module.requires_grad_(False) pass return model From 039a507a2325fc7dce5254dc61f02829b66919c2 Mon Sep 17 00:00:00 2001 From: Edd <68678137+Erland366@users.noreply.github.com> Date: Tue, 7 Jan 2025 10:04:27 +0800 Subject: [PATCH 051/473] Improve debugging experience (#1512) * Create CONTRIBUTING.md (#1472) Creating contributing guidelines * Update CONTRIBUTING.md improved sentence * Improve logging control in `unsloth_compile_transformers` by conditionally redirecting stdout based on UNSLOTH_DISABLE_LOGGER environment variable --------- Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com> --- unsloth/models/loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index a88114669..acfd0129b 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -446,7 +446,9 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - with contextlib.redirect_stdout(open(os.devnull, "w")): + with contextlib.redirect_stdout( + open(os.devnull, "w") if os.environ.get("UNSLOTH_DISABLE_LOGGER", "0") != "1" else sys.stdout + ): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, From f40558f5307df823fa589d5a402b87b7bc99ce1e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 6 Jan 2025 18:13:48 -0800 Subject: [PATCH 052/473] Update loader.py --- unsloth/models/loader.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index acfd0129b..657072ab3 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -446,9 +446,10 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - with contextlib.redirect_stdout( - open(os.devnull, "w") if os.environ.get("UNSLOTH_DISABLE_LOGGER", "0") != "1" else sys.stdout - ): + do_logging = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1" + redirector = sys.stdout if do_logging else open(os.devnull, "w") + + with contextlib.redirect_stdout(redirector): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, @@ -478,6 +479,7 @@ def from_pretrained( return_logits = return_logits, ) pass + if do_logging: redirector.close() # Check if this is local model since the tokenizer gets overwritten if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \ From a229db5a85c7f4795dc24b6c41c28b753c93a256 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 6 Jan 2025 18:56:26 -0800 Subject: [PATCH 053/473] Update llama.py --- unsloth/models/llama.py | 48 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c94514966..d3b51b683 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1968,8 +1968,18 @@ def get_peft_model( print("Unsloth: Training embed_tokens in mixed precision to save VRAM") dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype + pass + model.model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! @@ -1982,8 +1992,17 @@ def get_peft_model( print("Unsloth: Training lm_head in mixed precision to save VRAM") dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype + pass model.model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) model.model.lm_head.modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! @@ -2216,14 +2235,23 @@ def get_peft_model( model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing) - # Now patch lm_head and embed_tokens if train_embed_tokens: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype + pass + model.model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) pass @@ -2232,8 +2260,18 @@ def get_peft_model( assert(hasattr(model.model.lm_head, "modules_to_save")) dtype = model.model.lm_head.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype + pass + model.model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) model.model.lm_head.modules_to_save.default.requires_grad_(True) pass From b7ddf962d2f398be0286602d0fbb5b11e317887b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 6 Jan 2025 22:05:14 -0800 Subject: [PATCH 054/473] Update llama.py --- unsloth/models/llama.py | 77 +++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 46 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d3b51b683..0cfa1d04a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1967,48 +1967,41 @@ def get_peft_model( if "embed_tokens" in new_target_modules: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype - # Now patch lm_head and embed_tokens - if dtype == torch.float16: + new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - modules_to_save_dtype = torch.float32 - else: - # Can be bfloat16 - modules_to_save_dtype = dtype + new_dtype = torch.float32 pass - model.model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) - model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) + model.get_input_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_input_embeddings().modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! - model.model.model.embed_tokens.original_module\ + model.get_input_embeddings().original_module\ .to(device = "cpu", non_blocking = True) - model.model.model.embed_tokens.original_module.requires_grad_(False) + model.get_input_embeddings().original_module.requires_grad_(False) pass if "lm_head" in new_target_modules: print("Unsloth: Training lm_head in mixed precision to save VRAM") - dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype - # Now patch lm_head and embed_tokens - if dtype == torch.float16: + new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - modules_to_save_dtype = torch.float32 - else: - # Can be bfloat16 - modules_to_save_dtype = dtype + new_dtype = torch.float32 pass - model.model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) - model.model.lm_head.modules_to_save.default.requires_grad_(True) + + model.get_output_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_output_embeddings().modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! - model.model.lm_head.original_module\ + model.get_output_embeddings().original_module\ .to(device = "cpu", non_blocking = True) - model.model.lm_head.original_module.requires_grad_(False) + model.get_output_embeddings().original_module.requires_grad_(False) pass return model @@ -2237,42 +2230,34 @@ def get_peft_model( if train_embed_tokens: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) + assert(hasattr(model.get_input_embeddings(), "modules_to_save")) - dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype - # Now patch lm_head and embed_tokens - if dtype == torch.float16: + new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - modules_to_save_dtype = torch.float32 - else: - # Can be bfloat16 - modules_to_save_dtype = dtype + new_dtype = torch.float32 pass - model.model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) - model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) + model.get_input_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_input_embeddings().modules_to_save.default.requires_grad_(True) pass if train_lm_head: print("Unsloth: Training lm_head in mixed precision to save VRAM") - assert(hasattr(model.model.lm_head, "modules_to_save")) + assert(hasattr(model.get_output_embeddings(), "modules_to_save")) - dtype = model.model.lm_head.modules_to_save.default.weight.dtype - # Now patch lm_head and embed_tokens - if dtype == torch.float16: + new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - modules_to_save_dtype = torch.float32 - else: - # Can be bfloat16 - modules_to_save_dtype = dtype + new_dtype = torch.float32 pass - model.model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) - model.model.lm_head.modules_to_save.default.requires_grad_(True) + model.get_output_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_output_embeddings().modules_to_save.default.requires_grad_(True) pass # Patch tokenizer to pad to the right From 2b5d4701fdbc5cf71019894688d5c6fddd65b753 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 6 Jan 2025 22:05:44 -0800 Subject: [PATCH 055/473] Revert "Update llama.py" This reverts commit b7ddf962d2f398be0286602d0fbb5b11e317887b. --- unsloth/models/llama.py | 77 ++++++++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0cfa1d04a..d3b51b683 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1967,41 +1967,48 @@ def get_peft_model( if "embed_tokens" in new_target_modules: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype - if new_dtype == torch.float16: + dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - new_dtype = torch.float32 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype pass - model.get_input_embeddings().modules_to_save.default\ - .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) - model.get_input_embeddings().modules_to_save.default.requires_grad_(True) + model.model.model.embed_tokens.modules_to_save.default\ + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) + model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! - model.get_input_embeddings().original_module\ + model.model.model.embed_tokens.original_module\ .to(device = "cpu", non_blocking = True) - model.get_input_embeddings().original_module.requires_grad_(False) + model.model.model.embed_tokens.original_module.requires_grad_(False) pass if "lm_head" in new_target_modules: print("Unsloth: Training lm_head in mixed precision to save VRAM") - new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype - if new_dtype == torch.float16: + dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - new_dtype = torch.float32 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype pass - - model.get_output_embeddings().modules_to_save.default\ - .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) - model.get_output_embeddings().modules_to_save.default.requires_grad_(True) + model.model.lm_head.modules_to_save.default\ + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) + model.model.lm_head.modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! - model.get_output_embeddings().original_module\ + model.model.lm_head.original_module\ .to(device = "cpu", non_blocking = True) - model.get_output_embeddings().original_module.requires_grad_(False) + model.model.lm_head.original_module.requires_grad_(False) pass return model @@ -2230,34 +2237,42 @@ def get_peft_model( if train_embed_tokens: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - assert(hasattr(model.get_input_embeddings(), "modules_to_save")) + assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) - new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype - if new_dtype == torch.float16: + dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - new_dtype = torch.float32 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype pass - model.get_input_embeddings().modules_to_save.default\ - .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) - model.get_input_embeddings().modules_to_save.default.requires_grad_(True) + model.model.model.embed_tokens.modules_to_save.default\ + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) + model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) pass if train_lm_head: print("Unsloth: Training lm_head in mixed precision to save VRAM") - assert(hasattr(model.get_output_embeddings(), "modules_to_save")) + assert(hasattr(model.model.lm_head, "modules_to_save")) - new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype - if new_dtype == torch.float16: + dtype = model.model.lm_head.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - new_dtype = torch.float32 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype pass - model.get_output_embeddings().modules_to_save.default\ - .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) - model.get_output_embeddings().modules_to_save.default.requires_grad_(True) + model.model.lm_head.modules_to_save.default\ + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) + model.model.lm_head.modules_to_save.default.requires_grad_(True) pass # Patch tokenizer to pad to the right From 52d2895dc26b9040a3a086a6019d4d769532eac9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 6 Jan 2025 22:06:00 -0800 Subject: [PATCH 056/473] Update llama.py --- unsloth/models/llama.py | 69 +++++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 128e0fd76..0cfa1d04a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1967,29 +1967,41 @@ def get_peft_model( if "embed_tokens" in new_target_modules: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - dtype = model.base_model.model.embed_tokens.modules_to_save.default.weight.dtype - model.base_model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) - model.base_model.model.embed_tokens.modules_to_save.default.requires_grad_(True) + new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + new_dtype = torch.float32 + pass + + model.get_input_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_input_embeddings().modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! - model.base_model.model.embed_tokens.original_module\ + model.get_input_embeddings().original_module\ .to(device = "cpu", non_blocking = True) - model.base_model.model.embed_tokens.original_module.requires_grad_(False) + model.get_input_embeddings().original_module.requires_grad_(False) pass if "lm_head" in new_target_modules: print("Unsloth: Training lm_head in mixed precision to save VRAM") - dtype = model.base_model.model.lm_head.modules_to_save.default.weight.dtype - model.base_model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) - model.base_model.lm_head.modules_to_save.default.requires_grad_(True) + new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + new_dtype = torch.float32 + pass + + model.get_output_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_output_embeddings().modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! - model.base_model.lm_head.original_module\ + model.get_output_embeddings().original_module\ .to(device = "cpu", non_blocking = True) - model.base_model.lm_head.original_module.requires_grad_(False) + model.get_output_embeddings().original_module.requires_grad_(False) pass return model @@ -2216,25 +2228,36 @@ def get_peft_model( model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing) - # Now patch lm_head and embed_tokens if train_embed_tokens: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) + assert(hasattr(model.get_input_embeddings(), "modules_to_save")) - dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype - model.model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) - model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) + new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + new_dtype = torch.float32 + pass + + model.get_input_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_input_embeddings().modules_to_save.default.requires_grad_(True) pass if train_lm_head: print("Unsloth: Training lm_head in mixed precision to save VRAM") - assert(hasattr(model.model.lm_head, "modules_to_save")) + assert(hasattr(model.get_output_embeddings(), "modules_to_save")) + + new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + new_dtype = torch.float32 + pass - dtype = model.model.lm_head.modules_to_save.default.weight.dtype - model.model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) - model.model.lm_head.modules_to_save.default.requires_grad_(True) + model.get_output_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_output_embeddings().modules_to_save.default.requires_grad_(True) pass # Patch tokenizer to pad to the right From 1e8cf025c196e55c9aaf65be8d021a6f3c578efd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:30:32 -0800 Subject: [PATCH 057/473] Update llama.py --- unsloth/models/llama.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0cfa1d04a..f4ffbec4a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -996,18 +996,21 @@ def _CausalLM_fast_forward( lm_head = self.lm_head.weight logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) logit_scaling = getattr(self.config, "logit_scale", 0) + dtype = lm_head.dtype if bsz == 1 and q_len == 1: - logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype)) + logits = torch.mv(lm_head, hidden_states.ravel().to(dtype)) logits = logits.unsqueeze(0).unsqueeze(0) elif num_logits_to_keep != 0: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype)) + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(dtype)) else: RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" # < 1024 Normal Unsloth uses less VRAM! if bsz*q_len <= 1024: RETURN_LOGITS = True if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: + + print(hidden_states, lm_head) n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) loss = fused_linear_cross_entropy( hidden_states = hidden_states, @@ -1029,7 +1032,7 @@ def _CausalLM_fast_forward( ) return output pass - logits = self.lm_head(hidden_states.to(lm_head.dtype)) + logits = self.lm_head(hidden_states.to(dtype)) pass torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None) From cef7e5881fa71f336b5aab0f876a70fa3dfac825 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:34:09 -0800 Subject: [PATCH 058/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index f4ffbec4a..c5c245e0a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -624,6 +624,7 @@ def LlamaModel_fast_forward( else: raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") pass + print(inputs_embeds, inputs_embeds.dtype) # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") From ca8e92cd89969ba73869a9227a462d1cc1cdf66d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:34:46 -0800 Subject: [PATCH 059/473] Update llama.py --- unsloth/models/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c5c245e0a..0765e4289 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -624,7 +624,6 @@ def LlamaModel_fast_forward( else: raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") pass - print(inputs_embeds, inputs_embeds.dtype) # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") @@ -1011,7 +1010,6 @@ def _CausalLM_fast_forward( if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: - print(hidden_states, lm_head) n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) loss = fused_linear_cross_entropy( hidden_states = hidden_states, From dbef42d72679ad7f5ce28e56771a1f469e4ed5e7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:38:04 -0800 Subject: [PATCH 060/473] Update llama.py --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0765e4289..fe9eacd24 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -866,7 +866,9 @@ def custom_forward(*inputs): elif IS_COHERE: hidden_states = self.norm(hidden_states) else: + print(0, hidden_states.dtype) hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA) + print(1, hidden_states.dtype) pass if output_hidden_states: all_hidden_states += (hidden_states,) From 0dd136ddfe80d2c7eda718bf59b77b0ca3ae2df7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:41:16 -0800 Subject: [PATCH 061/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fe9eacd24..32caa4521 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -853,6 +853,7 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] pass + print(idx, hidden_states.dtype) if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) From 3369f0039bbb86e344e9ba36509293c442c5e332 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:41:26 -0800 Subject: [PATCH 062/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 32caa4521..1c58e34b6 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -853,7 +853,7 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] pass - print(idx, hidden_states.dtype) + print(idx, hidden_states.dtype, end = " ") if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) From 61ecb22c9a2d58b8e4d05113c3cb0fe2c75134c3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:45:22 -0800 Subject: [PATCH 063/473] Update llama.py --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1c58e34b6..824a3ccf1 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -789,11 +789,13 @@ def LlamaModel_fast_forward( if transformers_version > "4.47.1" and hasattr(self, "rotary_emb"): # Transformers main has made it mandatory to pass position_embeddings # https://github.com/huggingface/transformers/pull/34858 + print("***") position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) else: position_embeddings = None # Go through every layer! + print("START", hidden_states.dtype) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) From ec033328d596568a6c24bd4343b389eff110e9cb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:50:56 -0800 Subject: [PATCH 064/473] Update llama.py --- unsloth/models/llama.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 824a3ccf1..8a1d2c99b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -498,7 +498,9 @@ def LlamaDecoderLayer_fast_forward( hidden_states += residual else: residual = hidden_states + print(501, hidden_states.dtype) hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) + print(503, hidden_states.dtype) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states = hidden_states, causal_mask = causal_mask, @@ -510,12 +512,16 @@ def LlamaDecoderLayer_fast_forward( padding_mask = padding_mask, position_embeddings = position_embeddings, ) + print(515, hidden_states.dtype) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states + print(520, hidden_states.dtype) hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states) + print(522, hidden_states.dtype) hidden_states = self.mlp(hidden_states) + print(524, hidden_states.dtype) hidden_states = residual + hidden_states pass From fa02ce1401423e1970699edf596d24e65b260011 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:02:24 -0800 Subject: [PATCH 065/473] Update llama.py --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8a1d2c99b..35e7a2b35 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -389,6 +389,7 @@ def LlamaAttention_fast_forward( if position_ids is None else inplace_rope_embedding(Q, K, cos, sin, position_ids) ) + print(392, Q.dtype, K.dtype) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) @@ -441,6 +442,7 @@ def LlamaAttention_fast_forward( # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass + print(445, A.dtype) attn_output = A.reshape(bsz, q_len, n_heads*head_dim) attn_output = self.apply_o(self, attn_output) attn_weights = None From 06d40574dc93af0e09dae9e8bc353f7de51428c2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:02:35 -0800 Subject: [PATCH 066/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 35e7a2b35..cc3caaa95 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -389,7 +389,7 @@ def LlamaAttention_fast_forward( if position_ids is None else inplace_rope_embedding(Q, K, cos, sin, position_ids) ) - print(392, Q.dtype, K.dtype) + print(392, Q.dtype, K.dtype, position_ids) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) From 500479640f2cc7512d3ebd8345a2145d3fc28ab6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:05:04 -0800 Subject: [PATCH 067/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index cc3caaa95..8ce319bbe 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -384,6 +384,7 @@ def LlamaAttention_fast_forward( else: cos, sin = rotary_emb(V, seq_len=kv_seq_len) + print(387, Q.dtype, K.dtype, position_ids) Q, K = ( fast_rope_embedding(Q, K, cos, sin) if position_ids is None From 2608fe4aa66ef9f4b82421cd0c7bf5ad367495a8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:10:41 -0800 Subject: [PATCH 068/473] Update llama.py --- unsloth/models/llama.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8ce319bbe..0765e4289 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -384,13 +384,11 @@ def LlamaAttention_fast_forward( else: cos, sin = rotary_emb(V, seq_len=kv_seq_len) - print(387, Q.dtype, K.dtype, position_ids) Q, K = ( fast_rope_embedding(Q, K, cos, sin) if position_ids is None else inplace_rope_embedding(Q, K, cos, sin, position_ids) ) - print(392, Q.dtype, K.dtype, position_ids) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) @@ -443,7 +441,6 @@ def LlamaAttention_fast_forward( # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass - print(445, A.dtype) attn_output = A.reshape(bsz, q_len, n_heads*head_dim) attn_output = self.apply_o(self, attn_output) attn_weights = None @@ -501,9 +498,7 @@ def LlamaDecoderLayer_fast_forward( hidden_states += residual else: residual = hidden_states - print(501, hidden_states.dtype) hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) - print(503, hidden_states.dtype) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states = hidden_states, causal_mask = causal_mask, @@ -515,16 +510,12 @@ def LlamaDecoderLayer_fast_forward( padding_mask = padding_mask, position_embeddings = position_embeddings, ) - print(515, hidden_states.dtype) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - print(520, hidden_states.dtype) hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states) - print(522, hidden_states.dtype) hidden_states = self.mlp(hidden_states) - print(524, hidden_states.dtype) hidden_states = residual + hidden_states pass @@ -798,13 +789,11 @@ def LlamaModel_fast_forward( if transformers_version > "4.47.1" and hasattr(self, "rotary_emb"): # Transformers main has made it mandatory to pass position_embeddings # https://github.com/huggingface/transformers/pull/34858 - print("***") position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) else: position_embeddings = None # Go through every layer! - print("START", hidden_states.dtype) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -864,7 +853,6 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] pass - print(idx, hidden_states.dtype, end = " ") if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) @@ -878,9 +866,7 @@ def custom_forward(*inputs): elif IS_COHERE: hidden_states = self.norm(hidden_states) else: - print(0, hidden_states.dtype) hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA) - print(1, hidden_states.dtype) pass if output_hidden_states: all_hidden_states += (hidden_states,) From 2b3391f478cf6545c92be9421944a0e5171670fe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:40:04 -0800 Subject: [PATCH 069/473] Auto change is_bfloat16_supported --- unsloth/models/_utils.py | 11 +++++++++-- unsloth/models/llama.py | 10 ++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 3752d46d0..9d75fda16 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -15,6 +15,10 @@ __version__ = "2025.1.1" __all__ = [ + "SUPPORTS_BFLOAT16", + "is_bfloat16_supported", + "USE_BFLOAT16", + "prepare_model_for_kbit_training", "xformers", "xformers_attention", @@ -30,7 +34,6 @@ "offload_to_disk", "offload_input_embeddings", "offload_output_embeddings", - "is_bfloat16_supported", "unsloth_offloaded_gradient_checkpoint", "torch_compile_options", "patch_linear_scaling", @@ -773,9 +776,13 @@ def offload_output_embeddings(model, temporary_location : str = "_unsloth_tempor pass +# Log dtype used - sometimes people use float16 on bfloat16 platforms +global USE_BFLOAT16 +USE_BFLOAT16 = SUPPORTS_BFLOAT16 # Fixes a weird Torch 2.3 bug which says T4s have bfloat16 def is_bfloat16_supported(): - return SUPPORTS_BFLOAT16 + global USE_BFLOAT16 + return SUPPORTS_BFLOAT16 and USE_BFLOAT16 pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0765e4289..4ffb18f68 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -68,6 +68,8 @@ from triton import __version__ as triton_version BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if xformers is not None else None +from ._utils import SUPPORTS_BFLOAT16, USE_BFLOAT16 + def original_apply_qkv(self, X): Q = self.q_proj(X) @@ -1387,7 +1389,8 @@ def __init__(self, # self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) # Short sequences - dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16 + global USE_BFLOAT16 + dtype = torch.bfloat16 if USE_BFLOAT16 else torch.float16 t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.short_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) @@ -1580,7 +1583,6 @@ def from_pretrained( pass if token is None: token = get_token() if model_patcher is None: model_patcher = FastLlamaModel - SUPPORTS_BFLOAT16 = is_bfloat16_supported() gpu_stats = torch.cuda.get_device_properties(0) max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) @@ -1612,6 +1614,10 @@ def from_pretrained( assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) + # Log global device type used + global USE_BFLOAT16 + USE_BFLOAT16 = True if dtype == torch.bfloat16 else False + # RoPE Scaling model_config = AutoConfig.from_pretrained(model_name, token = token) model_max_seq_length = model_config.max_position_embeddings From a1b897ec3ab216692f4e78aef5c742ba6249417f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:43:20 -0800 Subject: [PATCH 070/473] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4ffb18f68..16159b128 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1617,7 +1617,8 @@ def from_pretrained( # Log global device type used global USE_BFLOAT16 USE_BFLOAT16 = True if dtype == torch.bfloat16 else False - + print(USE_BFLOAT16) + # RoPE Scaling model_config = AutoConfig.from_pretrained(model_name, token = token) model_max_seq_length = model_config.max_position_embeddings From ce840954589e9b96a4a5a6e0034988fcc587b6f0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:51:49 -0800 Subject: [PATCH 071/473] Force data-type --- unsloth/models/_utils.py | 7 +------ unsloth/models/llama.py | 14 +++++--------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 9d75fda16..86adc0e63 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -17,7 +17,6 @@ __all__ = [ "SUPPORTS_BFLOAT16", "is_bfloat16_supported", - "USE_BFLOAT16", "prepare_model_for_kbit_training", "xformers", @@ -776,13 +775,9 @@ def offload_output_embeddings(model, temporary_location : str = "_unsloth_tempor pass -# Log dtype used - sometimes people use float16 on bfloat16 platforms -global USE_BFLOAT16 -USE_BFLOAT16 = SUPPORTS_BFLOAT16 # Fixes a weird Torch 2.3 bug which says T4s have bfloat16 def is_bfloat16_supported(): - global USE_BFLOAT16 - return SUPPORTS_BFLOAT16 and USE_BFLOAT16 + return SUPPORTS_BFLOAT16 pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 16159b128..16dcd587a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -68,8 +68,6 @@ from triton import __version__ as triton_version BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if xformers is not None else None -from ._utils import SUPPORTS_BFLOAT16, USE_BFLOAT16 - def original_apply_qkv(self, X): Q = self.q_proj(X) @@ -1389,8 +1387,7 @@ def __init__(self, # self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) # Short sequences - global USE_BFLOAT16 - dtype = torch.bfloat16 if USE_BFLOAT16 else torch.float16 + dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16 t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.short_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) @@ -1583,6 +1580,7 @@ def from_pretrained( pass if token is None: token = get_token() if model_patcher is None: model_patcher = FastLlamaModel + SUPPORTS_BFLOAT16 = is_bfloat16_supported() gpu_stats = torch.cuda.get_device_properties(0) max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) @@ -1611,14 +1609,12 @@ def from_pretrained( elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: logger.warning_once("Device does not support bfloat16. Will change to float16.") dtype = torch.float16 + elif dtype == torch.float16 and SUPPORTS_BFLOAT16: + logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.") + dtype = torch.float16 assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) - # Log global device type used - global USE_BFLOAT16 - USE_BFLOAT16 = True if dtype == torch.bfloat16 else False - print(USE_BFLOAT16) - # RoPE Scaling model_config = AutoConfig.from_pretrained(model_name, token = token) model_max_seq_length = model_config.max_position_embeddings From ad31cb699f403b333f6210668f8edfcdaba430d9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:56:32 -0800 Subject: [PATCH 072/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 16dcd587a..ba98bec8b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1611,7 +1611,7 @@ def from_pretrained( dtype = torch.float16 elif dtype == torch.float16 and SUPPORTS_BFLOAT16: logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.") - dtype = torch.float16 + dtype = torch.bfloat16 assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) From d7a2057ca60f5281fbe8d6ae0ef3e15aed60a2d9 Mon Sep 17 00:00:00 2001 From: Kareem <81531392+KareemMusleh@users.noreply.github.com> Date: Tue, 7 Jan 2025 17:41:15 +0700 Subject: [PATCH 073/473] All attention refactor fix (#1491) * change initilization of n_heads, n_kv_heads, hidden_size in llama.py * do the same for cohere, mistral, gemma2, granite * do the same for flexattention,cohere, mistral, granite --- unsloth/kernels/flex_attention.py | 10 +++++----- unsloth/models/cohere.py | 18 ++++++++++-------- unsloth/models/gemma2.py | 14 ++++++++------ unsloth/models/granite.py | 14 ++++++++------ unsloth/models/llama.py | 18 ++++++++++-------- unsloth/models/mistral.py | 12 ++++++------ 6 files changed, 47 insertions(+), 39 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 887ffca1b..6f8239422 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -43,9 +43,9 @@ # Logit softcapping @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): - n_heads = self.num_heads + n_heads = self.config.num_attention_heads head_dim = self.head_dim - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads n_groups = self.num_key_value_groups # Grouped query attention @@ -130,7 +130,7 @@ def flex_attention(s, t): pass def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): - n_heads = self.num_heads + n_heads = self.config.num_attention_heads head_dim = self.head_dim s = self.config.query_pre_attn_scalar t = self.config.attn_logit_softcapping @@ -147,9 +147,9 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): torch_tanh = torch.tanh torch_nn_functional_softmax = torch.nn.functional.softmax def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): - n_heads = self.num_heads + n_heads = self.config.num_attention_heads head_dim = self.head_dim - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads n_groups = self.num_key_value_groups # Grouped query attention diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 1610949f6..0c36abf68 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -94,9 +94,9 @@ def CohereAttention_fast_forward( bsz, q_len, _ = hidden_states.size() - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) @@ -259,12 +259,14 @@ def CohereAttention_fast_forward_inference( K1, V1 = past_key_value dtype = Xn.dtype - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim - attention_size = n_heads*head_dim # assert(n_kv_heads * n_groups == n_heads) + + hidden_size = self.config.hidden_size + attention_size = n_heads*head_dim seq_len = K1.shape[-2] kv_seq_len = seq_len + 1 @@ -281,10 +283,10 @@ def CohereAttention_fast_forward_inference( self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") # Mistral Nemo 12b has weird dimensions - if attention_size != self.hidden_size: - self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") + if attention_size != hidden_size: + self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0") else: - self.temp_O = self.temp_QA[1][:,:,:self.hidden_size] + self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 0f0a02071..be6b0469d 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -98,9 +98,9 @@ def Gemma2Attention_fast_forward( bsz, q_len, _ = hidden_states.size() - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) @@ -255,12 +255,14 @@ def Gemma2Attention_fast_forward_inference( K1, V1 = past_key_value dtype = Xn.dtype - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim - attention_size = n_heads*head_dim # assert(n_kv_heads * n_groups == n_heads) + + hidden_size = self.config.hidden_size + attention_size = n_heads*head_dim seq_len = K1.shape[-2] kv_seq_len = seq_len + 1 @@ -276,7 +278,7 @@ def Gemma2Attention_fast_forward_inference( self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") # Only for Gemma2 - self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") + self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0") self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 9466a8d6c..f8c29627f 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -84,9 +84,9 @@ def GraniteAttention_fast_forward( bsz, q_len, _ = hidden_states.size() - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) @@ -257,12 +257,14 @@ def GraniteAttention_fast_forward_inference( K1, V1 = past_key_value dtype = Xn.dtype - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim - attention_size = n_heads*head_dim # assert(n_kv_heads * n_groups == n_heads) + + hidden_size = self.config.hidden_size + attention_size = n_heads*head_dim seq_len = K1.shape[-2] kv_seq_len = seq_len + 1 @@ -278,7 +280,7 @@ def GraniteAttention_fast_forward_inference( self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") # Only for Gemma2 - self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") + self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0") self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ba98bec8b..5ce2f6195 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -146,12 +146,14 @@ def LlamaAttention_fast_forward_inference( K1, V1 = past_key_value dtype = Xn.dtype - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim - attention_size = n_heads*head_dim # assert(n_kv_heads * n_groups == n_heads) + + hidden_size = self.config.hidden_size + attention_size = n_heads*head_dim seq_len = K1.shape[-2] kv_seq_len = seq_len + 1 @@ -168,10 +170,10 @@ def LlamaAttention_fast_forward_inference( self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") # Mistral Nemo 12b has weird dimensions - if attention_size != self.hidden_size: - self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") + if attention_size != hidden_size: + self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0") else: - self.temp_O = self.temp_QA[1][:,:,:self.hidden_size] + self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") @@ -356,9 +358,9 @@ def LlamaAttention_fast_forward( bsz, q_len, _ = hidden_states.size() - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index d6c694666..9a97015f9 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -64,9 +64,9 @@ def MistralAttention_fast_forward( bsz, q_len, _ = hidden_states.size() - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) @@ -278,16 +278,16 @@ def MistralForCausalLM_fast_forward( # Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now. def patch_mistral_nemo_attention(function): function = function.replace( - "(self.head_dim * self.num_heads) != self.hidden_size", + "(self.head_dim * self.config.num_attention_heads) != self.config.hidden_size", "False", ) function = function.replace( - "self.head_dim = self.hidden_size // self.num_heads", + "self.head_dim = self.config.hidden_size // self.config.num_attention_heads", "self.head_dim = config.head_dim", ) function = function.replace( - "self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)", - "self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)", + "self.o_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)", + "self.o_proj = nn.Linear(self.config.num_attention_heads * self.head_dim, self.config.hidden_size, bias=False)", ) return function pass From 0cb9c5f667883ae54eb80c5c3bf87f44d935d72a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 03:33:46 -0800 Subject: [PATCH 074/473] Update llama.py --- unsloth/models/llama.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 5ce2f6195..7d803bbe9 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -20,6 +20,10 @@ from ._utils import __version__ from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version +from unsloth_zoo.utils import Version +transformers_version = Version(transformers_version) +# Transformers moved rotary embeddings out of all attention layers +IS_ATTENTION_REFACTOR = transformers_version > Version("4.47.1") from transformers.models.llama.modeling_llama import ( logger, BaseModelOutputWithPast, @@ -788,12 +792,7 @@ def LlamaModel_fast_forward( pass pass - if transformers_version > "4.47.1" and hasattr(self, "rotary_emb"): - # Transformers main has made it mandatory to pass position_embeddings - # https://github.com/huggingface/transformers/pull/34858 - position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) - else: - position_embeddings = None + position_embeddings = None # Go through every layer! for idx, decoder_layer in enumerate(self.layers): @@ -1886,6 +1885,13 @@ def from_pretrained( internal_model = internal_model.model pass internal_model._saved_temp_tokenizer = tokenizer + + # For transformers > 4.47.1, we need to add rotary_emb to all attention layers + if IS_ATTENTION_REFACTOR or hasattr(model.model, "rotary_emb"): + rotary_emb = model.model.rotary_emb + for layer in model.model.layers: + layer.self_attn.rotary_emb = rotary_emb + pass return model, tokenizer pass From e3a92e0e77a07f391eafc28447255d9b282c345f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 03:39:08 -0800 Subject: [PATCH 075/473] Update llama.py --- unsloth/models/llama.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7d803bbe9..edd3ddf94 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -792,7 +792,12 @@ def LlamaModel_fast_forward( pass pass - position_embeddings = None + if IS_ATTENTION_REFACTOR and not hasattr(self.layers[0].self_attn, "rotary_emb"): + # Transformers main has made it mandatory to pass position_embeddings + # https://github.com/huggingface/transformers/pull/34858 + position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) + else: + position_embeddings = None # Go through every layer! for idx, decoder_layer in enumerate(self.layers): From 422c0334c5785a2c81f7ba4d7ddae331a61b970a Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Tue, 7 Jan 2025 17:19:11 +0530 Subject: [PATCH 076/473] Update granite to work with latest post_patch methods (#1502) * Update granite to work with latest post_patch methods * Pass position_embeddings for granite even if transformers<4.47 * Update llama.py --------- Co-authored-by: Daniel Han --- unsloth/models/granite.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index f8c29627f..e67c9f1cf 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -20,7 +20,8 @@ LlamaLinearScalingRotaryEmbedding, ) from .mistral import * - +from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit +from peft.tuners.lora import Linear4bit as Peft_Linear4bit try: from transformers.models.granite.modeling_granite import ( GraniteAttention, @@ -423,6 +424,18 @@ class GraniteRotaryEmbedding(LlamaRotaryEmbedding): def __init__(self, config): super().__init__(config = config) +def patched_init(original_init): + def new_init(self, *args, **kwargs): + # we can use self.residual_multiplier arg in GraniteDecoderLayer_fast_forward as mentioned here + # https://github.com/huggingface/transformers/blob/e5fd865ebae062b7cf03a81b8c6affeb39f30bec/src/transformers/models/granite/modeling_granite.py#L243 + # The problem is, we don't have access to either the value or config in GraniteModel_fast_forward_inference + # So we need a way to pass this value around. It is probably better to pass on entire config just in case we need it later + config = kwargs.get("config", args[0] if args else None) + if config is not None: + self.config = config + original_init(self, *args, **kwargs) + return new_init + class FastGraniteModel(FastLlamaModel): @staticmethod @@ -437,12 +450,13 @@ def pre_patch(): exec(function, globals()) GraniteAttention.__init__ = eval(init_name) pass - GraniteAttention .forward = GraniteAttention_fast_forward - GraniteSdpaAttention .forward = GraniteAttention_fast_forward - GraniteFlashAttention2.forward = GraniteAttention_fast_forward - GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward - GraniteModel .forward = LlamaModel_fast_forward - GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference) + GraniteAttention .forward = GraniteAttention_fast_forward + GraniteSdpaAttention .forward = GraniteAttention_fast_forward + GraniteFlashAttention2.forward = GraniteAttention_fast_forward + GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward + GraniteModel .forward = LlamaModel_fast_forward + GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference) + GraniteForCausalLM .__init__ = patched_init(GraniteForCausalLM.__init__) PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward fix_prepare_inputs_for_generation(GraniteForCausalLM) @@ -454,7 +468,7 @@ def pre_patch(): @staticmethod - def post_patch(model): + def post_patch(model, tokenizer): # Torch.compile fails on embedding matrix?? # Workaround randomnly fixes it for torch versions < 2.2 @@ -519,7 +533,7 @@ def post_patch(model): for _ in range(3): gc.collect() torch.cuda.empty_cache() - return model + return model, tokenizer pass pass From 83b48a894bcda0fe3486129e2213cf5aee1f5f88 Mon Sep 17 00:00:00 2001 From: Z Date: Tue, 7 Jan 2025 04:58:40 -0700 Subject: [PATCH 077/473] Minor fixes for granite models (#1503) * Update granite.py Grab residual multiplier directly from layer * Update llama.py Version should read >= 4.47.1 as that is the version requiring the changes * Update granite.py * Update llama.py --------- Co-authored-by: Daniel Han --- unsloth/models/granite.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index e67c9f1cf..497a357fe 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -182,6 +182,11 @@ def GraniteDecoderLayer_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ): + residual_multiplier = \ + self.residual_multiplier \ + if hasattr(self, "residual_multiplier") else \ + self.config.residual_multiplier + if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None: residual = hidden_states hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states) @@ -197,13 +202,13 @@ def GraniteDecoderLayer_fast_forward( position_embeddings = position_embeddings, _flag_for_generation=self._flag_for_generation, ) - hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) # Fully Connected residual = hidden_states hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states) hidden_states = fast_swiglu_inference(self.mlp, hidden_states) - hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) else: residual = hidden_states hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) @@ -218,13 +223,13 @@ def GraniteDecoderLayer_fast_forward( padding_mask=padding_mask, position_embeddings = position_embeddings, ) - hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) # Fully Connected residual = hidden_states hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) pass outputs = (hidden_states,) @@ -370,6 +375,10 @@ def GraniteModel_fast_forward_inference( hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) hidden_states *= self.model.embedding_multiplier + residual_multiplier = \ + self.residual_multiplier \ + if hasattr(self, "residual_multiplier") else \ + self.config.residual_multiplier bsz, q_len, hd = hidden_states.shape seq_len = past_key_values[0][0].shape[-2] @@ -401,12 +410,12 @@ def GraniteModel_fast_forward_inference( position_embeddings = position_embeddings, ) - hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) residual = hidden_states hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states) hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states) - hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) next_decoder_cache.append(present_key_value) pass From e0ccfafd107b369d765fa06b6ace098b938ec5b9 Mon Sep 17 00:00:00 2001 From: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:09:36 +0800 Subject: [PATCH 078/473] support modelscope models and datasets (#1481) * support modelscope * change modelscope args * remove useless import * remove useless import * fix * wip * fix * remove useless code * add readme * add some comments * change print to raise error * update comment * Update loader.py --------- Co-authored-by: Daniel Han --- README.md | 3 +++ unsloth-cli.py | 12 ++++++++++-- unsloth/models/loader.py | 19 +++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6bff98cbd..f658e6ceb 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,9 @@ For **advanced installation instructions** or if you see weird errors during ins - Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more! - We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code! - We're in 🤗Hugging Face's official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)! +- If you want to download models from the ModelScope community, please use an environment variable: `UNSLOTH_USE_MODELSCOPE=1`, and install the modelscope library by: `pip install modelscope -U`. + +> unsloth_cli.py also supports `UNSLOTH_USE_MODELSCOPE=1` to download models and datasets. please remember to use the model and dataset id in the ModelScope community. ```python from unsloth import FastLanguageModel diff --git a/unsloth-cli.py b/unsloth-cli.py index ddb0ac8b7..b7613f92d 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -30,11 +30,14 @@ """ import argparse +import os + def run(args): import torch from unsloth import FastLanguageModel from datasets import load_dataset + from transformers.utils import strtobool from trl import SFTTrainer from transformers import TrainingArguments from unsloth import is_bfloat16_supported @@ -86,8 +89,13 @@ def formatting_prompts_func(examples): texts.append(text) return {"text": texts} - # Load and format dataset - dataset = load_dataset(args.dataset, split="train") + use_modelscope = strtobool(os.environ.get('UNSLOTH_USE_MODELSCOPE', 'False')) + if use_modelscope: + from modelscope import MsDataset + dataset = MsDataset.load(args.dataset, split="train") + else: + # Load and format dataset + dataset = load_dataset(args.dataset, split="train") dataset = dataset.map(formatting_prompts_func, batched=True) print("Data is formatted and ready!") diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 657072ab3..e9caad0e6 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -31,6 +31,15 @@ pass from huggingface_hub import HfFileSystem +# [TODO] Move USE_MODELSCOPE to utils +USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1" +if USE_MODELSCOPE: + import importlib + if importlib.util.find_spec("modelscope") is None: + raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`') + pass +pass + # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! from unsloth_zoo.utils import Version, _get_dtype transformers_version = Version(transformers_version) @@ -72,6 +81,11 @@ def from_pretrained( if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) + if USE_MODELSCOPE and not os.path.exists(model_name): + from modelscope import snapshot_download + model_name = snapshot_download(model_name) + pass + # First check if it's a normal model via AutoConfig from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled was_disabled = are_progress_bars_disabled() @@ -355,6 +369,11 @@ def from_pretrained( if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) + if USE_MODELSCOPE and not os.path.exists(model_name): + from modelscope import snapshot_download + model_name = snapshot_download(model_name) + pass + # First check if it's a normal model via AutoConfig from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled was_disabled = are_progress_bars_disabled() From 63ad366d0f82bbaa57858bc3120c101dc209f877 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 8 Jan 2025 12:42:18 -0800 Subject: [PATCH 079/473] Merge branch 'main' into nightly --- pyproject.toml | 4 ++-- unsloth/__init__.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bf4c99528..43ec13fd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2024.12.7", + "unsloth_zoo>=2025.1.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -285,7 +285,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2024.12.7", + "unsloth_zoo>=2025.1.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index bbeded9fc..d460432bb 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -48,9 +48,11 @@ # And optimize pinning of memory os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \ "expandable_segments:True,"\ - "roundup_power2_divisions:[32:256,64:128,256:64,>:32],"\ - "pinned_use_cuda_host_register:True,"\ - "pinned_num_register_threads:8" + "roundup_power2_divisions:[32:256,64:128,256:64,>:32]" + +# [TODO] Check why some GPUs don't work +# "pinned_use_cuda_host_register:True,"\ +# "pinned_num_register_threads:8" # Hugging Face Hub faster downloads if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: From a7d783869db415d58e0ee34270ba090e00b58d46 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 8 Jan 2025 14:38:41 -0800 Subject: [PATCH 080/473] Phi 4 --- unsloth/chat_templates.py | 40 +++++++++++++++++++++++++++++++++++++++ unsloth/models/_utils.py | 2 +- unsloth/models/mapper.py | 5 +++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index da10f7e00..d8dc38522 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -890,6 +890,46 @@ DEFAULT_SYSTEM_MESSAGE["qwen2.5"] = qwen25_default_system_message # No system message in Qwen 2.5 pass +# =========================================== Phi-4 +# "{{ bos_token }}"\ # Phi-4 removes BOS? +phi4_template = \ + "{% for message in messages %}"\ + "{% if (message['role'] == 'system') %}"\ + "{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}"\ + "{% elif (message['role'] == 'user') %}"\ + "{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}"\ + "{% elif (message['role'] == 'assistant') %}"\ + "{{'<|im_start|>assistant<|im_sep|>' + message['content'] + '<|im_end|>'}}"\ + "{% endif %}"\ + "{% endfor %}"\ + "{% if add_generation_prompt %}"\ + "{{ '<|im_start|>assistant<|im_sep|>' }}"\ + "{% endif %}" +pass + +_phi4_ollama_template = \ + "{{ if .System }}<|im_start|><|system|><|im_sep|>{{ .System }}<|im_end|>{{ end }}"\ + "{{ if .Prompt }}<|im_start|><|user|><|im_sep|>{{ .Prompt }}<|im_end|>{{ end }}"\ + "<|im_start|><|assistant|><|im_sep|>{{ .Response }}<|im_end|>" + +# Ollama from https://www.ollama.com/library/phi4 is different +phi4_ollama = \ +f''' +FROM {{__FILE_LOCATION__}} +TEMPLATE """{_phi4_ollama_template}""" +PARAMETER stop "<|im_end|>" +PARAMETER stop "<|im_start|>" +PARAMETER stop "<|im_sep|>" +PARAMETER temperature 1.5 +PARAMETER min_p 0.1 +''' + +phi4_template_eos_token = "<|im_end|>" +CHAT_TEMPLATES["phi-4"] = (phi4_template, phi4_template_eos_token, False, phi4_ollama,) +DEFAULT_SYSTEM_MESSAGE["phi-4"] = None # No system message in Phi-4 +pass + + def _change_system_message(template: str, type_chat_template: str, system_message: str = None): system_message_pattern = r"\{system_message\}" diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 86adc0e63..a93f18cd4 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.1.1" +__version__ = "2025.1.2" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 41f744464..b7b24b5cc 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -520,6 +520,11 @@ "unsloth/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.3-70B-Instruct", ), + "unsloth/phi-4-unsloth-bnb-4bit" : ( + "unsloth/phi-4", + "microsoft/phi-4", + "unsloth/phi-4-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} From 2ced650ac23c09359d0f7e76bc621fc8ba1f56ae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 14 Jan 2025 22:32:44 -0800 Subject: [PATCH 081/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index edd3ddf94..7c7d66d03 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -664,7 +664,7 @@ def LlamaModel_fast_forward( # Fix up attention mask by setting elements to 0 # Specifically for DPO - if self._has_no_labels and (attention_mask is not None) and (past_key_values is None) and \ + if getattr(self, "_has_no_labels", False) is True and (attention_mask is not None) and (past_key_values is None) and \ (not train_embed_tokens): # Careful for inference the attention_mask is size (1, kv_seq_len) # Whilst the input_embeds is size (1, 1, 4096) From dd9b4e1d615ee2ea0015afebca66c43df92432db Mon Sep 17 00:00:00 2001 From: AminWhat <88392440+aminwhat@users.noreply.github.com> Date: Thu, 16 Jan 2025 10:32:23 +0330 Subject: [PATCH 082/473] Torch.Cuda Is Available Condition and Warning (#1545) * check for torch.cuda and triton if available on my machine(mac m3) the cuda were not available * Update pyproject.toml * Update __init__.py --------- Co-authored-by: Daniel Han --- unsloth/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 8002fbaef..7f37a2069 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -86,6 +86,10 @@ del os.environ["PYTORCH_CUDA_ALLOC_CONF"] pass +# First check if CUDA is available ie a NVIDIA GPU is seen +if not torch.cuda.is_available(): + raise NotImplementedError("Unsloth: No NVIDIA GPU found? Unsloth currently only supports GPUs!") + # Fix Xformers performance issues since 0.0.25 import importlib.util from pathlib import Path From bc37b7acc82724985dc415a9abcd57724b4da7f1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 00:56:56 -0800 Subject: [PATCH 083/473] Update mistral.py --- unsloth/models/mistral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 9a97015f9..e52ac2cbf 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -306,6 +306,7 @@ def pre_patch(): # Just for Mistral Nemo models! if function is not None: function = patch_mistral_nemo_attention(function) + print(function) # if True:#init_name is not None: exec(function, globals()) MistralAttention.__init__ = eval(init_name) From 2e7a88643f7a62fe2b568abe6068ca5d48d9a0a1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 00:58:46 -0800 Subject: [PATCH 084/473] Update mistral.py --- unsloth/models/mistral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index e52ac2cbf..4edc3b799 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -305,6 +305,7 @@ def pre_patch(): ) # Just for Mistral Nemo models! if function is not None: + print(function) function = patch_mistral_nemo_attention(function) print(function) # if True:#init_name is not None: From 15e603648399cea29d24913022d3083dc799f3ce Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 01:07:23 -0800 Subject: [PATCH 085/473] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0036a18c4..7ddfef6b5 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -847,6 +847,7 @@ def patch_linear_scaling( rotary_emb = rotary_emb[0] function = function.replace(rotary_emb, fix_rope_function, 1) function = exec_code + "\n\n" + function + print(function) return init_name, function pass From 0b6bb121693d22d2e0fb39135cfac961b4a3438e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 01:09:23 -0800 Subject: [PATCH 086/473] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 7ddfef6b5..ed575a8b4 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -847,7 +847,7 @@ def patch_linear_scaling( rotary_emb = rotary_emb[0] function = function.replace(rotary_emb, fix_rope_function, 1) function = exec_code + "\n\n" + function - print(function) + print(exec_code) return init_name, function pass From 76403f972e2561e8390c37b7ae35ba1c0d9a7606 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 01:10:40 -0800 Subject: [PATCH 087/473] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index ed575a8b4..82b9b6705 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -847,6 +847,7 @@ def patch_linear_scaling( rotary_emb = rotary_emb[0] function = function.replace(rotary_emb, fix_rope_function, 1) function = exec_code + "\n\n" + function + print("###########") print(exec_code) return init_name, function pass From 3c4ef996cb5736fff8fc2b261c92f720c4026d39 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 01:15:42 -0800 Subject: [PATCH 088/473] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 82b9b6705..76edb3ff0 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -800,6 +800,7 @@ def patch_linear_scaling( f"from {model_filepath} import logger, "\ f"{model_name.title()}Attention, {model_name.title()}Config" + print(exec_code) try: function = inspect.getsource(attention_module.__init__) except: From b4c0b02dc0727bc86bd202bf5a5518e96f8381c5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 01:18:15 -0800 Subject: [PATCH 089/473] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 76edb3ff0..279064b5e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -801,6 +801,7 @@ def patch_linear_scaling( f"{model_name.title()}Attention, {model_name.title()}Config" print(exec_code) + print(inspect.getsource(attention_module.__init__)) try: function = inspect.getsource(attention_module.__init__) except: From 24a24bf7c7bd70856b3dec6da5e684c550100af3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 01:22:13 -0800 Subject: [PATCH 090/473] Fix --- unsloth/models/_utils.py | 10 ++++------ unsloth/models/mistral.py | 2 -- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 279064b5e..ff2c8726e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -799,9 +799,7 @@ def patch_linear_scaling( f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\ f"from {model_filepath} import logger, "\ f"{model_name.title()}Attention, {model_name.title()}Config" - - print(exec_code) - print(inspect.getsource(attention_module.__init__)) + try: function = inspect.getsource(attention_module.__init__) except: @@ -845,12 +843,12 @@ def patch_linear_scaling( "self.rotary_emb = .+?\)", function, flags = re.DOTALL | re.MULTILINE, ) - if len(rotary_emb) == 0: return None, function + if len(rotary_emb) == 0: + return None, exec_code + "\n\n" + function + rotary_emb = rotary_emb[0] function = function.replace(rotary_emb, fix_rope_function, 1) function = exec_code + "\n\n" + function - print("###########") - print(exec_code) return init_name, function pass diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 4edc3b799..9a97015f9 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -305,9 +305,7 @@ def pre_patch(): ) # Just for Mistral Nemo models! if function is not None: - print(function) function = patch_mistral_nemo_attention(function) - print(function) # if True:#init_name is not None: exec(function, globals()) MistralAttention.__init__ = eval(init_name) From a953bfc7b55f1a294af7d67ec5bd4a0f8c9aefcd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 03:09:02 -0800 Subject: [PATCH 091/473] Bug fixes --- unsloth/models/_utils.py | 8 ++++++-- unsloth/models/mistral.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index ff2c8726e..2c16bf6e7 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -285,7 +285,11 @@ def _is_openai_available(): return False if _is_package_available("flash_attn"): # Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl" try: - from flash_attn.flash_attn_interface import flash_attn_cuda + try: + # See https://github.com/unslothai/unsloth/issues/1437 + from flash_attn.flash_attn_interface import flash_attn_gpu + except: + from flash_attn.flash_attn_interface import flash_attn_cuda HAS_FLASH_ATTENTION = True # Also check for softcapping @@ -799,7 +803,7 @@ def patch_linear_scaling( f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\ f"from {model_filepath} import logger, "\ f"{model_name.title()}Attention, {model_name.title()}Config" - + try: function = inspect.getsource(attention_module.__init__) except: diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 9a97015f9..784ca9cb4 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -304,7 +304,7 @@ def pre_patch(): attention_module = MistralAttention, ) # Just for Mistral Nemo models! - if function is not None: + if function is not None and init_name is not None: function = patch_mistral_nemo_attention(function) # if True:#init_name is not None: exec(function, globals()) From e6d677bbcda6b319b598405b9aca95db9394dfab Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 19 Jan 2025 01:37:13 -0800 Subject: [PATCH 092/473] Update mapper.py --- unsloth/models/mapper.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index c1113f529..b7df6668b 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -471,20 +471,18 @@ "meta-llama/Llama-3.2-11B-Vision-Instruct", "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", ), - "unsloth/Llama-3.2-90B-Vision-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit" : ( "unsloth/Llama-3.2-90B-Vision-Instruct", "meta-llama/Llama-3.2-90B-Vision-Instruct", - "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit", ), "unsloth/Llama-3.2-11B-Vision-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-11B-Vision", "meta-llama/Llama-3.2-11B-Vision", "unsloth/Llama-3.2-11B-Vision-bnb-4bit", ), - "unsloth/Llama-3.2-90B-Vision-unsloth-bnb-4bit" : ( + "unsloth/Llama-3.2-90B-Vision-bnb-4bit" : ( "unsloth/Llama-3.2-90B-Vision", "meta-llama/Llama-3.2-90B-Vision", - "unsloth/Llama-3.2-90B-Vision-bnb-4bit", ), "unsloth/Pixtral-12B-2409-unsloth-bnb-4bit" : ( "unsloth/Pixtral-12B-2409", From d8d8bdc7d19b553b5f47f8af838307c20e4fccf0 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sun, 19 Jan 2025 17:24:12 +0530 Subject: [PATCH 093/473] Add dropout to granite to match HF's implementation (#1557) Signed-off-by: datta0 --- unsloth/models/granite.py | 7 ++++--- unsloth/models/llama.py | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 497a357fe..fb7e96d8d 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -89,6 +89,7 @@ def GraniteAttention_fast_forward( n_groups = self.num_key_value_groups n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim + dropout_p = self.config.attention_dropout if self.training else 0 assert(n_kv_heads * n_groups == n_heads) Q, K, V = self.apply_qkv(self, hidden_states) @@ -135,7 +136,7 @@ def GraniteAttention_fast_forward( Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) pass - A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling) + A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling, p=dropout_p) A = A.view(bsz, q_len, n_heads, head_dim) elif HAS_FLASH_ATTENTION and attention_mask is None: @@ -143,7 +144,7 @@ def GraniteAttention_fast_forward( K = K.transpose(1, 2) V = V.transpose(1, 2) window = (kv_seq_len, kv_seq_len) - A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling) + A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling, dropout_p=dropout_p) else: # Grouped query attention # if n_groups != 1: @@ -157,7 +158,7 @@ def GraniteAttention_fast_forward( Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, scale = self.scaling, is_causal = False) + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, scale = self.scaling, is_causal = False, dropout_p=dropout_p) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7c7d66d03..da3295adf 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -636,6 +636,7 @@ def LlamaModel_fast_forward( IS_GEMMA2 = self.config.model_type.startswith("gemma2") IS_COHERE = self.config.model_type.startswith("cohere") IS_GRANITE = self.config.model_type.startswith("granite") + train_embed_tokens = self.embed_tokens.weight.requires_grad if IS_GEMMA: @@ -792,9 +793,12 @@ def LlamaModel_fast_forward( pass pass - if IS_ATTENTION_REFACTOR and not hasattr(self.layers[0].self_attn, "rotary_emb"): + if (IS_ATTENTION_REFACTOR and (hasattr(self, "rotary_emb") or not hasattr(self.layers[0].self_attn, "rotary_emb"))) or IS_GRANITE: # Transformers main has made it mandatory to pass position_embeddings # https://github.com/huggingface/transformers/pull/34858 + # Also, transformers 4.45.0 supports granite but with the attention refactor (it always had the refactor) + # unsloth's check for granite too has "version >= 4.45.0 (rightly so)". + # so let granite always use the attention refactor implementation. position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) else: position_embeddings = None From f42d0e9b3250d80e803a1f98773b64e5abfd2116 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 19 Jan 2025 15:24:14 -0800 Subject: [PATCH 094/473] Update llama.py --- unsloth/models/llama.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index da3295adf..ff52f1cff 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -949,6 +949,10 @@ def LlamaModel_fast_forward_inference( ) pass +global global_hidden_states +global global_labels +global_hidden_states = None +global_labels = None def CausalLM_fast_forward(fast_forward_inference): def _CausalLM_fast_forward( @@ -1021,6 +1025,11 @@ def _CausalLM_fast_forward( if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) + global global_hidden_states + global global_labels + global_hidden_states = hidden_states + global_labels = labels + raise loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = lm_head, From b667bc6f6d56fbfa72469460301587558667556e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 19 Jan 2025 19:19:08 -0800 Subject: [PATCH 095/473] Update llama.py --- unsloth/models/llama.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ff52f1cff..da3295adf 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -949,10 +949,6 @@ def LlamaModel_fast_forward_inference( ) pass -global global_hidden_states -global global_labels -global_hidden_states = None -global_labels = None def CausalLM_fast_forward(fast_forward_inference): def _CausalLM_fast_forward( @@ -1025,11 +1021,6 @@ def _CausalLM_fast_forward( if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) - global global_hidden_states - global global_labels - global_hidden_states = hidden_states - global_labels = labels - raise loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = lm_head, From 1ce40cea137f4dfedaf1e91d3203c100c024c2f8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 20 Jan 2025 01:10:55 -0800 Subject: [PATCH 096/473] Bug fixes --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- unsloth/models/_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b24abd355..d9df119a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.1.2", + "unsloth_zoo>=2025.1.4", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -285,7 +285,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.1.2", + "unsloth_zoo>=2025.1.4", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 7f37a2069..4882eaf63 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.1.2"): + if Version(unsloth_zoo_version) < Version("2025.1.4"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 2c16bf6e7..bfb1786ee 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.1.5" +__version__ = "2025.1.6" __all__ = [ "SUPPORTS_BFLOAT16", From cdb32596ddccc6cbfe7662186b8486c9dd6fce3b Mon Sep 17 00:00:00 2001 From: Zhe Zhang <2631992879@qq.com> Date: Mon, 20 Jan 2025 17:25:31 +0800 Subject: [PATCH 097/473] fix: flash_attn_detection_error (#1556) * fix: flash_attn_detection_error * Update _utils.py --------- Co-authored-by: Daniel Han From 65329491b704f80183d9020cf5d67462f922545f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 31 Jan 2025 03:02:37 -0800 Subject: [PATCH 098/473] Update mapper.py --- unsloth/models/mapper.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 72619cf05..bc01c2858 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -432,21 +432,25 @@ "unsloth/Qwen2.5-Coder-32B-Instruct", "Qwen/Qwen2.5-Coder-32B-Instruct", ), - "unsloth/Llama-3.2-1B-bnb-4bit" : ( + "unsloth/Llama-3.2-1B-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-1B", "meta-llama/Llama-3.2-1B", + "unsloth/Llama-3.2-1B-bnb-4bit", ), - "unsloth/Llama-3.2-3B-bnb-4bit" : ( + "unsloth/Llama-3.2-3B-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-3B", "meta-llama/Llama-3.2-3B", + "unsloth/Llama-3.2-3B-bnb-4bit", ), - "unsloth/Llama-3.2-1B-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct", + "unsloth/Llama-3.2-1B-Instruct-bnb-4bit", ), - "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", + "unsloth/Llama-3.2-3B-Instruct-bnb-4bit", ), "unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit" : ( "unsloth/Llama-3.1-Nemotron-70B-Instruct", @@ -550,6 +554,31 @@ "unsloth/DeepSeek-R1-Distill-Llama-70B", "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", ), + "unsloth/Mistral-Small-24B-Base-2501-unsloth-bnb-4bit" : ( + "unsloth/Mistral-Small-24B-Base", + "mistralai/Mistral-Small-24B-Base-2501", + "unsloth/Mistral-Small-24B-Base-2501-bnb-4bit", + ), + "unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit" : ( + "unsloth/Mistral-Small-24B-Instruct", + "mistralai/Mistral-Small-24B-Instruct-2501", + "unsloth/Mistral-Small-24B-Instruct-2501-bnb-4bit", + ), + "unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Qwen2.5-VL-3B-Instruct", + "Qwen/Qwen2.5-VL-3B-Instruct", + "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit", + ), + "unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Qwen2.5-VL-7B-Instruct", + "Qwen/Qwen2.5-VL-7B-Instruct", + "unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit", + ), + "unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Qwen2.5-VL-72B-Instruct", + "Qwen/Qwen2.5-VL-72B-Instruct", + "unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} From ea492f2ef1b7b28529c5eeabdabe8ea2138613fb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 1 Feb 2025 17:54:00 -0800 Subject: [PATCH 099/473] Update gemma.py --- unsloth/models/gemma.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index c65434328..408c55440 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -210,7 +210,14 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= config = None, # [TODO] Hack to pass in config - need to remove later ): super().__init__() - if config is not None: return # [TODO] Hack to pass in config - need to remove later + if config is not None: + # [TODO] Hack to pass in config - need to remove later + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + dim = int((config.hidden_size // config.num_attention_heads)) + device = "cuda" + max_position_embeddings = config.max_position_embeddings + pass self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base From e4c3557981fc113f81a77e34b982ad8520a47e45 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 1 Feb 2025 19:02:36 -0800 Subject: [PATCH 100/473] Update gemma.py --- unsloth/models/gemma.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 408c55440..53d0bb51a 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -223,6 +223,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) + print(dim, max_position_embeddings, base) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) From ad3039bd79470b8a0dcb2f1d5b6464b5afcee4dc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 1 Feb 2025 19:11:20 -0800 Subject: [PATCH 101/473] Update gemma.py --- unsloth/models/gemma.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 53d0bb51a..d94f24071 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -211,6 +211,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= ): super().__init__() if config is not None: + print(config) + print(dir(config)) # [TODO] Hack to pass in config - need to remove later base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 From ffe6a7392d100d2528096909c3f67d036bd10be3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 1 Feb 2025 19:15:55 -0800 Subject: [PATCH 102/473] Update gemma.py --- unsloth/models/gemma.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index d94f24071..23561ed07 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -211,12 +211,11 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= ): super().__init__() if config is not None: - print(config) - print(dir(config)) # [TODO] Hack to pass in config - need to remove later base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads)) + dim = getattr(config, "head_dim", None) + if dim is None: dim = int((config.hidden_size // config.num_attention_heads)) device = "cuda" max_position_embeddings = config.max_position_embeddings pass From a5226ebdab7cce088e8357e343de0027c46a8847 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 1 Feb 2025 19:27:33 -0800 Subject: [PATCH 103/473] dim fix --- unsloth/models/gemma.py | 1 - unsloth/models/llama.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 23561ed07..bc29c46ab 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -224,7 +224,6 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - print(dim, max_position_embeddings, base) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index da3295adf..4b64c74f3 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1159,7 +1159,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= # [TODO] Hack to pass in config - need to remove later base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads)) + dim = getattr(config, "head_dim", None) + if dim is None: dim = int((config.hidden_size // config.num_attention_heads)) device = "cuda" max_position_embeddings = config.max_position_embeddings pass From e45342c8b2403f78e24b230078af1f3ac0e03cb7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 1 Feb 2025 23:46:40 -0800 Subject: [PATCH 104/473] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index b0d51a860..017b5b553 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.1.8" +__version__ = "2025.2.1" __all__ = [ "SUPPORTS_BFLOAT16", From c81ce12eb1a21c074e995397d28682b854732d2b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 00:43:26 -0800 Subject: [PATCH 105/473] Torch 2.6 support --- pyproject.toml | 105 ++++++++++++++++++++++++++++++++++++--- unsloth/_auto_install.py | 6 ++- 2 files changed, 101 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d9df119a1..88c757b33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,12 @@ cu124onlytorch240 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] +cu118onlytorch250 = [ + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", +] cu121onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", @@ -147,6 +153,12 @@ cu124onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] +cu118onlytorch251 = [ + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", +] cu121onlytorch251 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", @@ -163,6 +175,28 @@ cu124onlytorch251 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] +cu118onlytorch260 = [ + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", +] +cu124onlytorch260 = [ + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", +] +cu126onlytorch260 = [ + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", +] cu118 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", @@ -223,21 +257,31 @@ cu121-torch240 = [ "bitsandbytes>=0.43.3", "unsloth[cu121onlytorch240]", ] -cu121-torch250 = [ +cu124-torch240 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", - "unsloth[cu121onlytorch250]", + "unsloth[cu124onlytorch240]", ] -cu124-torch240 = [ +cu118-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", - "unsloth[cu124onlytorch240]", + "unsloth[cu118onlytorch250]", +] +cu121-torch250 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.43.3", + "unsloth[cu121onlytorch250]", ] cu124-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", "unsloth[cu124onlytorch250]", ] +cu118-torch251 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.43.3", + "unsloth[cu118onlytorch251]", +] cu121-torch251 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", @@ -248,6 +292,21 @@ cu124-torch251 = [ "bitsandbytes>=0.43.3", "unsloth[cu124onlytorch251]", ] +cu118-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu118onlytorch260]", +] +cu124-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu124onlytorch260]", +] +cu126-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu126onlytorch260]", +] kaggle = [ "unsloth[huggingface]", ] @@ -381,16 +440,22 @@ cu121-ampere-torch240 = [ "unsloth[cu121onlytorch240]", "unsloth[flashattention]", ] -cu121-ampere-torch250 = [ +cu124-ampere-torch240 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", - "unsloth[cu121onlytorch250]", + "unsloth[cu124onlytorch240]", "unsloth[flashattention]", ] -cu124-ampere-torch240 = [ +cu118-ampere-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", - "unsloth[cu124onlytorch240]", + "unsloth[cu118onlytorch250]", + "unsloth[flashattention]", +] +cu121-ampere-torch250 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.43.3", + "unsloth[cu121onlytorch250]", "unsloth[flashattention]", ] cu124-ampere-torch250 = [ @@ -399,6 +464,12 @@ cu124-ampere-torch250 = [ "unsloth[cu124onlytorch250]", "unsloth[flashattention]", ] +cu118-ampere-torch251 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.43.3", + "unsloth[cu118onlytorch251]", + "unsloth[flashattention]", +] cu121-ampere-torch251 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", @@ -411,6 +482,24 @@ cu124-ampere-torch251 = [ "unsloth[cu124onlytorch251]", "unsloth[flashattention]", ] +cu118-ampere-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu118onlytorch260]", + "unsloth[flashattention]", +] +cu124-ampere-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu124onlytorch260]", + "unsloth[flashattention]", +] +cu126-ampere-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu126onlytorch260]", + "unsloth[flashattention]", +] [project.urls] homepage = "http://www.unsloth.ai" diff --git a/unsloth/_auto_install.py b/unsloth/_auto_install.py index c3b94c670..8bb548519 100644 --- a/unsloth/_auto_install.py +++ b/unsloth/_auto_install.py @@ -18,14 +18,16 @@ v = V(torch.__version__) cuda = str(torch.version.cuda) is_ampere = torch.cuda.get_device_capability()[0] >= 8 -if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": raise RuntimeError(f"CUDA = {cuda} not supported!") +if cuda != "12.1" and cuda != "11.8" and cuda != "12.4" and cuda != "12.6": raise RuntimeError(f"CUDA = {cuda} not supported!") if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!") elif v <= V('2.1.1'): x = 'cu{}{}-torch211' elif v <= V('2.1.2'): x = 'cu{}{}-torch212' elif v < V('2.3.0'): x = 'cu{}{}-torch220' elif v < V('2.4.0'): x = 'cu{}{}-torch230' elif v < V('2.5.0'): x = 'cu{}{}-torch240' -elif v < V('2.6.0'): x = 'cu{}{}-torch250' +elif v < V('2.5.1'): x = 'cu{}{}-torch250' +elif v <= V('2.5.1'): x = 'cu{}{}-torch251' +elif v < V('2.7.0'): x = 'cu{}{}-torch260' else: raise RuntimeError(f"Torch = {v} too new!") x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"') \ No newline at end of file From fb0526be6172b528edead1b5f0e98c7502e66955 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:06:13 -0800 Subject: [PATCH 106/473] Update llama.py --- unsloth/models/llama.py | 92 +++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 54 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4b64c74f3..051cd441c 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2510,18 +2510,24 @@ def for_inference(model): # return # pass - internal_model = model - internal_model.gradient_checkpointing = False - internal_model.training = False - - while hasattr(internal_model, "model"): - internal_model = internal_model.model - internal_model.gradient_checkpointing = False - internal_model.training = False - pass - if hasattr(internal_model, "training"): - internal_model.training = False - pass + m = model + while hasattr(m, "model"): + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = False + if hasattr(m, "training"): + m.training = False + # Pad tokenizer to the left + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.padding_side = "left" + m = m.model + pass + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = False + if hasattr(m, "training"): + m.training = False + # Pad tokenizer to the left + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.padding_side = "left" # Also check if lm_head / embeddings are trained internal_model = model @@ -2530,30 +2536,13 @@ def for_inference(model): pass lm_head = internal_model.lm_head.weight device_type = lm_head.device.type - dtype = model.config.torch_dtype - - if type(dtype) is str: - if dtype == "float16": dtype = torch.float16 - elif dtype == "bfloat16": dtype = torch.bfloat16 - pass + dtype = _get_dtype(model.config.torch_dtype) # Wrap model.generate if model.generate.__name__ != "_fast_generate": model._unwrapped_old_generate = model.generate model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model) pass - - # Patch tokenizer to pad to the left - internal_model = model - while hasattr(internal_model, "model"): - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.padding_side = "left" - pass - internal_model = internal_model.model - pass - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.padding_side = "left" - pass # Also disable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): @@ -2571,9 +2560,6 @@ def for_inference(model): @staticmethod def for_training(model, use_gradient_checkpointing = True): - internal_model = model - internal_model.gradient_checkpointing = use_gradient_checkpointing - internal_model.training = True # Delete all fast inference loras for param in model.parameters(): @@ -2581,14 +2567,24 @@ def for_training(model, use_gradient_checkpointing = True): del param._fast_lora pass - while hasattr(internal_model, "model"): - internal_model = internal_model.model - internal_model.gradient_checkpointing = use_gradient_checkpointing - internal_model.training = True - pass - if hasattr(internal_model, "training"): - internal_model.training = True - pass + m = model + while hasattr(m, "model"): + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = use_gradient_checkpointing + if hasattr(m, "training"): + m.training = True + # Pad tokenizer to the right + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.padding_side = "right" + m = m.model + pass + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = use_gradient_checkpointing + if hasattr(m, "training"): + m.training = True + # Pad tokenizer to the right + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.padding_side = "right" # Also revert model.generate if hasattr(model, "_unwrapped_old_generate"): @@ -2596,18 +2592,6 @@ def for_training(model, use_gradient_checkpointing = True): del model._unwrapped_old_generate pass - # Patch tokenizer to pad to the right - internal_model = model - while hasattr(internal_model, "model"): - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.padding_side = "right" - pass - internal_model = internal_model.model - pass - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.padding_side = "right" - pass - # Also re-enable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): embeddings = model.get_input_embeddings() @@ -2617,7 +2601,7 @@ def for_training(model, use_gradient_checkpointing = True): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = True pass - + return model pass pass From f14adf1f701ce6fd48e1b64cf9485c14fa77164b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:10:17 -0800 Subject: [PATCH 107/473] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 051cd441c..23a8c0a68 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -262,6 +262,7 @@ def LlamaAttention_fast_forward_inference( A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) A = torch_matmul(A, Vnn, out = Qn) else: + print(attention_mask) A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) pass A = A.transpose(1, 2) @@ -2601,7 +2602,7 @@ def for_training(model, use_gradient_checkpointing = True): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = True pass - + return model pass pass From 03083b6fc44056cba462ed73697539d30e2fbf57 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:11:33 -0800 Subject: [PATCH 108/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 23a8c0a68..3c3325391 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -254,6 +254,7 @@ def LlamaAttention_fast_forward_inference( # pass # Attention + print(attention_mask) if bsz == 1: Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows From 15011952ab2bc60cc74089d7a5584b8469f85852 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:14:13 -0800 Subject: [PATCH 109/473] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3c3325391..23a8c0a68 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -254,7 +254,6 @@ def LlamaAttention_fast_forward_inference( # pass # Attention - print(attention_mask) if bsz == 1: Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows From e6b93e2bea60367c9ba792b6bacd0e9915a60ff2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:15:54 -0800 Subject: [PATCH 110/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 23a8c0a68..143cd4165 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -20,7 +20,7 @@ from ._utils import __version__ from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version -from unsloth_zoo.utils import Version +from unsloth_zoo.utils import Version, _get_dtype transformers_version = Version(transformers_version) # Transformers moved rotary embeddings out of all attention layers IS_ATTENTION_REFACTOR = transformers_version > Version("4.47.1") From e550ff01f1f909ee6c05b81ac60580796d6c2527 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:18:33 -0800 Subject: [PATCH 111/473] Update llama.py --- unsloth/models/llama.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 143cd4165..e69c7068f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -70,7 +70,8 @@ from huggingface_hub.utils._token import get_token pass from triton import __version__ as triton_version -BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if xformers is not None else None +HAS_XFORMERS = xformers is not None +BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None def original_apply_qkv(self, X): @@ -404,7 +405,7 @@ def LlamaAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if (not HAS_FLASH_ATTENTION and attention_mask is None): + if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None): # Xformers memory efficient attention # Also has Flash Attention v2 dispatching Q = Q.transpose(1, 2) @@ -978,7 +979,7 @@ def _CausalLM_fast_forward( attention_mask = attention_mask, ) else: - causal_mask = xformers.attn_bias.LowerTriangularMask() + causal_mask = xformers.attn_bias.LowerTriangularMask() if HAS_XFORMERS else None output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( From 99a87054998c5caffd228c680c8e55367ef52d46 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:23:53 -0800 Subject: [PATCH 112/473] Update llama.py --- unsloth/models/llama.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index e69c7068f..8d7871bdf 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -90,6 +90,8 @@ def original_apply_o(self, X): from math import sqrt as math_sqrt KV_CACHE_INCREMENT = 256 # KV Cache update size torch_nn_functional_softmax = torch.nn.functional.softmax +# SDPA has GQA internally +SDPA_HAS_GQA = "enable_gqa" in scaled_dot_product_attention.__doc__ # Fix new HF's inference code def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,): @@ -244,7 +246,7 @@ def LlamaAttention_fast_forward_inference( # Grouped query attention _, _, cached_len, _ = Knn.shape - if n_groups != 1: + if not SDPA_HAS_GQA and n_groups != 1: Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) @@ -263,8 +265,10 @@ def LlamaAttention_fast_forward_inference( A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) A = torch_matmul(A, Vnn, out = Qn) else: - print(attention_mask) - A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) + if SDPA_HAS_GQA: + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True) + else: + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) pass A = A.transpose(1, 2) A = A.reshape(bsz, 1, attention_size) From f04336fd13617c2812de8b75e95aac625763f283 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 03:49:18 -0800 Subject: [PATCH 113/473] Update llama.py --- unsloth/models/llama.py | 68 ++++++++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8d7871bdf..106cedbdd 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -295,14 +295,23 @@ def fast_swiglu_inference(self, X): return down pass - -def fast_rms_layernorm_inference(self, X): +torch_square = torch.square +torch_mean = torch.mean +def fast_rms_layernorm_inference(self, X, XX = None, XX2 = None, variance = None): old_dtype = X.dtype - XX = X.to(torch.float32) - variance = XX.square().mean(-1, keepdim = True) + if XX is None: + XX = X.to(torch.float32) + variance = XX.square().mean(-1, keepdim = True) + else: + XX.copy_(X) + torch_mean(torch_square(XX, out = XX2), -1, keepdim = True, out = variance) + pass variance += self.variance_epsilon XX *= variance.rsqrt_() - X = XX.to(old_dtype) # Must preserve due to residual + + if XX is None: X = XX.to(old_dtype) + else: X.copy_(XX) + X *= self.weight return X pass @@ -908,15 +917,15 @@ def LlamaModel_fast_forward_inference( attention_mask = None, ): input_ids = input_ids[:,:self.max_seq_length] - hidden_states = self.model.embed_tokens(input_ids) - hidden_states = hidden_states.to(self.config.torch_dtype) - bsz, q_len, hd = hidden_states.shape + X = self.model.embed_tokens(input_ids) + X = X.to(self.config.torch_dtype) + bsz, q_len, hd = X.shape seq_len = past_key_values[0][0].shape[-2] if bsz != 1: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (bsz, q_len), - hidden_states, + X, seq_len, sliding_window = getattr(self.config, "sliding_window", None), ) @@ -925,30 +934,47 @@ def LlamaModel_fast_forward_inference( pass next_decoder_cache = [] + residual = torch.empty_like(X) + XX = torch.empty_like(X, dtype = torch.float32) + XX2 = torch.empty_like(X, dtype = torch.float32) + variance = torch.empty((X.shape[0], X.shape[1], 1), dtype = torch.float32, device = "cuda:0") + for idx, decoder_layer in enumerate(self.model.layers): - residual = hidden_states - hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) - hidden_states, present_key_value = LlamaAttention_fast_forward_inference( + residual.copy_(X) # residual = X + X = fast_rms_layernorm_inference( + decoder_layer.input_layernorm, + X, + XX = XX, + XX2 = XX2, + variance = variance, + ) + X, present_key_value = LlamaAttention_fast_forward_inference( decoder_layer.self_attn, - hidden_states = hidden_states, + hidden_states = X, past_key_value = past_key_values[idx], position_ids = position_ids, attention_mask = attention_mask, do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), ) - hidden_states += residual - - residual = hidden_states - hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states) - hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states) - hidden_states += residual + X += residual + + residual.copy_(X) # residual = X + X = fast_rms_layernorm_inference( + decoder_layer.post_attention_layernorm, + X, + XX = XX, + XX2 = XX2, + variance = variance, + ) + X = fast_swiglu_inference(decoder_layer.mlp, X) + X += residual next_decoder_cache.append(present_key_value) pass - hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states) + X = fast_rms_layernorm_inference(self.model.norm, X) return BaseModelOutputWithPast( - last_hidden_state = hidden_states, + last_hidden_state = X, past_key_values = next_decoder_cache, hidden_states = [], attentions = [], From b4cf11f4dc0ddf535c79f8818c0f2b94c7271431 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 03:56:13 -0800 Subject: [PATCH 114/473] Update llama.py --- unsloth/models/llama.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 106cedbdd..475b234e1 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -278,15 +278,15 @@ def LlamaAttention_fast_forward_inference( torch_nn_functional_silu = torch.nn.functional.silu -def fast_swiglu_inference(self, X): +def fast_swiglu_inference(self, X, temp_gate = None, temp_up = None): # gate = self.gate_proj(X) # up = self.up_proj(X) bsz, _, hd = X.shape # mlp_size = self.config.intermediate_size # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") - gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0]) - up = fast_linear_forward(self. up_proj, X)#, out = temp[1]) + gate = fast_linear_forward(self.gate_proj, X, out = temp_gate) + up = fast_linear_forward(self. up_proj, X, out = temp_up) gate = torch_nn_functional_silu(gate, inplace = True) gate *= up @@ -920,6 +920,7 @@ def LlamaModel_fast_forward_inference( X = self.model.embed_tokens(input_ids) X = X.to(self.config.torch_dtype) bsz, q_len, hd = X.shape + mlp_size = self.config.intermediate_size seq_len = past_key_values[0][0].shape[-2] if bsz != 1: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -935,9 +936,11 @@ def LlamaModel_fast_forward_inference( next_decoder_cache = [] residual = torch.empty_like(X) - XX = torch.empty_like(X, dtype = torch.float32) - XX2 = torch.empty_like(X, dtype = torch.float32) - variance = torch.empty((X.shape[0], X.shape[1], 1), dtype = torch.float32, device = "cuda:0") + _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32) + XX, XX2 = _XX[0], _XX[1] + variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") + temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") + temp_gate, temp_up = temp_mlp[0], temp_mlp[1] for idx, decoder_layer in enumerate(self.model.layers): residual.copy_(X) # residual = X @@ -966,12 +969,23 @@ def LlamaModel_fast_forward_inference( XX2 = XX2, variance = variance, ) - X = fast_swiglu_inference(decoder_layer.mlp, X) + X = fast_swiglu_inference( + decoder_layer.mlp, + X, + temp_gate = temp_gate, + temp_up = temp_up, + ) X += residual next_decoder_cache.append(present_key_value) pass - X = fast_rms_layernorm_inference(self.model.norm, X) + X = fast_rms_layernorm_inference( + self.model.norm, + X, + XX = XX, + XX2 = XX2, + variance = variance, + ) return BaseModelOutputWithPast( last_hidden_state = X, From 20255efdd44a987f04a154c091710356d6dfa917 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 03:56:25 -0800 Subject: [PATCH 115/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 475b234e1..401b8986a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -936,6 +936,7 @@ def LlamaModel_fast_forward_inference( next_decoder_cache = [] residual = torch.empty_like(X) + print(bsz, q_len, hd) _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32) XX, XX2 = _XX[0], _XX[1] variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") From 04b0c4563c154e879ab71b7636db55652b46f2e2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 03:57:24 -0800 Subject: [PATCH 116/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 401b8986a..d0ffa53d5 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -937,7 +937,7 @@ def LlamaModel_fast_forward_inference( next_decoder_cache = [] residual = torch.empty_like(X) print(bsz, q_len, hd) - _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32) + _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") XX, XX2 = _XX[0], _XX[1] variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") From 8e8337309dd9a008cc1a53f24c07556998a36bd1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 04:00:47 -0800 Subject: [PATCH 117/473] Update llama.py --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d0ffa53d5..97a1fc233 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -936,13 +936,12 @@ def LlamaModel_fast_forward_inference( next_decoder_cache = [] residual = torch.empty_like(X) - print(bsz, q_len, hd) _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") XX, XX2 = _XX[0], _XX[1] variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") temp_gate, temp_up = temp_mlp[0], temp_mlp[1] - + for idx, decoder_layer in enumerate(self.model.layers): residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( From cd4b0393cf1d480c00d85f4498d03bc361cc6290 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 13:45:25 -0800 Subject: [PATCH 118/473] Faster inference? --- unsloth/kernels/utils.py | 9 ++++++--- unsloth/models/llama.py | 24 ++++++++++++++++-------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index de543962e..57df0d6b3 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -15,6 +15,7 @@ import triton MAX_FUSED_SIZE : int = 65536 next_power_of_2 = triton.next_power_of_2 +import functools # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch @@ -96,18 +97,20 @@ def get_lora_parameters(proj): pass +@functools.cache def get_lora_parameters_bias(proj): # For DPO or disabled adapters - base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj) + base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight bias = base_layer.bias - if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + if getattr(proj, "disable_adapters", True) or proj.merged: return W, QUANT_STATE(W), None, None, None, bias pass active_adapter = proj.active_adapters[0] if \ - hasattr(proj, "active_adapters") else proj.active_adapter + getattr(proj, "active_adapters", ) else proj.active_adapter A = proj.lora_A [active_adapter].weight B = proj.lora_B [active_adapter].weight s = proj.scaling[active_adapter] diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 97a1fc233..c91f04073 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -917,10 +917,23 @@ def LlamaModel_fast_forward_inference( attention_mask = None, ): input_ids = input_ids[:,:self.max_seq_length] + bsz, q_len = input_ids.shape + hd = self.config.hidden_size + mlp_size = self.config.intermediate_size + + # Get saved buffers to reduce memory movement + residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") + _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") + XX, XX2 = _XX[0], _XX[1] + variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") + temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") + temp_gate, temp_up = temp_mlp[0], temp_mlp[1] + X = self.model.embed_tokens(input_ids) X = X.to(self.config.torch_dtype) bsz, q_len, hd = X.shape - mlp_size = self.config.intermediate_size + assert(q_len == 1) + seq_len = past_key_values[0][0].shape[-2] if bsz != 1: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -933,15 +946,10 @@ def LlamaModel_fast_forward_inference( else: attention_mask = None pass + print(attention_mask) next_decoder_cache = [] - residual = torch.empty_like(X) - _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") - XX, XX2 = _XX[0], _XX[1] - variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") - temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") - temp_gate, temp_up = temp_mlp[0], temp_mlp[1] - + for idx, decoder_layer in enumerate(self.model.layers): residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( From c7ac842da892d68fc42c11184772e4b8d953a962 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 13:47:14 -0800 Subject: [PATCH 119/473] Update llama.py --- unsloth/models/llama.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c91f04073..d1d5f5e16 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -920,6 +920,11 @@ def LlamaModel_fast_forward_inference( bsz, q_len = input_ids.shape hd = self.config.hidden_size mlp_size = self.config.intermediate_size + + X = self.model.embed_tokens(input_ids) + X = X.to(self.config.torch_dtype) + bsz, q_len, hd = X.shape + assert(q_len == 1) # Get saved buffers to reduce memory movement residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") @@ -929,11 +934,6 @@ def LlamaModel_fast_forward_inference( temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") temp_gate, temp_up = temp_mlp[0], temp_mlp[1] - X = self.model.embed_tokens(input_ids) - X = X.to(self.config.torch_dtype) - bsz, q_len, hd = X.shape - assert(q_len == 1) - seq_len = past_key_values[0][0].shape[-2] if bsz != 1: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( From 0575002599494549ec9f8f641e28c7aa8cbc1221 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 13:49:25 -0800 Subject: [PATCH 120/473] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d1d5f5e16..cafec19cb 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -946,7 +946,6 @@ def LlamaModel_fast_forward_inference( else: attention_mask = None pass - print(attention_mask) next_decoder_cache = [] From cc88d1b9e6ea9595786df65f1189ac3ea476104f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 13:52:15 -0800 Subject: [PATCH 121/473] Update utils.py --- unsloth/kernels/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 57df0d6b3..f8690a17a 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -97,7 +97,6 @@ def get_lora_parameters(proj): pass -@functools.cache def get_lora_parameters_bias(proj): # For DPO or disabled adapters base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) From 19c4085f17e932ae1acfa5e8da56625a715ea9d4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 14:51:14 -0800 Subject: [PATCH 122/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index cafec19cb..3ed292082 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -260,6 +260,7 @@ def LlamaAttention_fast_forward_inference( if bsz == 1: Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows + print(Qn.shape, Knn.transpose(2, 3).shape, self.attention[:,:,:,:cached_len].shape, self.attention.shape, cached_len) A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) From 1ff67e3b35af85d5f9d36b453577e47a5a6c418e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 14:54:58 -0800 Subject: [PATCH 123/473] Update llama.py --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3ed292082..d6a4fa107 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -246,7 +246,7 @@ def LlamaAttention_fast_forward_inference( # Grouped query attention _, _, cached_len, _ = Knn.shape - if not SDPA_HAS_GQA and n_groups != 1: + if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1: Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) @@ -260,7 +260,6 @@ def LlamaAttention_fast_forward_inference( if bsz == 1: Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows - print(Qn.shape, Knn.transpose(2, 3).shape, self.attention[:,:,:,:cached_len].shape, self.attention.shape, cached_len) A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) From 8b37bc1f7af4b7efc9224c0fc92bc790a7223007 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 15:17:50 -0800 Subject: [PATCH 124/473] Update utils.py --- unsloth/kernels/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index f8690a17a..762219220 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -227,7 +227,7 @@ def fast_gemv(X, W, quant_state, out = None): if quant_state is None: return torch.matmul(X, W, out = out) # For fast X @ W where seq_len == 1 # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469 - _, q_len, hd = X.shape + bsz, q_len, hd = X.shape # assert(q_len == 1) if type(quant_state) is not list: @@ -254,7 +254,7 @@ def fast_gemv(X, W, quant_state, out = None): bout = shape[0] if out is None: - out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0") + out = torch.empty((bsz, 1, bout,), dtype = dtype, device = "cuda:0") # else: # assert(out.shape == (1, 1, bout,)) # pass @@ -284,8 +284,9 @@ def fast_gemv(X, W, quant_state, out = None): cgemm_4bit_inference_naive_bf16 blocksize = ctypes.c_int32(blocksize) - fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), - lda, ldb, ldc, blocksize, CUDA_STREAM,) + for i in range(bsz): + fx(m, n, k, get_ptr(X[i]), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out[i]), + lda, ldb, ldc, blocksize, CUDA_STREAM,) return out pass From b734d728fa88242128f77654234d77281351d1af Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 15:23:05 -0800 Subject: [PATCH 125/473] Update utils.py --- unsloth/kernels/utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 762219220..f8690a17a 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -227,7 +227,7 @@ def fast_gemv(X, W, quant_state, out = None): if quant_state is None: return torch.matmul(X, W, out = out) # For fast X @ W where seq_len == 1 # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469 - bsz, q_len, hd = X.shape + _, q_len, hd = X.shape # assert(q_len == 1) if type(quant_state) is not list: @@ -254,7 +254,7 @@ def fast_gemv(X, W, quant_state, out = None): bout = shape[0] if out is None: - out = torch.empty((bsz, 1, bout,), dtype = dtype, device = "cuda:0") + out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0") # else: # assert(out.shape == (1, 1, bout,)) # pass @@ -284,9 +284,8 @@ def fast_gemv(X, W, quant_state, out = None): cgemm_4bit_inference_naive_bf16 blocksize = ctypes.c_int32(blocksize) - for i in range(bsz): - fx(m, n, k, get_ptr(X[i]), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out[i]), - lda, ldb, ldc, blocksize, CUDA_STREAM,) + fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), + lda, ldb, ldc, blocksize, CUDA_STREAM,) return out pass From 9c7618cced69b4a9f904a80a242126757069b80a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 16:18:56 -0800 Subject: [PATCH 126/473] Update utils.py --- unsloth/kernels/utils.py | 73 ++++++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index f8690a17a..c5df015ca 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -116,9 +116,13 @@ def get_lora_parameters_bias(proj): return W, QUANT_STATE(W), A, B, s, bias pass +global WEIGHT_BUFFER +WEIGHT_BUFFER = None +global ABSMAX_BUFFER +ABSMAX_BUFFER = None if HAS_CUDA_STREAM: - def fast_dequantize(W, quant_state = None, out = None): + def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class @@ -141,18 +145,34 @@ def fast_dequantize(W, quant_state = None, out = None): global CUDA_STREAM if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0") + n_elements_absmax = absmax.numel() + # Create weight matrix - if out is None: - out = torch.empty(shape, dtype = dtype, device = "cuda:0") + if use_global_buffer: + + # Use same buffers for faster inference + size = shape[0]*shape[1] + global WEIGHT_BUFFER + global ABSMAX_BUFFER + if WEIGHT_BUFFER is None: + WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") + ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0") + + if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) + if n_elements_absmax > ABSMAX_BUFFER.numel(): WEIGHT_BUFFER.resize_(n_elements_absmax) + + out = WEIGHT_BUFFER[:size].view(shape) + out_absmax = ABSMAX_BUFFER[:n_elements_absmax] else: - assert(out.shape == shape) - assert(out.dtype == dtype) + if out is None: + out = torch.empty(shape, dtype = dtype, device = "cuda:0") + else: + assert(out.shape == shape) + assert(out.dtype == dtype) + out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + pass # NF4 dequantization of statistics - n_elements_absmax = absmax.numel() - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") - - # Do dequantization ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, @@ -160,6 +180,7 @@ def fast_dequantize(W, quant_state = None, out = None): ) out_absmax += offset + # Dequantize W fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), @@ -170,7 +191,7 @@ def fast_dequantize(W, quant_state = None, out = None): return out.t() if is_transposed else out pass else: - def fast_dequantize(W, quant_state = None, out = None): + def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class @@ -191,16 +212,32 @@ def fast_dequantize(W, quant_state = None, out = None): absmax2, code2, blocksize2, _, _, _, _ = state2 pass + n_elements_absmax = absmax.numel() + # Create weight matrix - if out is None: - out = torch.empty(shape, dtype = dtype, device = "cuda:0") - else: - assert(out.shape == shape) - assert(out.dtype == dtype) + if use_global_buffer: - # NF4 dequantization of statistics - n_elements_absmax = absmax.numel() - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + # Use same buffers for faster inference + size = shape[0]*shape[1] + global WEIGHT_BUFFER + global ABSMAX_BUFFER + if WEIGHT_BUFFER is None: + WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") + ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0") + + if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) + if n_elements_absmax > ABSMAX_BUFFER.numel(): WEIGHT_BUFFER.resize_(n_elements_absmax) + + out = WEIGHT_BUFFER[:size].view(shape) + out_absmax = ABSMAX_BUFFER[:n_elements_absmax] + else: + if out is None: + out = torch.empty(shape, dtype = dtype, device = "cuda:0") + else: + assert(out.shape == shape) + assert(out.dtype == dtype) + out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + pass # Do dequantization ptr_out_absmax = get_ptr(out_absmax) From e530002aa19cdb28cbb3085f8ffa6af897bfb07e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 16:23:31 -0800 Subject: [PATCH 127/473] Update utils.py --- unsloth/kernels/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index c5df015ca..753eda5b3 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -404,7 +404,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: - W = fast_dequantize(W.t(), W_quant) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) out = torch.matmul(X, W, out = out) pass @@ -438,7 +438,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype - W = fast_dequantize(W.t(), W_quant) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) if X.dim() == 3: batch, seq_len, d = X.shape From 404ac62e2edc6cd24f379d697f98f5a3db86c24c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 16:24:10 -0800 Subject: [PATCH 128/473] Update utils.py --- unsloth/kernels/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 753eda5b3..037c8c8a1 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -226,7 +226,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0") if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) - if n_elements_absmax > ABSMAX_BUFFER.numel(): WEIGHT_BUFFER.resize_(n_elements_absmax) + if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) out = WEIGHT_BUFFER[:size].view(shape) out_absmax = ABSMAX_BUFFER[:n_elements_absmax] From 78395a4b1c87174aac241328207cf40be887583e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 16:26:55 -0800 Subject: [PATCH 129/473] Update utils.py --- unsloth/kernels/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 037c8c8a1..d470c0f87 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -219,6 +219,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # Use same buffers for faster inference size = shape[0]*shape[1] + print(shape, size) global WEIGHT_BUFFER global ABSMAX_BUFFER if WEIGHT_BUFFER is None: From 4386c2a4ddc7f922ea643d0d6822da2ad13f0b99 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 16:29:26 -0800 Subject: [PATCH 130/473] Update utils.py --- unsloth/kernels/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index d470c0f87..c378d4d73 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -154,12 +154,13 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False size = shape[0]*shape[1] global WEIGHT_BUFFER global ABSMAX_BUFFER + print(size, shape) if WEIGHT_BUFFER is None: WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0") if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) - if n_elements_absmax > ABSMAX_BUFFER.numel(): WEIGHT_BUFFER.resize_(n_elements_absmax) + if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) out = WEIGHT_BUFFER[:size].view(shape) out_absmax = ABSMAX_BUFFER[:n_elements_absmax] @@ -219,7 +220,6 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # Use same buffers for faster inference size = shape[0]*shape[1] - print(shape, size) global WEIGHT_BUFFER global ABSMAX_BUFFER if WEIGHT_BUFFER is None: From 62fe595cb76e8a7a08be99ecc88acd071d40a2f9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 16:31:28 -0800 Subject: [PATCH 131/473] Update utils.py --- unsloth/kernels/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index c378d4d73..645727956 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -154,10 +154,9 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False size = shape[0]*shape[1] global WEIGHT_BUFFER global ABSMAX_BUFFER - print(size, shape) if WEIGHT_BUFFER is None: WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") - ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0") + ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) From 366ca87ad363418a6cceb92395d7a5b54f22b900 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 17:00:24 -0800 Subject: [PATCH 132/473] Update utils.py --- unsloth/kernels/utils.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 645727956..d4f31d0e4 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -242,15 +242,25 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # Do dequantization ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( - get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), + get_ptr(code2), + get_ptr(absmax), + get_ptr(absmax2), + ptr_out_absmax, + ctypes.c_int(blocksize2), + ctypes.c_int(n_elements_absmax), ) out_absmax += offset fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 - fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes.c_int(blocksize), ctypes.c_int(out.numel()),) + fx( + get_ptr(None), + get_ptr(W), + ptr_out_absmax, + get_ptr(out), + ctypes.c_int(blocksize), + ctypes.c_int(out.numel()), + ) # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) @@ -393,6 +403,9 @@ def fast_gemv(X, W, quant_state, out = None): pass +torch_mm = torch.mm +torch_mv = torch.mv +torch_matmul = torch.matmul def fast_linear_forward(proj, X, temp_lora = None, out = None): W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj) @@ -405,7 +418,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): out = fast_gemv(X, W, W_quant, out = out) else: W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) - out = torch.matmul(X, W, out = out) + out = torch_matmul(X, W, out = out) pass # Add in LoRA weights @@ -420,11 +433,11 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): if bsz == 1: out = out.view(out_dim) - temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora) + temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora) out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S) else: out = out.view(bsz, out_dim) - temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora) + temp_lora = torch_mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora) out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S) pass out = out.view(bsz, 1, out_dim) From 5d0f36a1968966e71a46e0dfd21e4269ec34e077 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 17:04:37 -0800 Subject: [PATCH 133/473] Update utils.py --- unsloth/kernels/utils.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index d4f31d0e4..8537e9595 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -175,16 +175,27 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # NF4 dequantization of statistics ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( - get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), CUDA_STREAM, + get_ptr(code2), + get_ptr(absmax), + get_ptr(absmax2), + ptr_out_absmax, + ctypes.c_int(blocksize2), + ctypes.c_int(n_elements_absmax), + CUDA_STREAM, ) out_absmax += offset # Dequantize W fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 - fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes.c_int(blocksize), ctypes.c_int(out.numel()), CUDA_STREAM,) + fx( + get_ptr(None), + get_ptr(W), + ptr_out_absmax, + get_ptr(out), + ctypes.c_int(blocksize), + ctypes.c_int(out.numel()), + CUDA_STREAM,) # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) @@ -242,25 +253,15 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # Do dequantization ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( - get_ptr(code2), - get_ptr(absmax), - get_ptr(absmax2), - ptr_out_absmax, - ctypes.c_int(blocksize2), - ctypes.c_int(n_elements_absmax), + get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, + ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), ) out_absmax += offset fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 - fx( - get_ptr(None), - get_ptr(W), - ptr_out_absmax, - get_ptr(out), - ctypes.c_int(blocksize), - ctypes.c_int(out.numel()), - ) + fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), + ctypes.c_int(blocksize), ctypes.c_int(out.numel()),) # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) From ed596d9b4655cd5b84a4e59d8634b81e206c8235 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 17:07:31 -0800 Subject: [PATCH 134/473] Update utils.py --- unsloth/kernels/utils.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 8537e9595..3b0c1d391 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -121,6 +121,8 @@ def get_lora_parameters_bias(proj): global ABSMAX_BUFFER ABSMAX_BUFFER = None +ctypes_c_int = ctypes.c_int + if HAS_CUDA_STREAM: def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W @@ -157,12 +159,14 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False if WEIGHT_BUFFER is None: WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + ABSMAX_BUFFER.ptr_out_absmax = get_ptr(ABSMAX_BUFFER) if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) out = WEIGHT_BUFFER[:size].view(shape) out_absmax = ABSMAX_BUFFER[:n_elements_absmax] + ptr_out_absmax = ABSMAX_BUFFER.ptr_out_absmax else: if out is None: out = torch.empty(shape, dtype = dtype, device = "cuda:0") @@ -170,19 +174,20 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False assert(out.shape == shape) assert(out.dtype == dtype) out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + ptr_out_absmax = get_ptr(out_absmax) pass # NF4 dequantization of statistics - ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes.c_int(blocksize2), - ctypes.c_int(n_elements_absmax), + ctypes_c_int(blocksize2), + ctypes_c_int(n_elements_absmax), CUDA_STREAM, ) + print(offset, out_absmax) out_absmax += offset # Dequantize W @@ -193,8 +198,8 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes.c_int(blocksize), - ctypes.c_int(out.numel()), + ctypes_c_int(blocksize), + ctypes_c_int(out.numel()), CUDA_STREAM,) # Careful returning transposed data @@ -254,14 +259,14 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), + ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), ) out_absmax += offset fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes.c_int(blocksize), ctypes.c_int(out.numel()),) + ctypes_c_int(blocksize), ctypes_c_int(out.numel()),) # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) From d652dc15e36f0a5069c6c76a654dd4453cd76f10 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 17:12:21 -0800 Subject: [PATCH 135/473] Update utils.py --- unsloth/kernels/utils.py | 59 +++++++++++++++------------------------- 1 file changed, 22 insertions(+), 37 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 3b0c1d391..ac468e43a 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -67,6 +67,7 @@ def calculate_settings(n : int) -> (int, int,): CUDA_STREAM = None get_ptr = bnb.functional.get_ptr import ctypes +ctypes_c_int = ctypes.c_int cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 @@ -121,8 +122,6 @@ def get_lora_parameters_bias(proj): global ABSMAX_BUFFER ABSMAX_BUFFER = None -ctypes_c_int = ctypes.c_int - if HAS_CUDA_STREAM: def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W @@ -159,14 +158,12 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False if WEIGHT_BUFFER is None: WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") - ABSMAX_BUFFER.ptr_out_absmax = get_ptr(ABSMAX_BUFFER) if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) out = WEIGHT_BUFFER[:size].view(shape) out_absmax = ABSMAX_BUFFER[:n_elements_absmax] - ptr_out_absmax = ABSMAX_BUFFER.ptr_out_absmax else: if out is None: out = torch.empty(shape, dtype = dtype, device = "cuda:0") @@ -174,33 +171,21 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False assert(out.shape == shape) assert(out.dtype == dtype) out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") - ptr_out_absmax = get_ptr(out_absmax) pass # NF4 dequantization of statistics + ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( - get_ptr(code2), - get_ptr(absmax), - get_ptr(absmax2), - ptr_out_absmax, - ctypes_c_int(blocksize2), - ctypes_c_int(n_elements_absmax), - CUDA_STREAM, + get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, + ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM, ) - print(offset, out_absmax) out_absmax += offset # Dequantize W fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 - fx( - get_ptr(None), - get_ptr(W), - ptr_out_absmax, - get_ptr(out), - ctypes_c_int(blocksize), - ctypes_c_int(out.numel()), - CUDA_STREAM,) + fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), + ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,) # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) @@ -318,17 +303,17 @@ def fast_gemv(X, W, quant_state, out = None): lda = shape[0] ldc = shape[0] ldb = (hd+1)//2 - m = ctypes.c_int32(m) - n = ctypes.c_int32(n) - k = ctypes.c_int32(k) - lda = ctypes.c_int32(lda) - ldb = ctypes.c_int32(ldb) - ldc = ctypes.c_int32(ldc) + m = ctypes_c_int32(m) + n = ctypes_c_int32(n) + k = ctypes_c_int32(k) + lda = ctypes_c_int32(lda) + ldb = ctypes_c_int32(ldb) + ldc = ctypes_c_int32(ldc) df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0") cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), - ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), CUDA_STREAM, + ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM, ) df += offset absmax = df @@ -336,7 +321,7 @@ def fast_gemv(X, W, quant_state, out = None): fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \ cgemm_4bit_inference_naive_bf16 - blocksize = ctypes.c_int32(blocksize) + blocksize = ctypes_c_int32(blocksize) fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), lda, ldb, ldc, blocksize, CUDA_STREAM,) @@ -382,17 +367,17 @@ def fast_gemv(X, W, quant_state, out = None): lda = shape[0] ldc = shape[0] ldb = (hd+1)//2 - m = ctypes.c_int32(m) - n = ctypes.c_int32(n) - k = ctypes.c_int32(k) - lda = ctypes.c_int32(lda) - ldb = ctypes.c_int32(ldb) - ldc = ctypes.c_int32(ldc) + m = ctypes_c_int32(m) + n = ctypes_c_int32(n) + k = ctypes_c_int32(k) + lda = ctypes_c_int32(lda) + ldb = ctypes_c_int32(ldb) + ldc = ctypes_c_int32(ldc) df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0") cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), - ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), + ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), ) df += offset absmax = df @@ -400,7 +385,7 @@ def fast_gemv(X, W, quant_state, out = None): fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \ cgemm_4bit_inference_naive_bf16 - blocksize = ctypes.c_int32(blocksize) + blocksize = ctypes_c_int32(blocksize) fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), lda, ldb, ldc, blocksize,) From ec266cf4891854823adf64f2415f374ee43c6fdb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Feb 2025 15:43:40 -0800 Subject: [PATCH 136/473] Update utils.py --- unsloth/kernels/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index ac468e43a..66a1a4895 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -404,7 +404,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) if W_quant is None: - out = torch.matmul(X, W.t(), out = out) + out = torch_matmul(X, W.t(), out = out) elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: @@ -452,7 +452,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): reshape = False pass - out = torch.matmul(X, W, out = out) + out = torch_matmul(X, W, out = out) if W_quant is not None: del W if A is not None: From b861b662a02470e402df548fef6aecaaf9d208fa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Feb 2025 19:43:52 -0800 Subject: [PATCH 137/473] Update mapper.py --- unsloth/models/mapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index bc01c2858..6e6e402a0 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -555,12 +555,12 @@ "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", ), "unsloth/Mistral-Small-24B-Base-2501-unsloth-bnb-4bit" : ( - "unsloth/Mistral-Small-24B-Base", + "unsloth/Mistral-Small-24B-Base-2501", "mistralai/Mistral-Small-24B-Base-2501", "unsloth/Mistral-Small-24B-Base-2501-bnb-4bit", ), "unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit" : ( - "unsloth/Mistral-Small-24B-Instruct", + "unsloth/Mistral-Small-24B-Instruct-2501", "mistralai/Mistral-Small-24B-Instruct-2501", "unsloth/Mistral-Small-24B-Instruct-2501-bnb-4bit", ), From ba151161028fe20de1c1cb4fb1341e480a7446fd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 02:21:15 -0800 Subject: [PATCH 138/473] Fast Inference via vLLM --- unsloth/models/llama.py | 84 +++++++++++++++++++++++++++++++++------- unsloth/models/loader.py | 43 +++++++++++++++++++- 2 files changed, 111 insertions(+), 16 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d6a4fa107..b350f764c 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1634,9 +1634,18 @@ def from_pretrained( model_patcher = None, tokenizer_name = None, trust_remote_code = False, + + fast_inference = False, # uses vLLM + gpu_memory_utilization = 0.5, + float8_kv_cache = True, + random_state = 3407, + max_lora_rank = 16, + disable_log_stats = False, **kwargs, ): if trust_remote_code: + if fast_inference: + raise NotImplementedError("Unsloth: Fast inference does not support `trust_remote_code` yet.") print( "Unsloth: WARNING `trust_remote_code` is True.\n"\ "Are you certain you want to do remote code execution?" @@ -1650,9 +1659,9 @@ def from_pretrained( statistics = \ f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.\n"\ - f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ - f"O^O/ \_/ \\ Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ - f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ + f" {chr(92)}{chr(92)} /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ + f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ + f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ f' "-____-" Free Apache license: http://github.com/unslothai/unsloth' print(statistics) @@ -1680,7 +1689,11 @@ def from_pretrained( assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) # RoPE Scaling - model_config = AutoConfig.from_pretrained(model_name, token = token) + model_config = AutoConfig.from_pretrained( + model_name, + token = token, + attn_implementation = "sdpa", + ) model_max_seq_length = model_config.max_position_embeddings # Check if RoPE Scaling is even allowed @@ -1701,6 +1714,9 @@ def from_pretrained( rope_scaling = max_seq_length / model_max_seq_length + if fast_inference: + raise NotImplementedError("Unsloth: Fast inference does not yet work with RoPE Scaling.") + logger.warning_once( f"Unsloth: {model_name} can only handle sequence lengths of at most "\ f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\ @@ -1742,17 +1758,55 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config - model = AutoModelForCausalLM.from_pretrained( - model_name, - device_map = device_map, - torch_dtype = dtype, - # quantization_config = bnb_config, - token = token, - max_position_embeddings = max_position_embeddings, - trust_remote_code = trust_remote_code, - attn_implementation = "eager", - **kwargs, - ) + if not fast_inference: + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map = device_map, + torch_dtype = dtype, + # quantization_config = bnb_config, + token = token, + max_position_embeddings = max_position_embeddings, + trust_remote_code = trust_remote_code, + attn_implementation = "eager", + **kwargs, + ) + else: + from unsloth_zoo.vllm_utils import ( + load_vllm, + get_vllm_state_dict, + convert_vllm_to_huggingface, + generate_batches, + ) + allowed_args = inspect.getfullargspec(load_vllm).args + load_vllm_kwargs = dict( + model_name = model_name, + config = model_config, + gpu_memory_utilization = gpu_memory_utilization, + max_seq_length = max_seq_length, + dtype = dtype, + disable_log_stats = disable_log_stats, + float8_kv_cache = float8_kv_cache, + enable_lora = True, + max_lora_rank = max_lora_rank, + disable_log_stats = disable_log_stats, + ) + for allowed_arg in allowed_args: + if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs: + load_vllm_kwargs[allowed_arg] = kwargs[allowed_arg] + pass + + # Load vLLM first + llm = load_vllm(**load_vllm_kwargs) + + # Convert to HF format + _, quant_state_dict = get_vllm_state_dict(llm, config = model_config) + model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype) + model.vllm_engine = llm + model.fast_generate = model.vllm_engine.generate + + from functools import partial + model.fast_generate_batches = partial(generate_batches, model.vllm_engine) + pass # Return old flag os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer # We currently only support NVIDIA GPUs - AMD / Intel is a work in progress! diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index e9caad0e6..144863b8d 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -30,11 +30,11 @@ from huggingface_hub.utils._token import get_token pass from huggingface_hub import HfFileSystem +import importlib.util # [TODO] Move USE_MODELSCOPE to utils USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1" if USE_MODELSCOPE: - import importlib if importlib.util.find_spec("modelscope") is None: raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`') pass @@ -73,9 +73,25 @@ def from_pretrained( resize_model_vocab = None, revision = None, use_exact_model_name = False, + + fast_inference = False, # uses vLLM + gpu_memory_utilization = 0.5, + float8_kv_cache = True, + random_state = 3407, + max_lora_rank = 16, + disable_log_stats = False, *args, **kwargs, ): if token is None: token = get_token() + + if fast_inference: + if importlib.util.find_spec("vllm") is None: + raise ImportError( + "Unsloth: Please install vLLM before enabling `fast_inference`!\n"\ + "You can do this in a terminal via `pip install vllm`" + ) + pass + pass old_model_name = model_name if not use_exact_model_name: @@ -255,6 +271,24 @@ def from_pretrained( tokenizer_name = None pass + if fast_inference: + from unsloth_zoo.vllm_utils import ( + patch_vllm, + vllm_dynamic_quant_supported, + ) + patch_vllm() + if model_name.endswith("unsloth-bnb-4bit"): + if not vllm_dynamic_quant_supported(model_name, model_config): + # Instead use -bnb-4bit variant + print( + f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n"\ + f"we do not yet support fast inference for {model_name}" + ) + model_name = model_name[:-len("unsloth-bnb-4bit")] + "bnb-4bit" + pass + pass + pass + model, tokenizer = dispatch_model.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, @@ -268,6 +302,13 @@ def from_pretrained( tokenizer_name = tokenizer_name, trust_remote_code = trust_remote_code, revision = revision if not is_peft else None, + + fast_inference = fast_inference, + gpu_memory_utilization = gpu_memory_utilization, + float8_kv_cache = float8_kv_cache, + random_state = random_state, + max_lora_rank = max_lora_rank, + disable_log_stats = disable_log_stats, *args, **kwargs, ) From d2aef048e0e4f0d0de3e4f19a892f1357f0eba2c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 02:30:51 -0800 Subject: [PATCH 139/473] Update llama.py --- unsloth/models/llama.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index b350f764c..1a700c62d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1784,7 +1784,6 @@ def from_pretrained( gpu_memory_utilization = gpu_memory_utilization, max_seq_length = max_seq_length, dtype = dtype, - disable_log_stats = disable_log_stats, float8_kv_cache = float8_kv_cache, enable_lora = True, max_lora_rank = max_lora_rank, @@ -2302,6 +2301,20 @@ def get_peft_model( modules_to_save = list(set(modules_to_save)) pass + vllm_engine = None + if hasattr(model, "vllm_engine"): + # Fast inference! + vllm_engine = model.vllm_engine + vllm_fast_generate = model.fast_generate + vllm_fast_generate_batches = model.fast_generate_batches + + if len(modules_to_save) != 0: + raise NotImplementedError("Unsloth: Currently fast inference does not work with training embeddings or lm_head.") + + if bias != "none": + raise NotImplementedError("Unsloth: Currently fast inference does not work with using biases for LoRA.") + pass + # Get LoRA arguments = dict( r = r, @@ -2408,6 +2421,19 @@ def get_peft_model( torch.cuda.empty_cache() pass + # Patch for fast inference + if vllm_engine is not None: + model.vllm_engine = vllm_engine + model.fast_generate = vllm_fast_generate + model.fast_generate_batches = vllm_fast_generate_batches + + # Also saving and loading LoRA + from functools import partial + from unsloth_zoo.vllm_utils import save_lora, load_lora + model.save_lora = partial(save_lora, model) + model.load_lora = partial(load_lora, model) + pass + return model pass From 48bdd41631b775635d09f349cee70a4d9c8cbf24 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 02:56:16 -0800 Subject: [PATCH 140/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1a700c62d..ab90d2cbb 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2308,7 +2308,7 @@ def get_peft_model( vllm_fast_generate = model.fast_generate vllm_fast_generate_batches = model.fast_generate_batches - if len(modules_to_save) != 0: + if modules_to_save is not None: raise NotImplementedError("Unsloth: Currently fast inference does not work with training embeddings or lm_head.") if bias != "none": From 2a8ba7ba3a3bfc5f84196df555d5269713369b23 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 04:02:40 -0800 Subject: [PATCH 141/473] Update utils.py --- unsloth/kernels/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 66a1a4895..165950a91 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -67,7 +67,8 @@ def calculate_settings(n : int) -> (int, int,): CUDA_STREAM = None get_ptr = bnb.functional.get_ptr import ctypes -ctypes_c_int = ctypes.c_int +ctypes_c_int = ctypes.c_int +ctypes_c_int32 = ctypes.c_int32 cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 From cf13d541243fbb7c9c7a51f6b58d38aea0c478dd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:14:01 -0800 Subject: [PATCH 142/473] Create rl.py --- unsloth/models/rl.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 unsloth/models/rl.py diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py new file mode 100644 index 000000000..efe2d33e0 --- /dev/null +++ b/unsloth/models/rl.py @@ -0,0 +1,39 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + "patch_rl", +] + +from trl.models.utils import unwrap_model_for_generation +from contextlib import contextmanager + + +def patch_rl(FastLanguageModel): + @contextmanager + def unsloth_unwrap_model_for_generation(model, *args, **kwargs): + FastLanguageModel.for_inference(model) + yield unwrap_model_for_generation(model, *args, **kwargs) + FastLanguageModel.for_training (model) + pass + + import trl.trainer + trainers = dir(trl.trainer) + trainers = [x for x in trainers if x.endswith("_trainer")] + unwrap = "unwrap_model_for_generation" + for trainer in trainers: + if hasattr(eval(f"trl.trainer.{trainer}"), unwrap): + exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}") + pass +pass From 38e6ec2d81674378245e5be4f9e7d7a4e3ab5d5c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:17:38 -0800 Subject: [PATCH 143/473] PatchRL --- unsloth/models/__init__.py | 1 + unsloth/models/rl.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index c52d14f40..3478dfc31 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -20,3 +20,4 @@ from .qwen2 import FastQwen2Model from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported +from .rl import PatchRL diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index efe2d33e0..2aa8f0265 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -13,14 +13,15 @@ # limitations under the License. __all__ = [ - "patch_rl", + "PatchRL", ] -from trl.models.utils import unwrap_model_for_generation -from contextlib import contextmanager +def PatchRL(FastLanguageModel): -def patch_rl(FastLanguageModel): + from trl.models.utils import unwrap_model_for_generation + from contextlib import contextmanager + @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): FastLanguageModel.for_inference(model) From 886b3c82905536ffbc983352a79f14da219b9cac Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:19:37 -0800 Subject: [PATCH 144/473] Update rl.py --- unsloth/models/rl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2aa8f0265..2bd602e09 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -21,11 +21,12 @@ def PatchRL(FastLanguageModel): from trl.models.utils import unwrap_model_for_generation from contextlib import contextmanager - + @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): FastLanguageModel.for_inference(model) - yield unwrap_model_for_generation(model, *args, **kwargs) + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + yield unwrapped_model FastLanguageModel.for_training (model) pass From 8724b1af04e7982b7d41635d0534356d61484120 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:23:19 -0800 Subject: [PATCH 145/473] Update rl.py --- unsloth/models/rl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2bd602e09..cea08bbc3 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -24,9 +24,12 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): + # Must use for_inference to allow inference in Unsloth FastLanguageModel.for_inference(model) - with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: - yield unwrapped_model + with torch.inference_mode(): + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + yield unwrapped_model + # Return back to training mode FastLanguageModel.for_training (model) pass From 870bd33599f88afffdfb0cc1fa32b86b276921a1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:24:36 -0800 Subject: [PATCH 146/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index cea08bbc3..b041277e4 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -16,6 +16,7 @@ "PatchRL", ] +import torch def PatchRL(FastLanguageModel): From efa4bd86cea0d47ce9c0d20a327926c7eba30061 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:36:04 -0800 Subject: [PATCH 147/473] PatchRLStatistics --- unsloth/models/__init__.py | 2 +- unsloth/models/rl.py | 131 +++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index 3478dfc31..279080173 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -20,4 +20,4 @@ from .qwen2 import FastQwen2Model from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported -from .rl import PatchRL +from .rl import PatchRL, PatchRLStatistics diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b041277e4..f8d4d5412 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -14,9 +14,22 @@ __all__ = [ "PatchRL", + "PatchRLStatistics", ] import torch +try: + from transformers.utils.notebook import ( + IntervalStrategy, + NotebookTrainingTracker, + NotebookProgressCallback, + ) + HAS_NOTEBOOK = True +except: + HAS_NOTEBOOK = False +pass +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + def PatchRL(FastLanguageModel): @@ -43,3 +56,121 @@ def unsloth_unwrap_model_for_generation(model, *args, **kwargs): exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}") pass pass + + +def NotebookProgressCallback_on_train_begin(Trainer_metrics): + def _NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs): + self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step" + self.training_loss = 0 + self.last_log = 0 + column_names = [self.first_column] + ["Training Loss"] + if args.eval_strategy != IntervalStrategy.NO: + column_names.append("Validation Loss") + column_names += [x.replace("/", " / ") for x in Trainer_metrics] + self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) + pass + return _NotebookProgressCallback_on_train_begin +pass + + +def NotebookProgressCallback_on_log(Trainer_metrics): + def _NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs): + # Only for when there is no evaluation + if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: + values = {"Training Loss": logs["loss"]} + for metric in DPOTrainer_metrics: + values[metric.replace("/", " / ")] = logs[metric] + pass + # First column is necessarily Step since we're not in epoch eval strategy + values["Step"] = state.global_step + self.training_tracker.write_line(values) + pass + pass + return _NotebookProgressCallback_on_log +pass + + +def _NotebookTrainingTracker_write_line(Trainer_metrics): + set_Trainer_metrics = set(Trainer_metrics) + def NotebookTrainingTracker_write_line(self, values): + """ + Write the values in the inner table. + + Args: + values (`Dict[str, float]`): The values to display. + """ + if self.inner_table is None: + self.inner_table = [list(values.keys()), list(values.values())] + else: + columns = self.inner_table[0] + new_values = {} + for key, value in values.items(): + lowered = key.lower() + if lowered in set_Trainer_metrics: + new_values[lowered.replace("/", " / ")] = value + else: + new_values[key] = value + pass + values = new_values + + self.inner_table[0] = columns + if len(self.inner_table) > 1: + last_values = self.inner_table[-1] + first_column = self.inner_table[0][0] + if last_values[0] != values[first_column]: + # write new line + self.inner_table.append([values[c] if c in values else "No Log" for c in columns]) + else: + # update last line + new_values = values + for c in columns: + if c not in new_values.keys(): + new_values[c] = last_values[columns.index(c)] + self.inner_table[-1] = [new_values[c] for c in columns] + else: + # Edit for evaluation purposes + self.inner_table.append([values[c] if c in values else 0 for c in columns]) + pass + pass + pass + return NotebookTrainingTracker_write_line +pass + + +def _PatchRLStatistics(metrics): + if HAS_NOTEBOOK: + from transformers.trainer import is_in_notebook + if is_in_notebook(): + # Patch DPO notebook printing + NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line(metrics) + from transformers.trainer import DEFAULT_PROGRESS_CALLBACK + DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin(metrics) + DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log(metrics) + pass + pass +pass + + +def PatchRLStatistics(algorithm = "grpo"): + if algorithm == "grpo": + metrics = [ + "completion_length", + "reward", + "reward_std", + "kl", + ] + elif algorithm == "dpo" or algorithm == "kto": + metrics = [ + "rewards/chosen", + "rewards/rejected", + "rewards/accuracies", + "rewards/margins", + "logps/rejected", + "logps/chosen", + "logits/rejected", + "logits/chosen", + ] + else: + print(f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.") + _PatchRLStatistics(metrics) +pass From 3848350944958632979f1258287a8c22fcff19e8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:36:51 -0800 Subject: [PATCH 148/473] Update rl.py --- unsloth/models/rl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f8d4d5412..40979ec76 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -151,15 +151,16 @@ def _PatchRLStatistics(metrics): pass -def PatchRLStatistics(algorithm = "grpo"): - if algorithm == "grpo": +def PatchRLStatistics(algorithm = "GRPO"): + algorithm = algorithm.upper() + if algorithm == "GRPO": metrics = [ "completion_length", "reward", "reward_std", "kl", ] - elif algorithm == "dpo" or algorithm == "kto": + elif algorithm == "DPO" or algorithm == "KTO": metrics = [ "rewards/chosen", "rewards/rejected", From f8b03ee90ce31341ad1cbde9822719418ca23cc4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:45:05 -0800 Subject: [PATCH 149/473] Update rl.py --- unsloth/models/rl.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 40979ec76..0e9e28b48 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,10 +39,9 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth - FastLanguageModel.for_inference(model) - with torch.inference_mode(): - with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: - yield unwrapped_model + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + FastLanguageModel.for_inference(unwrapped_model) + yield unwrapped_model # Return back to training mode FastLanguageModel.for_training (model) pass From 44db7fcba191d5ec5c73517af0b86f76638e1be0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:47:23 -0800 Subject: [PATCH 150/473] Update rl.py --- unsloth/models/rl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 0e9e28b48..caf12cd6d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -89,9 +89,9 @@ def _NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kw pass -def _NotebookTrainingTracker_write_line(Trainer_metrics): +def NotebookTrainingTracker_write_line(Trainer_metrics): set_Trainer_metrics = set(Trainer_metrics) - def NotebookTrainingTracker_write_line(self, values): + def _NotebookTrainingTracker_write_line(self, values): """ Write the values in the inner table. @@ -132,7 +132,7 @@ def NotebookTrainingTracker_write_line(self, values): pass pass pass - return NotebookTrainingTracker_write_line + return _NotebookTrainingTracker_write_line pass From deb7a8711db1150def95751e4d96cffcf82d46c6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:01:38 -0800 Subject: [PATCH 151/473] Update utils.py --- unsloth/kernels/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 165950a91..0bfd4269b 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -157,8 +157,8 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False global WEIGHT_BUFFER global ABSMAX_BUFFER if WEIGHT_BUFFER is None: - WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") - ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0", requires_grad = False) + ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False) if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) @@ -167,11 +167,11 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False out_absmax = ABSMAX_BUFFER[:n_elements_absmax] else: if out is None: - out = torch.empty(shape, dtype = dtype, device = "cuda:0") + out = torch.empty(shape, dtype = dtype, device = "cuda:0", requires_grad = False) else: assert(out.shape == shape) assert(out.dtype == dtype) - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False) pass # NF4 dequantization of statistics @@ -224,8 +224,8 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False global WEIGHT_BUFFER global ABSMAX_BUFFER if WEIGHT_BUFFER is None: - WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") - ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0") + WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0", requires_grad = False) + ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0", requires_grad = False) if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) @@ -234,11 +234,11 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False out_absmax = ABSMAX_BUFFER[:n_elements_absmax] else: if out is None: - out = torch.empty(shape, dtype = dtype, device = "cuda:0") + out = torch.empty(shape, dtype = dtype, device = "cuda:0", requires_grad = False) else: assert(out.shape == shape) assert(out.dtype == dtype) - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False) pass # Do dequantization From 47c9ff3d82e159deef74516ea31a0c4eb8d733d5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:02:42 -0800 Subject: [PATCH 152/473] Update utils.py --- unsloth/kernels/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 0bfd4269b..f052914f9 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -124,6 +124,7 @@ def get_lora_parameters_bias(proj): ABSMAX_BUFFER = None if HAS_CUDA_STREAM: + @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W if type(quant_state) is not list: @@ -193,6 +194,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False return out.t() if is_transposed else out pass else: + @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W if type(quant_state) is not list: From 7bec3c17dfabb6241a8114c484325c107ada2274 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:14:12 -0800 Subject: [PATCH 153/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index caf12cd6d..c4a835ed9 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,6 +39,7 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth + print("$$$$$$$$$$$$$$$$$$$$$$$") with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: FastLanguageModel.for_inference(unwrapped_model) yield unwrapped_model From 2c0c7b3d7a7cf5fc3c62259fa0a7e5ca988c1176 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:28:54 -0800 Subject: [PATCH 154/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c4a835ed9..932a29f78 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,12 +39,12 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth - print("$$$$$$$$$$$$$$$$$$$$$$$") with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: FastLanguageModel.for_inference(unwrapped_model) yield unwrapped_model # Return back to training mode FastLanguageModel.for_training (model) + yield model pass import trl.trainer From 5ccb46ab9126e531a6b56b383382331fb8a2eb12 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:32:51 -0800 Subject: [PATCH 155/473] Update rl.py --- unsloth/models/rl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 932a29f78..2282e8b31 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -43,8 +43,7 @@ def unsloth_unwrap_model_for_generation(model, *args, **kwargs): FastLanguageModel.for_inference(unwrapped_model) yield unwrapped_model # Return back to training mode - FastLanguageModel.for_training (model) - yield model + FastLanguageModel.for_training(model) pass import trl.trainer From eeca1a611b5a92d9425362b060d181511731f0be Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:37:32 -0800 Subject: [PATCH 156/473] Update rl.py --- unsloth/models/rl.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2282e8b31..b51be3b7f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,11 +39,14 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth - with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: - FastLanguageModel.for_inference(unwrapped_model) + FastLanguageModel.for_inference(model) + try: + unwrapped_model = unwrap_model_for_generation(model, *args, **kwargs) yield unwrapped_model - # Return back to training mode - FastLanguageModel.for_training(model) + finally: + # Finally return back training + FastLanguageModel.for_training(model) + pass pass import trl.trainer From 4d1e272a0e8bbb6b4d8fe3c7840a029ea4b71225 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:41:37 -0800 Subject: [PATCH 157/473] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b51be3b7f..26c73a7b1 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -41,8 +41,8 @@ def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth FastLanguageModel.for_inference(model) try: - unwrapped_model = unwrap_model_for_generation(model, *args, **kwargs) - yield unwrapped_model + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + yield unwrapped_model finally: # Finally return back training FastLanguageModel.for_training(model) From 906055d4039b07bfb13110a715407dd9522fd5b9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:44:18 -0800 Subject: [PATCH 158/473] Update rl.py --- unsloth/models/rl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 26c73a7b1..6603346fd 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,6 +39,7 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth + print("$$$$$$$$$$$$$$") FastLanguageModel.for_inference(model) try: with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: @@ -46,6 +47,7 @@ def unsloth_unwrap_model_for_generation(model, *args, **kwargs): finally: # Finally return back training FastLanguageModel.for_training(model) + print("###############") pass pass From e8ca0e7ee2de00d7a53f51239a095395e9502142 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:48:54 -0800 Subject: [PATCH 159/473] Update rl.py --- unsloth/models/rl.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 6603346fd..d77a4b378 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,15 +39,14 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth - print("$$$$$$$$$$$$$$") - FastLanguageModel.for_inference(model) - try: - with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + FastLanguageModel.for_inference(unwrapped_model) + try: yield unwrapped_model - finally: - # Finally return back training - FastLanguageModel.for_training(model) - print("###############") + finally: + # Finally return back training + FastLanguageModel.for_training(model) + pass pass pass From 9a2999bad9a33f3f4dd6e9f9829c0a276875592e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:50:14 -0800 Subject: [PATCH 160/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index d77a4b378..3129488f3 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,7 +39,7 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth - with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: FastLanguageModel.for_inference(unwrapped_model) try: yield unwrapped_model From 6d92ed61dba92224e6b0a2bfa50dee7a124c4dfd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:50:39 -0800 Subject: [PATCH 161/473] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3129488f3..72b911acb 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -37,9 +37,9 @@ def PatchRL(FastLanguageModel): from contextlib import contextmanager @contextmanager - def unsloth_unwrap_model_for_generation(model, *args, **kwargs): + def unsloth_unwrap_model_for_generation(model, accelerator): # Must use for_inference to allow inference in Unsloth - with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: FastLanguageModel.for_inference(unwrapped_model) try: yield unwrapped_model From 2c6f31ffe00ac074ca7a5f31c7768a806e15fdfb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:56:12 -0800 Subject: [PATCH 162/473] Update rl.py --- unsloth/models/rl.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 72b911acb..72b568790 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,13 +39,15 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, accelerator): # Must use for_inference to allow inference in Unsloth - with unwrap_model_for_generation(model, accelerator) as unwrapped_model: - FastLanguageModel.for_inference(unwrapped_model) - try: - yield unwrapped_model - finally: - # Finally return back training - FastLanguageModel.for_training(model) + with torch.inference_mode(): + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + FastLanguageModel.for_inference(unwrapped_model) + try: + yield unwrapped_model + finally: + # Finally return back training + FastLanguageModel.for_training(model) + pass pass pass pass From 65f991e2cf6da5c768f7628d030577a160dc4915 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:58:04 -0800 Subject: [PATCH 163/473] Update rl.py --- unsloth/models/rl.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 72b568790..2431e5a70 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,15 +39,13 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, accelerator): # Must use for_inference to allow inference in Unsloth - with torch.inference_mode(): - with unwrap_model_for_generation(model, accelerator) as unwrapped_model: - FastLanguageModel.for_inference(unwrapped_model) - try: - yield unwrapped_model - finally: - # Finally return back training - FastLanguageModel.for_training(model) - pass + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + FastLanguageModel.for_inference(unwrapped_model) + try: + yield unwrapped_model.eval() + finally: + # Finally return back training + FastLanguageModel.for_training(model) pass pass pass From c08c009798066eba17c522039edc8f676bb373f7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 07:13:40 -0800 Subject: [PATCH 164/473] Update rl.py --- unsloth/models/rl.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2431e5a70..06634ae3c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -33,16 +33,16 @@ def PatchRL(FastLanguageModel): - from trl.models.utils import unwrap_model_for_generation + from trl.models import unwrap_model_for_generation from contextlib import contextmanager @contextmanager def unsloth_unwrap_model_for_generation(model, accelerator): # Must use for_inference to allow inference in Unsloth + FastLanguageModel.for_inference(model) with unwrap_model_for_generation(model, accelerator) as unwrapped_model: - FastLanguageModel.for_inference(unwrapped_model) try: - yield unwrapped_model.eval() + yield unwrapped_model finally: # Finally return back training FastLanguageModel.for_training(model) @@ -50,6 +50,10 @@ def unsloth_unwrap_model_for_generation(model, accelerator): pass pass + import trl.models + trl.models.utils.unwrap_model_for_generation = unwrap_model_for_generation + trl.models.unwrap_model_for_generation = unwrap_model_for_generation + import trl.trainer trainers = dir(trl.trainer) trainers = [x for x in trainers if x.endswith("_trainer")] From a773af2635e2020542f91864ac069b79da8a042a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 07:25:05 -0800 Subject: [PATCH 165/473] Update rl.py --- unsloth/models/rl.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 06634ae3c..88db94bdf 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -33,27 +33,38 @@ def PatchRL(FastLanguageModel): - from trl.models import unwrap_model_for_generation + from trl.models.utils import unwrap_model_for_generation from contextlib import contextmanager @contextmanager - def unsloth_unwrap_model_for_generation(model, accelerator): - # Must use for_inference to allow inference in Unsloth - FastLanguageModel.for_inference(model) + def unsloth_unwrap_model_for_generation(model, *args, **kwargs): with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + # Put the model in inference mode. + FastLanguageModel.for_inference(unwrapped_model) + + # Monkey-patch the generate method so it clones its output. + original_generate = unwrapped_model.generate + + def generate_with_clone(*args, **kwargs): + out = original_generate(*args, **kwargs) + # If the output is a tensor (i.e. an inference tensor), clone it. + if isinstance(out, torch.Tensor): + return out.clone() + # Optionally, if out is a tuple or dict containing tensors, you + # might want to iterate over it and clone all tensors. + return out + + # Replace the generate method. + unwrapped_model.generate = generate_with_clone + try: yield unwrapped_model finally: - # Finally return back training + # Restore the original generate method and reset the model mode. + unwrapped_model.generate = original_generate FastLanguageModel.for_training(model) - pass - pass pass - import trl.models - trl.models.utils.unwrap_model_for_generation = unwrap_model_for_generation - trl.models.unwrap_model_for_generation = unwrap_model_for_generation - import trl.trainer trainers = dir(trl.trainer) trainers = [x for x in trainers if x.endswith("_trainer")] From fb24fc06737eb61ef8b833d509fcef2084d0fc2a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 07:27:07 -0800 Subject: [PATCH 166/473] Update rl.py --- unsloth/models/rl.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 88db94bdf..21ade011e 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -41,28 +41,26 @@ def unsloth_unwrap_model_for_generation(model, *args, **kwargs): with unwrap_model_for_generation(model, accelerator) as unwrapped_model: # Put the model in inference mode. FastLanguageModel.for_inference(unwrapped_model) - - # Monkey-patch the generate method so it clones its output. - original_generate = unwrapped_model.generate + # We must use .clone for Unsloth since we force inference_mode + # Rather we should have used no_grad + original_generate = unwrapped_model.generate def generate_with_clone(*args, **kwargs): out = original_generate(*args, **kwargs) - # If the output is a tensor (i.e. an inference tensor), clone it. if isinstance(out, torch.Tensor): return out.clone() - # Optionally, if out is a tuple or dict containing tensors, you - # might want to iterate over it and clone all tensors. return out - - # Replace the generate method. + pass unwrapped_model.generate = generate_with_clone try: yield unwrapped_model finally: - # Restore the original generate method and reset the model mode. + # Restore generate and return unwrapped_model.generate = original_generate FastLanguageModel.for_training(model) + pass + pass pass import trl.trainer From 30b0fa80b91274d1d1868bebf36dd7e3d26a5ec1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 07:28:16 -0800 Subject: [PATCH 167/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 21ade011e..0253fca7a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -37,7 +37,7 @@ def PatchRL(FastLanguageModel): from contextlib import contextmanager @contextmanager - def unsloth_unwrap_model_for_generation(model, *args, **kwargs): + def unsloth_unwrap_model_for_generation(model, accelerator): with unwrap_model_for_generation(model, accelerator) as unwrapped_model: # Put the model in inference mode. FastLanguageModel.for_inference(unwrapped_model) From 5bb5bfbb1162ba13465399b36f7275ddf1ece848 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 14:59:01 -0800 Subject: [PATCH 168/473] RL metrics --- unsloth/models/dpo.py | 113 ++---------------------------------------- unsloth/models/rl.py | 67 +++++++++++++++++-------- 2 files changed, 48 insertions(+), 132 deletions(-) diff --git a/unsloth/models/dpo.py b/unsloth/models/dpo.py index 5dc71f920..51f1c9a63 100644 --- a/unsloth/models/dpo.py +++ b/unsloth/models/dpo.py @@ -17,115 +17,8 @@ "PatchKTOTrainer", ] -try: - from transformers.utils.notebook import ( - IntervalStrategy, - NotebookTrainingTracker, - NotebookProgressCallback, - ) - HAS_NOTEBOOK = True -except: - HAS_NOTEBOOK = False -pass -import torch -from ._utils import torch_compile_options -import inspect -import torch.nn as nn -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from .rl import PatchRLStatistics +def PatchDPOTrainer(): PatchRLStatistics("DPO") -DPOTrainer_metrics = [ - "rewards/chosen", - "rewards/rejected", - "rewards/accuracies", - "rewards/margins", - "logps/rejected", - "logps/chosen", - "logits/rejected", - "logits/chosen", -] -set_DPOTrainer_metrics = frozenset(DPOTrainer_metrics) - - -def NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs): - self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step" - self.training_loss = 0 - self.last_log = 0 - column_names = [self.first_column] + ["Training Loss"] - if args.eval_strategy != IntervalStrategy.NO: - column_names.append("Validation Loss") - column_names += [x.replace("/", " / ") for x in DPOTrainer_metrics] - self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) -pass - - -def NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs): - # Only for when there is no evaluation - if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: - values = {"Training Loss": logs["loss"]} - for metric in DPOTrainer_metrics: - values[metric.replace("/", " / ")] = logs[metric] - pass - # First column is necessarily Step since we're not in epoch eval strategy - values["Step"] = state.global_step - self.training_tracker.write_line(values) - pass -pass - - -def NotebookTrainingTracker_write_line(self, values): - """ - Write the values in the inner table. - - Args: - values (`Dict[str, float]`): The values to display. - """ - if self.inner_table is None: - self.inner_table = [list(values.keys()), list(values.values())] - else: - columns = self.inner_table[0] - new_values = {} - for key, value in values.items(): - lowered = key.lower() - if lowered in set_DPOTrainer_metrics: - new_values[lowered.replace("/", " / ")] = value - else: - new_values[key] = value - pass - values = new_values - - self.inner_table[0] = columns - if len(self.inner_table) > 1: - last_values = self.inner_table[-1] - first_column = self.inner_table[0][0] - if last_values[0] != values[first_column]: - # write new line - self.inner_table.append([values[c] if c in values else "No Log" for c in columns]) - else: - # update last line - new_values = values - for c in columns: - if c not in new_values.keys(): - new_values[c] = last_values[columns.index(c)] - self.inner_table[-1] = [new_values[c] for c in columns] - else: - # Edit for evaluation purposes - self.inner_table.append([values[c] if c in values else 0 for c in columns]) - pass - pass -pass - - -def PatchDPOTrainer(): - if HAS_NOTEBOOK: - from transformers.trainer import is_in_notebook - if is_in_notebook(): - # Patch DPO notebook printing - NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line - from transformers.trainer import DEFAULT_PROGRESS_CALLBACK - DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin - DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log - pass - pass -pass -PatchKTOTrainer = PatchDPOTrainer +def PatchKTOTrainer(): PatchRLStatistics("KTO") diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 0253fca7a..18b2415f2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -29,7 +29,10 @@ HAS_NOTEBOOK = False pass from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union - +import inspect +import os +import re +import functools def PatchRL(FastLanguageModel): @@ -94,7 +97,7 @@ def _NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kw # Only for when there is no evaluation if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: values = {"Training Loss": logs["loss"]} - for metric in DPOTrainer_metrics: + for metric in Trainer_metrics: values[metric.replace("/", " / ")] = logs[metric] pass # First column is necessarily Step since we're not in epoch eval strategy @@ -167,27 +170,47 @@ def _PatchRLStatistics(metrics): pass +@functools.cache +def get_trl_metrics(): + # Gets metrics so we can output them in notebooks + + import trl.trainer + trainers = dir(trl.trainer) + trainers = [x for x in trainers if x.endswith("_trainer")] + filepath = inspect.getfile(trl.trainer) + filepath = os.path.split(filepath)[0] + + all_metrics = dict() + for trainer in trainers: + filename = os.path.join(filepath, f"{trainer}.py") + if not os.path.exists(filename): continue + with open(filename, "r") as file: file = file.read() + + # Get metrics['kl'] or stats['kl'] + metrics = re.findall(r"metrics\[[\"\']([^\"\']{1,})[\"\']\]", file) + stats = re.findall(r"stats\[[\"\']([^\"\']{1,})[\"\']\]", file) + metrics = metrics + stats + + # Get optional f-strings + metrics_f = re.findall(r"metrics\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) + stats_f = re.findall(r"stats\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) + metrics_f = metrics_f + stats_f + # Filter out prefixes if seen + # metrics[f"{prefix}rewards/chosen"] + left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file + if left_prefix: metrics += metrics_f + + all_metrics[trainer[:trainer.find("_")].upper()] = metrics + pass + return all_metrics +pass + + def PatchRLStatistics(algorithm = "GRPO"): algorithm = algorithm.upper() - if algorithm == "GRPO": - metrics = [ - "completion_length", - "reward", - "reward_std", - "kl", - ] - elif algorithm == "DPO" or algorithm == "KTO": - metrics = [ - "rewards/chosen", - "rewards/rejected", - "rewards/accuracies", - "rewards/margins", - "logps/rejected", - "logps/chosen", - "logits/rejected", - "logits/chosen", - ] - else: + all_metrics = get_trl_metrics() + if algorithm not in all_metrics: print(f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.") - _PatchRLStatistics(metrics) + pass + _PatchRLStatistics(all_metrics[algorithm]) pass From 0b6db78d6a9650ec1acc25ad6f6e761f73bbbb04 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 15:02:52 -0800 Subject: [PATCH 169/473] Update rl.py --- unsloth/models/rl.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 18b2415f2..02bc10c6f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -156,8 +156,10 @@ def _NotebookTrainingTracker_write_line(self, values): pass -def _PatchRLStatistics(metrics): +def _PatchRLStatistics(metrics, algorithm): if HAS_NOTEBOOK: + if len(metrics) == 0: + raise RuntimeError(f"Unsloth: RL statistics for {algorithm} failed with no metrics seen?") from transformers.trainer import is_in_notebook if is_in_notebook(): # Patch DPO notebook printing @@ -210,7 +212,10 @@ def PatchRLStatistics(algorithm = "GRPO"): algorithm = algorithm.upper() all_metrics = get_trl_metrics() if algorithm not in all_metrics: - print(f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.") + print( + f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.\n"\ + f"We support: `{list(all_metrics.keys())}`" + ) pass - _PatchRLStatistics(all_metrics[algorithm]) + _PatchRLStatistics(all_metrics[algorithm], algorithm) pass From 115701a74ad6ced46a51e6f072fecc6faa82dd96 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 15:08:10 -0800 Subject: [PATCH 170/473] RL metrics --- unsloth/models/dpo.py | 6 +++--- unsloth/models/rl.py | 12 ++++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/unsloth/models/dpo.py b/unsloth/models/dpo.py index 51f1c9a63..9c12abb98 100644 --- a/unsloth/models/dpo.py +++ b/unsloth/models/dpo.py @@ -17,8 +17,8 @@ "PatchKTOTrainer", ] -from .rl import PatchRLStatistics +from .rl import PatchFastRL -def PatchDPOTrainer(): PatchRLStatistics("DPO") +def PatchDPOTrainer(): PatchFastRL("DPO") -def PatchKTOTrainer(): PatchRLStatistics("KTO") +def PatchKTOTrainer(): PatchFastRL("KTO") diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 02bc10c6f..40d68f6a7 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -13,8 +13,7 @@ # limitations under the License. __all__ = [ - "PatchRL", - "PatchRLStatistics", + "PatchFastRL", ] import torch @@ -202,6 +201,9 @@ def get_trl_metrics(): left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file if left_prefix: metrics += metrics_f + # Remove all eval_ things + metrics = [x for x in metrics if not x.startswith("eval_")] + all_metrics[trainer[:trainer.find("_")].upper()] = metrics pass return all_metrics @@ -219,3 +221,9 @@ def PatchRLStatistics(algorithm = "GRPO"): pass _PatchRLStatistics(all_metrics[algorithm], algorithm) pass + + +def PatchFastRL(algorithm = "GRPO", FastLanguageModel = None): + if FastLanguageModel is not None: PatchRL(FastLanguageModel) + PatchRLStatistics(algorithm) +pass From 12038fd534fc0b2759e4f7efc14b2cff2bc65c27 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 15:11:40 -0800 Subject: [PATCH 171/473] Update __init__.py --- unsloth/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index 279080173..b15e04ab7 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -20,4 +20,4 @@ from .qwen2 import FastQwen2Model from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported -from .rl import PatchRL, PatchRLStatistics +from .rl import PatchFastRL From e2a526e9d069b13f0a138e8af2d7d48a530e5ec7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 15:16:44 -0800 Subject: [PATCH 172/473] Update rl.py --- unsloth/models/rl.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 40d68f6a7..4c6d73ee8 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -201,6 +201,21 @@ def get_trl_metrics(): left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file if left_prefix: metrics += metrics_f + # Remove optional items + # if ...: metrics[...] = + metrics_optional = re.findall( + r"if[^\n]{1,}\n[\s]{4,}"\ + r"(?:metrics|stats)"\ + r"\["\ + r"(?:f[\"\']\{[^\}]{1,}\})?"\ + r"([^\"\']{1,})[\"\']"\ + r"\]", + file, + flags = re.MULTILINE, + ) + metrics_optional = set(metrics_optional) + metrics = [x for x in metrics if x not in metrics_optional] + # Remove all eval_ things metrics = [x for x in metrics if not x.startswith("eval_")] From e74dbb5bb45137a5d0a74cbe6057833217c7e75f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 15:21:53 -0800 Subject: [PATCH 173/473] Update rl.py --- unsloth/models/rl.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 4c6d73ee8..752a9d9b2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -97,7 +97,9 @@ def _NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kw if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: values = {"Training Loss": logs["loss"]} for metric in Trainer_metrics: - values[metric.replace("/", " / ")] = logs[metric] + # Sometimes metric is not inside logs + try: values[metric.replace("/", " / ")] = logs[metric] + except: pass pass # First column is necessarily Step since we're not in epoch eval strategy values["Step"] = state.global_step @@ -201,21 +203,6 @@ def get_trl_metrics(): left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file if left_prefix: metrics += metrics_f - # Remove optional items - # if ...: metrics[...] = - metrics_optional = re.findall( - r"if[^\n]{1,}\n[\s]{4,}"\ - r"(?:metrics|stats)"\ - r"\["\ - r"(?:f[\"\']\{[^\}]{1,}\})?"\ - r"([^\"\']{1,})[\"\']"\ - r"\]", - file, - flags = re.MULTILINE, - ) - metrics_optional = set(metrics_optional) - metrics = [x for x in metrics if x not in metrics_optional] - # Remove all eval_ things metrics = [x for x in metrics if not x.startswith("eval_")] From 054ebb3594a4dcfc1a7a967df65d94955545fad8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 15:36:59 -0800 Subject: [PATCH 174/473] Update rl.py --- unsloth/models/rl.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 752a9d9b2..ca1a1b5db 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -16,6 +16,12 @@ "PatchFastRL", ] +METRICS_MOVE_TO_END = [ + "nll", + "aux", + "beta", + "alpha", +] import torch try: from transformers.utils.notebook import ( @@ -203,8 +209,29 @@ def get_trl_metrics(): left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file if left_prefix: metrics += metrics_f - # Remove all eval_ things - metrics = [x for x in metrics if not x.startswith("eval_")] + # Move all eval_ things to the end and reward to the front + beginning = [] + middle = [] + end = [] + for x in metrics: + lowered = x.lower() + if "reward" in lowered: + beginning.append(x) + elif x.lower().startswith("eval"): + end.append(x) + else: + # Check if we want to move to the end + moved = False + for move_end in METRICS_MOVE_TO_END: + if move_end in lowered: + end.append(x) + moved = True + break + if not moved: + middle.append(x) + pass + pass + metrics = beginning + middle + end all_metrics[trainer[:trainer.find("_")].upper()] = metrics pass From 4d68b9c17a0cedd4749fb86a0652c234801be111 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 17:52:36 -0800 Subject: [PATCH 175/473] Update chat_templates.py --- unsloth/chat_templates.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index d8dc38522..c40139323 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -759,6 +759,10 @@ CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,) DEFAULT_SYSTEM_MESSAGE["llama-31"] = "" # Llama3.1 default system message is empty + the dates + +for version in ("llama-3.2", "llama-3.3", "llama-32", "llama-33"): + CHAT_TEMPLATES[version] = CHAT_TEMPLATES["llama-3.1"] + DEFAULT_SYSTEM_MESSAGE[version] = "" pass From 547867d44b3f1231839b27d399ba047fa38964ec Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 18:31:01 -0800 Subject: [PATCH 176/473] Update mapper.py --- unsloth/models/mapper.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 6e6e402a0..c81290b66 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -304,25 +304,30 @@ "unsloth/Mistral-Small-Instruct-2409", "mistralai/Mistral-Small-Instruct-2409", ), - "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-0.5B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct", + "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit", ), - "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-1.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct", + "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit", ), - "unsloth/Qwen2.5-3B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-3B-Instruct", + "unsloth/Qwen2.5-3B-Instruct-bnb-4bit", ), - "unsloth/Qwen2.5-7B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-7B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-7B-Instruct", "Qwen/Qwen2.5-7B-Instruct", + "unsloth/Qwen2.5-7B-Instruct-bnb-4bit", ), - "unsloth/Qwen2.5-14B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-14B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-14B-Instruct", "Qwen/Qwen2.5-14B-Instruct", + "unsloth/Qwen2.5-14B-Instruct-bnb-4bit", ), "unsloth/Qwen2.5-32B-Instruct-bnb-4bit" : ( "unsloth/Qwen2.5-32B-Instruct", @@ -332,25 +337,30 @@ "unsloth/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-72B-Instruct", ), - "unsloth/Qwen2.5-0.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-0.5B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-0.5B", "Qwen/Qwen2.5-0.5B", + "unsloth/Qwen2.5-0.5B-bnb-4bit", ), - "unsloth/Qwen2.5-1.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-1.5B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-1.5B", "Qwen/Qwen2.5-1.5B", + "unsloth/Qwen2.5-1.5B-bnb-4bit", ), - "unsloth/Qwen2.5-3B-bnb-4bit" : ( + "unsloth/Qwen2.5-3B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-3B", "Qwen/Qwen2.5-3B", + "unsloth/Qwen2.5-3B-bnb-4bit", ), - "unsloth/Qwen2.5-7B-bnb-4bit" : ( + "unsloth/Qwen2.5-7B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-7B", "Qwen/Qwen2.5-7B", + "unsloth/Qwen2.5-7B-bnb-4bit", ), - "unsloth/Qwen2.5-14B-bnb-4bit" : ( + "unsloth/Qwen2.5-14B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-14B", "Qwen/Qwen2.5-14B", + "unsloth/Qwen2.5-14B-bnb-4bit", ), "unsloth/Qwen2.5-32B-bnb-4bit" : ( "unsloth/Qwen2.5-32B", From 8be4bfa446ab80caafeb1f1870dce8e0abfad29e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 20:00:02 -0800 Subject: [PATCH 177/473] Fp8 cache --- unsloth/models/llama.py | 2 +- unsloth/models/loader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ab90d2cbb..a337472a3 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1637,7 +1637,7 @@ def from_pretrained( fast_inference = False, # uses vLLM gpu_memory_utilization = 0.5, - float8_kv_cache = True, + float8_kv_cache = False, random_state = 3407, max_lora_rank = 16, disable_log_stats = False, diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 144863b8d..ad312e004 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -76,7 +76,7 @@ def from_pretrained( fast_inference = False, # uses vLLM gpu_memory_utilization = 0.5, - float8_kv_cache = True, + float8_kv_cache = False, random_state = 3407, max_lora_rank = 16, disable_log_stats = False, From 9eb8bf10085baa0393eb100ffb50ce7b51b183d2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 20:24:33 -0800 Subject: [PATCH 178/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index a337472a3..795281200 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -384,6 +384,7 @@ def LlamaAttention_fast_forward( assert(n_kv_heads * n_groups == n_heads) Q, K, V = self.apply_qkv(self, hidden_states) + print("#######", Q, self.q_proj.lora_B.default.weight) Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) From d2b66ca54da1e48fd759c520b3a98d71c722225d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 20:30:36 -0800 Subject: [PATCH 179/473] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 795281200..a337472a3 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -384,7 +384,6 @@ def LlamaAttention_fast_forward( assert(n_kv_heads * n_groups == n_heads) Q, K, V = self.apply_qkv(self, hidden_states) - print("#######", Q, self.q_proj.lora_B.default.weight) Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) From 604329ca61d616f0ed8386d6e617a273ff45f70d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 00:59:13 -0800 Subject: [PATCH 180/473] Update rl.py --- unsloth/models/rl.py | 132 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ca1a1b5db..b653fb960 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,6 +39,7 @@ import re import functools + def PatchRL(FastLanguageModel): from trl.models.utils import unwrap_model_for_generation @@ -240,6 +241,7 @@ def get_trl_metrics(): def PatchRLStatistics(algorithm = "GRPO"): + # Get notebook statistics columns to show up algorithm = algorithm.upper() all_metrics = get_trl_metrics() if algorithm not in all_metrics: @@ -252,7 +254,137 @@ def PatchRLStatistics(algorithm = "GRPO"): pass +def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): + # Patch for vLLM and Unsloth PEFT + import trl.trainer + + trainer = eval(f"trl.trainer.{trainer_file}") + name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()] + assert(len(name) == 1) + RLTrainer_name = name[0] + RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") + + try: + __init__ = inspect.getsource(RLTrainer.__init__) + except: + # Already patched most likely! + return + all_imports = dir(trainer) + imports = [x for x in all_imports if not x.startswith("_") and x in __init__] + + spaces = __init__.find("def") + __init__ = __init__.split("\n") + __init__ = "\n".join(x[spaces:] for x in __init__) + + vllm_part = re.findall( + r"(\n[\s]{4}"\ + r"if (self|args)\.use_vllm\:.+?"\ + r"\n[\s]{4,}"\ + "else:\n)", + __init__, + flags = re.MULTILINE | re.DOTALL, + ) + if (len(vllm_part) != 1): return + + vllm_part, args = vllm_part[0][0], vllm_part[0][1] + # Strip all comments + new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part) + + # Get SamplingParams + sampling_params = re.findall( + r"\n[\s]{4,}(self\.[^\s]{1,}[\s]{0,}\=[\s]{0,}"\ + r"SamplingParams\(.+?\))", + new_vllm_part, + flags = re.MULTILINE | re.DOTALL, + ) + if len(sampling_params) != 1: return + + sampling_params = sampling_params[0] + sampling_params = \ + " "*8 + "self.llm = model.vllm_engine; " + \ + sampling_params # Add spaces + new_vllm_part = f"\n if {args}.use_vllm:\n{sampling_params}\n else:\n" + __init__ = __init__.replace(vllm_part, new_vllm_part) + + # Remove peft_config + __init__ = __init__.replace("elif peft_config is None:", "elif False:") + __init__ = __init__.replace("elif peft_config is not None:", "elif False:") + __init__ = __init__.replace("if peft_config is None:", "if False:") + __init__ = __init__.replace("if peft_config is not None:", "if False:") + __init__ = __init__.replace("get_peft_model(model, peft_config)", "model") + + # Search for vLLM calling in all child functions + functions = dir(RLTrainer) + RLTrainer_source = inspect.getsource(RLTrainer) + functions = [x for x in functions if f"def {x}" in RLTrainer_source] + + changed = {"__init__" : __init__} + for function in functions: + if not hasattr(RLTrainer, function): continue + fx = getattr(RLTrainer, function) + try: + source = inspect.getsource(fx) + except: + continue + original_source = source + + # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + source = re.sub( + r"(\n[\s]{4,}).+?model_executor\.driver_worker.+?\n", + r"\n\1pass\n", + source, + ) + # llm_model.load_weights(model.state_dict().items()) + source = re.sub( + r"(\n[\s]{4,}).+?load_weights\(.+?\n", + r"\n\1pass\n", + source, + ) + # Replace self.llm.generate and self.llm.chat + lora_name = trainer_file + "_lora_model" + source = re.sub( + r"(self\.llm\.(?:generate|chat)\([^\)]{1,})\)", + r"\1, lora_request = model.load_lora('" + lora_name + r"', load_tensors = True))", + source + ) + if source == original_source: continue + + # Find all imports + imports += [x for x in all_imports if not x.startswith("_") and x in source] + + # Create actual function + spaces = source.find("def") + source = source.split("\n") + source = "\n".join(x[spaces:] for x in source) + changed[function] = source + pass + + # Import all functions + imports = list(set(imports)) + imports = f"from trl.trainer.{trainer_file} import (\n" + ',\n'.join(imports) + ")" + exec(imports) + + # Patch all functions + for function in changed: + exec(changed[function]) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = {function}") + pass +pass + + +def patch_trl_rl_trainers(): + # Patch all TRL modules if they have vLLM or PEFT + import trl.trainer + all_trainers = dir(trl.trainer) + all_trainers = [x for x in all_trainers if x.islower() and x.endswith("_trainer")] + for trainer in all_trainers: + _patch_trl_rl_trainers(trainer) + return +pass + + def PatchFastRL(algorithm = "GRPO", FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) + patch_trl_rl_trainers() PatchRLStatistics(algorithm) pass From 2c158dfbce48e11656c5a485529d007d13bfc3a1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:02:31 -0800 Subject: [PATCH 181/473] Update rl.py --- unsloth/models/rl.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b653fb960..13c2a62f1 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -276,6 +276,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): __init__ = __init__.split("\n") __init__ = "\n".join(x[spaces:] for x in __init__) + # Replace vLLM sections since we already have it done! vllm_part = re.findall( r"(\n[\s]{4}"\ r"if (self|args)\.use_vllm\:.+?"\ @@ -300,6 +301,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if len(sampling_params) != 1: return sampling_params = sampling_params[0] + # Replace with our vLLM engine sampling_params = \ " "*8 + "self.llm = model.vllm_engine; " + \ sampling_params # Add spaces @@ -334,12 +336,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): r"\n\1pass\n", source, ) + # llm_model.load_weights(model.state_dict().items()) source = re.sub( r"(\n[\s]{4,}).+?load_weights\(.+?\n", r"\n\1pass\n", source, ) + # Replace self.llm.generate and self.llm.chat lora_name = trainer_file + "_lora_model" source = re.sub( @@ -347,6 +351,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): r"\1, lora_request = model.load_lora('" + lora_name + r"', load_tensors = True))", source ) + + # Skip if no changes done if source == original_source: continue # Find all imports From 43116a21ee81ff3f76dba86295e428273369359d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:05:48 -0800 Subject: [PATCH 182/473] Update rl.py --- unsloth/models/rl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 13c2a62f1..ef2fcb567 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -362,6 +362,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): spaces = source.find("def") source = source.split("\n") source = "\n".join(x[spaces:] for x in source) + + # Replace function name with _unsloth_... + source = source.replace("def ", "def _unsloth_", 1) changed[function] = source pass @@ -372,8 +375,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Patch all functions for function in changed: - exec(changed[function]) - exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = {function}") + exec(changed[function], locals(), globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", locals(), globals()) pass pass From 656ce86bd94ed4611fbbc2449cefa9cd8661d660 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:06:20 -0800 Subject: [PATCH 183/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ef2fcb567..0a516bce2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -371,7 +371,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Import all functions imports = list(set(imports)) imports = f"from trl.trainer.{trainer_file} import (\n" + ',\n'.join(imports) + ")" - exec(imports) + exec(imports, locals(), globals()) # Patch all functions for function in changed: From 832cd9b34b0c7cf0979e6fa9e6de22c2229afc47 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:06:40 -0800 Subject: [PATCH 184/473] Update rl.py --- unsloth/models/rl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 0a516bce2..5b9aec652 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -371,12 +371,13 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Import all functions imports = list(set(imports)) imports = f"from trl.trainer.{trainer_file} import (\n" + ',\n'.join(imports) + ")" - exec(imports, locals(), globals()) + imported_functions = {} + exec(imports, imported_functions) # Patch all functions for function in changed: - exec(changed[function], locals(), globals()) - exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", locals(), globals()) + exec(changed[function], imported_functions, globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", imported_functions, globals()) pass pass From 8178b32271b5b053d3a368a3cac5aed525589ed2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:07:00 -0800 Subject: [PATCH 185/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5b9aec652..cd0bb0b39 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -377,6 +377,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Patch all functions for function in changed: exec(changed[function], imported_functions, globals()) + print(changed[function]) exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", imported_functions, globals()) pass pass From 40bb9456d88c7e59801d83221cb401ec3b021001 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:08:31 -0800 Subject: [PATCH 186/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index cd0bb0b39..bc6fa0f7a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -275,6 +275,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): spaces = __init__.find("def") __init__ = __init__.split("\n") __init__ = "\n".join(x[spaces:] for x in __init__) + __init__ = __init__.replace("def ", "def _unsloth_", 1) # Replace vLLM sections since we already have it done! vllm_part = re.findall( From 9d71ee4c4e701617858190394fd8347766c0ac54 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:10:42 -0800 Subject: [PATCH 187/473] Update rl.py --- unsloth/models/rl.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index bc6fa0f7a..1f33d46e6 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -372,14 +372,13 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Import all functions imports = list(set(imports)) imports = f"from trl.trainer.{trainer_file} import (\n" + ',\n'.join(imports) + ")" - imported_functions = {} - exec(imports, imported_functions) + exec(imports, locals()) # Patch all functions for function in changed: - exec(changed[function], imported_functions, globals()) + exec(changed[function], locals(), globals()) print(changed[function]) - exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", imported_functions, globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", locals(), globals()) pass pass From 1ee8492b97fdab719fcb597399fdc947a3d6153a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:12:33 -0800 Subject: [PATCH 188/473] Update rl.py --- unsloth/models/rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 1f33d46e6..071818a7a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -377,7 +377,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Patch all functions for function in changed: exec(changed[function], locals(), globals()) - print(changed[function]) exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", locals(), globals()) pass pass From 58cd0c9c6d2a50802f0d8d5cf51e8f9fa2c6d4e5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:16:43 -0800 Subject: [PATCH 189/473] Update rl.py --- unsloth/models/rl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 071818a7a..b48f9eeee 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -271,6 +271,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): return all_imports = dir(trainer) imports = [x for x in all_imports if not x.startswith("_") and x in __init__] + imports += ["Trainer"] spaces = __init__.find("def") __init__ = __init__.split("\n") @@ -316,6 +317,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): __init__ = __init__.replace("if peft_config is not None:", "if False:") __init__ = __init__.replace("get_peft_model(model, peft_config)", "model") + # Change super() to Trainer + __init__ = __init__.replace("super()", "super(Trainer, self)") + # Search for vLLM calling in all child functions functions = dir(RLTrainer) RLTrainer_source = inspect.getsource(RLTrainer) From fd347a2c416347628424b9669c6e2d1d80ef5166 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:47:44 -0800 Subject: [PATCH 190/473] Update rl.py --- unsloth/models/rl.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b48f9eeee..c7d3ab2c2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -269,14 +269,15 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): except: # Already patched most likely! return + old__init__ = __init__ all_imports = dir(trainer) - imports = [x for x in all_imports if not x.startswith("_") and x in __init__] + assert("Union" in all_imports) + imports = [x for x in all_imports if not x.startswith("_")] imports += ["Trainer"] spaces = __init__.find("def") __init__ = __init__.split("\n") __init__ = "\n".join(x[spaces:] for x in __init__) - __init__ = __init__.replace("def ", "def _unsloth_", 1) # Replace vLLM sections since we already have it done! vllm_part = re.findall( @@ -318,14 +319,18 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): __init__ = __init__.replace("get_peft_model(model, peft_config)", "model") # Change super() to Trainer - __init__ = __init__.replace("super()", "super(Trainer, self)") + __init__ = __init__.replace("super()", f"super(Unsloth{RLTrainer_name}, self)") + + # Add spaces back into __init__ + __init__ = __init__.split("\n") + __init__ = "\n".join(' '*spaces + x for x in __init__) # Search for vLLM calling in all child functions functions = dir(RLTrainer) RLTrainer_source = inspect.getsource(RLTrainer) functions = [x for x in functions if f"def {x}" in RLTrainer_source] - changed = {"__init__" : __init__} + changed = {"__init__" : (old__init__, __init__,)} for function in functions: if not hasattr(RLTrainer, function): continue fx = getattr(RLTrainer, function) @@ -363,26 +368,26 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Find all imports imports += [x for x in all_imports if not x.startswith("_") and x in source] - # Create actual function - spaces = source.find("def") - source = source.split("\n") - source = "\n".join(x[spaces:] for x in source) - - # Replace function name with _unsloth_... - source = source.replace("def ", "def _unsloth_", 1) - changed[function] = source + changed[function] = (original_source, source,) pass # Import all functions imports = list(set(imports)) imports = f"from trl.trainer.{trainer_file} import (\n" + ',\n'.join(imports) + ")" - exec(imports, locals()) + imported_functions = {} + exec(imports, globals(), imported_functions) # Patch all functions for function in changed: - exec(changed[function], locals(), globals()) - exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", locals(), globals()) + old, new = changed[function] + RLTrainer_source = RLTrainer_source.replace(old, new) pass + RLTrainer_source = RLTrainer_source.replace( + f"class {RLTrainer_name}", f"class Unsloth{RLTrainer_name}", 1 + ) + exec(RLTrainer_source, imported_functions, globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.trainer.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) pass From 9d06a56ff8ddf671ab6480be5b966aa8185437cd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:50:22 -0800 Subject: [PATCH 191/473] Update rl.py --- unsloth/models/rl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c7d3ab2c2..e8d9c19a5 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -256,6 +256,7 @@ def PatchRLStatistics(algorithm = "GRPO"): def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Patch for vLLM and Unsloth PEFT + import trl import trl.trainer trainer = eval(f"trl.trainer.{trainer_file}") @@ -388,6 +389,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): exec(RLTrainer_source, imported_functions, globals()) exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) exec(f"trl.trainer.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) pass From 00b6aa803fa9e6cad6c3ce00be238249a7b11507 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:57:37 -0800 Subject: [PATCH 192/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index e8d9c19a5..234257473 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -387,6 +387,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): f"class {RLTrainer_name}", f"class Unsloth{RLTrainer_name}", 1 ) exec(RLTrainer_source, imported_functions, globals()) + globals()[f"Unsloth{RLTrainer_name}"] = eval(f"Unsloth{RLTrainer_name}") exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) exec(f"trl.trainer.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) exec(f"trl.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) From 2c2388eb44e32d79c95cb3f07138476152694c8c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:14:59 -0800 Subject: [PATCH 193/473] Update rl.py --- unsloth/models/rl.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 234257473..d70a83f71 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -38,6 +38,7 @@ import os import re import functools +from unsloth_zoo.compiler import create_new_function def PatchRL(FastLanguageModel): @@ -319,9 +320,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): __init__ = __init__.replace("if peft_config is not None:", "if False:") __init__ = __init__.replace("get_peft_model(model, peft_config)", "model") - # Change super() to Trainer - __init__ = __init__.replace("super()", f"super(Unsloth{RLTrainer_name}, self)") - # Add spaces back into __init__ __init__ = __init__.split("\n") __init__ = "\n".join(' '*spaces + x for x in __init__) @@ -374,9 +372,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Import all functions imports = list(set(imports)) - imports = f"from trl.trainer.{trainer_file} import (\n" + ',\n'.join(imports) + ")" - imported_functions = {} - exec(imports, globals(), imported_functions) # Patch all functions for function in changed: @@ -386,11 +381,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class Unsloth{RLTrainer_name}", 1 ) - exec(RLTrainer_source, imported_functions, globals()) - globals()[f"Unsloth{RLTrainer_name}"] = eval(f"Unsloth{RLTrainer_name}") - exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) - exec(f"trl.trainer.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) - exec(f"trl.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) + + module = create_new_function( + RLTrainer_name, + RLTrainer_source, + f"trl.trainer.{trainer_file}", + imports, + ) + return module pass From 9e3e1bacd6695b2e4752e6f6db3282f6d8c76d94 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:21:08 -0800 Subject: [PATCH 194/473] Update rl.py --- unsloth/models/rl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index d70a83f71..fe1587f56 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -388,6 +388,11 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): f"trl.trainer.{trainer_file}", imports, ) + + # Patch over modules + exec(f"trl.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.trainer.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) return module pass From 505daf88b8a70de4e3148a38ed7b7695293c28ef Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:24:13 -0800 Subject: [PATCH 195/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index fe1587f56..c78587030 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -308,7 +308,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): sampling_params = sampling_params[0] # Replace with our vLLM engine sampling_params = \ - " "*8 + "self.llm = model.vllm_engine; " + \ + " "*8 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ sampling_params # Add spaces new_vllm_part = f"\n if {args}.use_vllm:\n{sampling_params}\n else:\n" __init__ = __init__.replace(vllm_part, new_vllm_part) From 5d53641a577813aa0a1c0213d861b97090ab9440 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:31:01 -0800 Subject: [PATCH 196/473] Update rl.py --- unsloth/models/rl.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c78587030..22e1e0f6c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -353,6 +353,13 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): source, ) + # .state_dict() + source = re.sub( + r"\.state_dict\(\)", + r"", + source, + ) + # Replace self.llm.generate and self.llm.chat lora_name = trainer_file + "_lora_model" source = re.sub( @@ -382,6 +389,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): f"class {RLTrainer_name}", f"class Unsloth{RLTrainer_name}", 1 ) + # Create new class in compiled cache and import it module = create_new_function( RLTrainer_name, RLTrainer_source, From cfb1a008962390a925e8448bc7a93f47351847c5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:32:34 -0800 Subject: [PATCH 197/473] Update __init__.py --- unsloth/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 1f82dd8b5..c89fd0f1f 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -196,7 +196,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.1.4"): + if Version(unsloth_zoo_version) < Version("2025.2.1"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: From 1f5a41813b026237549da0c751698a8fdfc916aa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:36:42 -0800 Subject: [PATCH 198/473] Update loader.py --- unsloth/models/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index ad312e004..39b367e27 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -78,8 +78,8 @@ def from_pretrained( gpu_memory_utilization = 0.5, float8_kv_cache = False, random_state = 3407, - max_lora_rank = 16, - disable_log_stats = False, + max_lora_rank = 64, + disable_log_stats = True, *args, **kwargs, ): if token is None: token = get_token() From 34d92aa6941b89380f2ef4128b1891cfe3793ac4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 04:23:20 -0800 Subject: [PATCH 199/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 22e1e0f6c..e5101662b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -364,7 +364,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): lora_name = trainer_file + "_lora_model" source = re.sub( r"(self\.llm\.(?:generate|chat)\([^\)]{1,})\)", - r"\1, lora_request = model.load_lora('" + lora_name + r"', load_tensors = True))", + r"\1, lora_request = self.model.load_lora('" + lora_name + r"', load_tensors = True))", source ) From 8b7c3af8c3f9270b410c9f20121a2dfa45a1a4e6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 04:29:04 -0800 Subject: [PATCH 200/473] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index e5101662b..515c6587f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -47,8 +47,8 @@ def PatchRL(FastLanguageModel): from contextlib import contextmanager @contextmanager - def unsloth_unwrap_model_for_generation(model, accelerator): - with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + def unsloth_unwrap_model_for_generation(model, *args, **kwargs): + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: # Put the model in inference mode. FastLanguageModel.for_inference(unwrapped_model) From 066ec25f187a4e39092bf980ae894a941258b4cd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 05:07:37 -0800 Subject: [PATCH 201/473] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index be7d2214a..2ec4adaa1 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.3" +__version__ = "2025.2.4" __all__ = [ "SUPPORTS_BFLOAT16", From 052b93f0d58f2ebfbc94a7f4d135809ba187554b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Feb 2025 19:19:51 -0800 Subject: [PATCH 202/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index f2b0da860..3b336664d 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1059,8 +1059,11 @@ def patch_sft_trainer_tokenizer(): if trainer_text is None: continue try: exec(trainer_text, globals()) - except: - raise RuntimeError(f"Unsloth: Please file a bug report! Error patching {trainer_name}") + except Exception as error: + raise RuntimeError( + f"Unsloth: Please file a bug report! Error patching {trainer_name}. Error:\n"\ + f"{str(error)}", + ) exec(f"trl.trainer.{trainer_name} = Unsloth{trainer_name}", globals()) pass From fdac0252ecf5173c043dd59bba3820ccbe199e7a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Feb 2025 19:21:58 -0800 Subject: [PATCH 203/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 3b336664d..cb8852a30 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1058,6 +1058,7 @@ def patch_sft_trainer_tokenizer(): trainer_text = patch_trl_tokenizer_processing_class(trainer_name) if trainer_text is None: continue try: + print(trainer_text) exec(trainer_text, globals()) except Exception as error: raise RuntimeError( From ade058e124890592c3f9fba86d785b7ebfdfdddf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Feb 2025 23:24:36 -0800 Subject: [PATCH 204/473] Better TRL handling --- unsloth/models/rl.py | 495 +++++++++++++++++++++++-------------------- 1 file changed, 264 insertions(+), 231 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 515c6587f..5d6117b70 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -16,29 +16,13 @@ "PatchFastRL", ] -METRICS_MOVE_TO_END = [ - "nll", - "aux", - "beta", - "alpha", -] import torch -try: - from transformers.utils.notebook import ( - IntervalStrategy, - NotebookTrainingTracker, - NotebookProgressCallback, - ) - HAS_NOTEBOOK = True -except: - HAS_NOTEBOOK = False -pass from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import inspect import os import re -import functools from unsloth_zoo.compiler import create_new_function +from unsloth_zoo.logging_utils import PatchRLStatistics def PatchRL(FastLanguageModel): @@ -78,219 +62,290 @@ def generate_with_clone(*args, **kwargs): trainers = [x for x in trainers if x.endswith("_trainer")] unwrap = "unwrap_model_for_generation" for trainer in trainers: - if hasattr(eval(f"trl.trainer.{trainer}"), unwrap): - exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}") + try: current_trainer = eval(f"trl.trainer.{trainer}") + except: continue + if hasattr(current_trainer, unwrap): + try: exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}") + except: continue pass pass -def NotebookProgressCallback_on_train_begin(Trainer_metrics): - def _NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs): - self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step" - self.training_loss = 0 - self.last_log = 0 - column_names = [self.first_column] + ["Training Loss"] - if args.eval_strategy != IntervalStrategy.NO: - column_names.append("Validation Loss") - column_names += [x.replace("/", " / ") for x in Trainer_metrics] - self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) - pass - return _NotebookProgressCallback_on_train_begin -pass +RLTrainer_replacement = ''' +from typing import * +from dataclasses import dataclass, field +@dataclass +class Unsloth{RLConfig_name}({RLConfig_name}): + """ + {__RLConfig_doc__} + """ + sampling_params: Optional[Any] = field( + default = None, + metadata = {{'help': 'vLLM SamplingParams'}}, + ) + def __init__({RLConfig_arguments}, + sampling_params = None + ): +{RLConfig_extra_args} + super().__init__({RLConfig_call_args}) +pass -def NotebookProgressCallback_on_log(Trainer_metrics): - def _NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs): - # Only for when there is no evaluation - if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: - values = {"Training Loss": logs["loss"]} - for metric in Trainer_metrics: - # Sometimes metric is not inside logs - try: values[metric.replace("/", " / ")] = logs[metric] - except: pass - pass - # First column is necessarily Step since we're not in epoch eval strategy - values["Step"] = state.global_step - self.training_tracker.write_line(values) - pass - pass - return _NotebookProgressCallback_on_log +{RLTrainer_extras} + +class Unsloth{RLTrainer_name}(_Unsloth{RLTrainer_name}): + """ + {__RLTrainer_doc__} + """ + def __init__({RLTrainer_arguments} + ): + if args is None: args = Unsloth{RLConfig_name}() +{RLTrainer_extra_args} + super().__init__({RLTrainer_call_args}) pass +''' +def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): + # Patch for vLLM and Unsloth PEFT + import trl + import trl.trainer + try: + trainer = eval(f"trl.trainer.{trainer_file}") + except Exception as error: + return + + # Get SFTTrainer and SFTConfig names + name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()] + config = [x for x in dir(trainer) if x.endswith("Config") and x != "Config" and trainer_file.split("_")[0] in x.lower()] + if len(name) != 1: return + if len(config) != 1: return + + # Get SFTTrainer, SFTConfig + RLTrainer_name = name[0] + RLConfig_name = config[0] + try: RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") + except: return + try: RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}" ) + except: return -def NotebookTrainingTracker_write_line(Trainer_metrics): - set_Trainer_metrics = set(Trainer_metrics) - def _NotebookTrainingTracker_write_line(self, values): - """ - Write the values in the inner table. - - Args: - values (`Dict[str, float]`): The values to display. - """ - if self.inner_table is None: - self.inner_table = [list(values.keys()), list(values.values())] - else: - columns = self.inner_table[0] - new_values = {} - for key, value in values.items(): - lowered = key.lower() - if lowered in set_Trainer_metrics: - new_values[lowered.replace("/", " / ")] = value - else: - new_values[key] = value - pass - values = new_values - - self.inner_table[0] = columns - if len(self.inner_table) > 1: - last_values = self.inner_table[-1] - first_column = self.inner_table[0][0] - if last_values[0] != values[first_column]: - # write new line - self.inner_table.append([values[c] if c in values else "No Log" for c in columns]) - else: - # update last line - new_values = values - for c in columns: - if c not in new_values.keys(): - new_values[c] = last_values[columns.index(c)] - self.inner_table[-1] = [new_values[c] for c in columns] - else: - # Edit for evaluation purposes - self.inner_table.append([values[c] if c in values else 0 for c in columns]) - pass - pass - pass - return _NotebookTrainingTracker_write_line -pass + # Check name + if RLTrainer.__name__.startswith("Unsloth"): return + if RLConfig .__name__.startswith("Unsloth"): return + all_imports = dir(trainer) + imports = [x for x in all_imports if not x.startswith("_")] -def _PatchRLStatistics(metrics, algorithm): - if HAS_NOTEBOOK: - if len(metrics) == 0: - raise RuntimeError(f"Unsloth: RL statistics for {algorithm} failed with no metrics seen?") - from transformers.trainer import is_in_notebook - if is_in_notebook(): - # Patch DPO notebook printing - NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line(metrics) - from transformers.trainer import DEFAULT_PROGRESS_CALLBACK - DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin(metrics) - DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log(metrics) + # Get default arguments + EMPTY = inspect.Parameter.empty + processed = [] + for RLobject in [RLTrainer, RLConfig]: + parameters = inspect.signature(RLobject.__init__).parameters + types = (bool, type(None), int, float, str,) + arguments = ["self"] + call_args = [] + for k, v in parameters.items(): + if k == "self": continue + v = v.default + if v == "\n": v = re.escape("\n") + if v is EMPTY: arguments.append(k) + elif type(v) is str: arguments.append(f"{k} = '{v}'") + elif type(v) in types: arguments.append(f"{k} = {v}") + else: continue + call_args.append(f"{k} = {k}") pass + arguments = f"\n{' '*8}" + f",\n{' '*8}".join(arguments) + call_args = f"\n{' '*12}" + f",\n{' '*12}".join(call_args) + processed.append((arguments, call_args,)) pass -pass + # Process RLTrainer first + arguments, call_args = processed[0] -@functools.cache -def get_trl_metrics(): - # Gets metrics so we can output them in notebooks + # Add tokenizer if not seen + if "tokenizer" not in parameters and "processing_class" in parameters: + arguments += f",\n{' '*8}tokenizer = None" + call_args = call_args.replace( + "processing_class = processing_class", + "processing_class = tokenizer if tokenizer is not None else processing_class", + ) + pass - import trl.trainer - trainers = dir(trl.trainer) - trainers = [x for x in trainers if x.endswith("_trainer")] - filepath = inspect.getfile(trl.trainer) - filepath = os.path.split(filepath)[0] + # Edit bf16, fp16 by checking model's torch_dtype directly + extra_args = "" + if "args" in call_args: + mixed_precision = \ + "use_bf16 = getattr(args, 'bf16', False)\n"\ + "use_fp16 = getattr(args, 'fp16', False)\n"\ + "dtype = getattr(model.config, 'torch_dtype', None)\n"\ + "if dtype is None: dtype = model.get_input_embeddings().dtype\n"\ + "from unsloth_zoo.utils import _get_dtype\n"\ + "dtype = _get_dtype(dtype)\n"\ + "float16 = dtype == torch.float16\n"\ + "if float16 and use_bf16: raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"\ + "if not float16 and use_fp16: raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\ + "if not use_bf16 and not use_fp16:\n"\ + " args.fp16 = float16\n"\ + " args.bf16 = not float16\n" + extra_args += mixed_precision + pass - all_metrics = dict() - for trainer in trainers: - filename = os.path.join(filepath, f"{trainer}.py") - if not os.path.exists(filename): continue - with open(filename, "r") as file: file = file.read() - - # Get metrics['kl'] or stats['kl'] - metrics = re.findall(r"metrics\[[\"\']([^\"\']{1,})[\"\']\]", file) - stats = re.findall(r"stats\[[\"\']([^\"\']{1,})[\"\']\]", file) - metrics = metrics + stats - - # Get optional f-strings - metrics_f = re.findall(r"metrics\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) - stats_f = re.findall(r"stats\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) - metrics_f = metrics_f + stats_f - # Filter out prefixes if seen - # metrics[f"{prefix}rewards/chosen"] - left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file - if left_prefix: metrics += metrics_f - - # Move all eval_ things to the end and reward to the front - beginning = [] - middle = [] - end = [] - for x in metrics: - lowered = x.lower() - if "reward" in lowered: - beginning.append(x) - elif x.lower().startswith("eval"): - end.append(x) - else: - # Check if we want to move to the end - moved = False - for move_end in METRICS_MOVE_TO_END: - if move_end in lowered: - end.append(x) - moved = True - break - if not moved: - middle.append(x) - pass + # Check if per_device_eval_batch_size (default 8) bigger than bsz + # Also use FP16 / BF16 evaluation + if "args" in call_args: + # Check eval_dataset first + if "eval_dataset" in call_args: + check_eval_dataset = \ + "if getattr(args, 'eval_strategy', 'no') == 'no':\n"\ + " args.eval_strategy = 'steps'\n"\ + " if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1\n" + extra_args += check_eval_dataset pass - metrics = beginning + middle + end - all_metrics[trainer[:trainer.find("_")].upper()] = metrics + eval_changes = \ + "ga_steps = getattr(args, 'gradient_accumulation_steps', None)\n"\ + "if getattr(args, 'eval_strategy', 'no') != 'no':\n"\ + " eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)\n"\ + " if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size\n"\ + " if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps\n"\ + "fp16_full_eval = getattr(args, 'fp16_full_eval', False)\n"\ + "bf16_full_eval = getattr(args, 'bf16_full_eval', False)\n"\ + "if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"\ + "if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"\ + "if not bf16_full_eval and not fp16_full_eval: args.bf16_full_eval = args.bf16; args.fp16_full_eval = args.fp16\n" + + extra_args += eval_changes pass - return all_metrics -pass + # Add statistics as well! + extra_args += \ + "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ + f"PatchRLStatistics('{trainer_file}')\n" + + # Create RLTrainer args + extra_args = extra_args.split("\n") + extra_args = "\n".join(" "*8 + x for x in extra_args) + RLTrainer_arguments = arguments + RLTrainer_extra_args = extra_args + RLTrainer_call_args = call_args + + # Fix RLConfig next + arguments, call_args = processed[1] + extra_args = "" + + # Edit GA / bsz and weight_decay + replacements = { + "output_dir" : 'unsloth_training_checkpoints', + "logging_nan_inf_filter" : False, + "per_device_train_batch_size" : 4, + "gradient_accumulation_steps" : 2, + "weight_decay" : 0.01, + "warmup_ratio" : 0.1, + "seed" : 3407, + "optim" : "adamw_8bit", + "learning_rate" : 5e-05, + "per_device_eval_batch_size" : 4, + "eval_accumulation_steps" : 2, + "torch_empty_cache_steps" : 250, + } + for k, v in replacements.items(): + x = f"{k}( = [^,\n]{{1,}})?,\n" + y = f"'{v}'" if type(v) is str else f"{v}" + y = f"{k} = {y},\n" + arguments = re.sub(x, y, arguments) + pass -def PatchRLStatistics(algorithm = "GRPO"): - # Get notebook statistics columns to show up - algorithm = algorithm.upper() - all_metrics = get_trl_metrics() - if algorithm not in all_metrics: - print( - f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.\n"\ - f"We support: `{list(all_metrics.keys())}`" - ) + # Warn on too large or too small learning rate + if " learning_rate" in call_args: + learning_rate_check = \ + "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ + "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')" + extra_args += learning_rate_check pass - _PatchRLStatistics(all_metrics[algorithm], algorithm) -pass + # Create RLConfig args + extra_args = extra_args.split("\n") + extra_args = "\n".join(" "*8 + x for x in extra_args) + RLConfig_arguments = arguments + RLConfig_extra_args = extra_args + RLConfig_call_args = call_args + + # Patch vLLM + RLTrainer_extras = patch_vllm(trainer_file, RLTrainer_name, all_imports, imports) + if RLTrainer_extras is None: + RLTrainer_extras = f"_Unsloth{RLTrainer_name} = {RLTrainer_name}" + + # Create full module + exec(f"from trl.trainer import ({RLTrainer_name}, {RLConfig_name},)") + __RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__ + __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ + + RLTrainer_source = RLTrainer_replacement.format( + RLTrainer_name = RLTrainer_name, + __RLTrainer_doc__ = __RLTrainer_doc__, + RLTrainer_arguments = RLTrainer_arguments, + RLTrainer_extra_args = RLTrainer_extra_args, + RLTrainer_call_args = RLTrainer_call_args, + + RLConfig_name = RLConfig_name, + __RLConfig_doc__ = __RLConfig_doc__, + RLConfig_arguments = RLConfig_arguments, + RLConfig_extra_args = RLConfig_extra_args, + RLConfig_call_args = RLConfig_call_args, + + RLTrainer_extras = RLTrainer_extras, + ) -def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): - # Patch for vLLM and Unsloth PEFT - import trl - import trl.trainer + # Create new function + created_module = create_new_function( + f"Unsloth{RLTrainer_name}", + RLTrainer_source, + f"trl.trainer.{trainer_file}", + imports, + ) + + # Patch Trainer + exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.trainer.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) + + # Patch Config + exec(f"trl.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals()) + exec(f"trl.trainer.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals()) + exec(f"trl.trainer.{trainer_file}.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals()) +pass - trainer = eval(f"trl.trainer.{trainer_file}") - name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()] - assert(len(name) == 1) - RLTrainer_name = name[0] - RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") - try: - __init__ = inspect.getsource(RLTrainer.__init__) - except: - # Already patched most likely! - return - old__init__ = __init__ - all_imports = dir(trainer) - assert("Union" in all_imports) - imports = [x for x in all_imports if not x.startswith("_")] - imports += ["Trainer"] +def patch_vllm(trainer_file, RLTrainer_name, all_imports, imports): + RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") + init = inspect.getsource(RLTrainer.__init__) + old_init = init - spaces = __init__.find("def") - __init__ = __init__.split("\n") - __init__ = "\n".join(x[spaces:] for x in __init__) + # Remove peft_config + init = init.replace("elif peft_config is None:", "elif False:") + init = init.replace("elif peft_config is not None:", "elif False:") + init = init.replace("if peft_config is None:", "if False:") + init = init.replace("if peft_config is not None:", "if False:") + init = init.replace("get_peft_model(model, peft_config)", "model") + + # Set use_vllm if not set + init = re.sub( + r"\)([ ]{0,}\-\>[ ]{0,}None[ ]{0,}):\n([\s]{4})", + r"):\n\2 "\ + r"if hasattr(model, 'vllm_engine') and "\ + r"getattr(args, 'use_vllm') and getattr(args, 'use_vllm', False): "\ + r"args.use_vllm = True\n\2", + init, 1, + ) - # Replace vLLM sections since we already have it done! vllm_part = re.findall( - r"(\n[\s]{4}"\ + r"(\n[\s]{8}"\ r"if (self|args)\.use_vllm\:.+?"\ - r"\n[\s]{4,}"\ + r"\n[\s]{8,}"\ "else:\n)", - __init__, + init, flags = re.MULTILINE | re.DOTALL, ) - if (len(vllm_part) != 1): return + if len(vllm_part) != 1: return None vllm_part, args = vllm_part[0][0], vllm_part[0][1] # Strip all comments @@ -303,40 +358,31 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): new_vllm_part, flags = re.MULTILINE | re.DOTALL, ) - if len(sampling_params) != 1: return + if len(sampling_params) != 1: return None sampling_params = sampling_params[0] # Replace with our vLLM engine sampling_params = \ - " "*8 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ + " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ sampling_params # Add spaces - new_vllm_part = f"\n if {args}.use_vllm:\n{sampling_params}\n else:\n" - __init__ = __init__.replace(vllm_part, new_vllm_part) - - # Remove peft_config - __init__ = __init__.replace("elif peft_config is None:", "elif False:") - __init__ = __init__.replace("elif peft_config is not None:", "elif False:") - __init__ = __init__.replace("if peft_config is None:", "if False:") - __init__ = __init__.replace("if peft_config is not None:", "if False:") - __init__ = __init__.replace("get_peft_model(model, peft_config)", "model") - - # Add spaces back into __init__ - __init__ = __init__.split("\n") - __init__ = "\n".join(' '*spaces + x for x in __init__) + new_vllm_part = \ + f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ + f"if getattr(args, 'sampling_params', None) is None else "\ + f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" + init = init.replace(vllm_part, new_vllm_part) # Search for vLLM calling in all child functions functions = dir(RLTrainer) RLTrainer_source = inspect.getsource(RLTrainer) functions = [x for x in functions if f"def {x}" in RLTrainer_source] - changed = {"__init__" : (old__init__, __init__,)} + changed = {"__init__" : (old_init, init,)} + for function in functions: if not hasattr(RLTrainer, function): continue fx = getattr(RLTrainer, function) - try: - source = inspect.getsource(fx) - except: - continue + try: source = inspect.getsource(fx) + except: continue original_source = source # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model @@ -386,22 +432,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source = RLTrainer_source.replace(old, new) pass RLTrainer_source = RLTrainer_source.replace( - f"class {RLTrainer_name}", f"class Unsloth{RLTrainer_name}", 1 - ) - - # Create new class in compiled cache and import it - module = create_new_function( - RLTrainer_name, - RLTrainer_source, - f"trl.trainer.{trainer_file}", - imports, + f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 ) - - # Patch over modules - exec(f"trl.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) - exec(f"trl.trainer.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) - exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) - return module + return RLTrainer_source pass From 15073c063f2eb91110de07e7309893edfa6f8824 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Feb 2025 23:25:37 -0800 Subject: [PATCH 205/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5d6117b70..e89e657fa 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -246,6 +246,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "per_device_eval_batch_size" : 4, "eval_accumulation_steps" : 2, "torch_empty_cache_steps" : 250, + "logging_steps" : 1, } for k, v in replacements.items(): x = f"{k}( = [^,\n]{{1,}})?,\n" From 0c54b1e0d2fa43d8154875de44e99a6c2b0c94d9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Feb 2025 23:30:08 -0800 Subject: [PATCH 206/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 44 -------------------------------------- 1 file changed, 44 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index cb8852a30..cfaf6cebe 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -907,35 +907,6 @@ def neftune_post_forward_hook(module, input, output): pass -def patch_trl_tokenizer_processing_class(trainer_name): - # New TRL removes tokenizer! - # We return it back! - exec(f"from trl import {trainer_name}", globals()) - if str(eval(f"{trainer_name}").__name__).startswith("Unsloth"): return None - parameters = eval(f"inspect.signature({trainer_name}).parameters") - if "tokenizer" in parameters: return None - - args = { - key : \ - value.default \ - if type(value.default) is not str else \ - f"'{value.default}'" \ - for key, value in parameters.items() - } - args["tokenizer"] = None - new_args = args.copy() - del new_args["tokenizer"] - del new_args["processing_class"] - new_args = ",\n".join(f"{' '*12}{key} = {key}" for key in new_args) + \ - f",\n{' '*12}processing_class = tokenizer if tokenizer else processing_class" - args = ",\n".join(f"{' '*8}{key} = {value}" for key, value in args.items()) - args = f"def __init__(\n" + f"{' '*8}self,\n" + args + "):" - args += f"\n{' '*8}\n{' '*8}super().__init__(\n{new_args}\n{' '*8})" - new_class = f"""class Unsloth{trainer_name}({trainer_name}):\n{' '*4}{args}\n""" - return new_class -pass - - def patch_sft_trainer_tokenizer(): """ Patches the trainer with changes @@ -1053,20 +1024,5 @@ def patch_sft_trainer_tokenizer(): pass pass -# Fix TRL trainers with removed tokenizer args (got replaced with processing_class) -for trainer_name in ("SFTTrainer", "DPOTrainer", "KTOTrainer"): - trainer_text = patch_trl_tokenizer_processing_class(trainer_name) - if trainer_text is None: continue - try: - print(trainer_text) - exec(trainer_text, globals()) - except Exception as error: - raise RuntimeError( - f"Unsloth: Please file a bug report! Error patching {trainer_name}. Error:\n"\ - f"{str(error)}", - ) - exec(f"trl.trainer.{trainer_name} = Unsloth{trainer_name}", globals()) -pass - # FInally patch TRL tokenizer things patch_sft_trainer_tokenizer() From a820ac655c50e98efe8c67d4a49cc540200f09d1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Feb 2025 23:33:15 -0800 Subject: [PATCH 207/473] Auto patching --- unsloth/models/llama.py | 2 ++ unsloth/models/rl.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index a337472a3..c50f65e4b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2739,3 +2739,5 @@ def for_training(model, use_gradient_checkpointing = True): pass pass +from .rl import PatchFastRL +PatchFastRL(FastLanguageModel = FastLlamaModel) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index e89e657fa..31a745e0d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -453,5 +453,5 @@ def patch_trl_rl_trainers(): def PatchFastRL(algorithm = "GRPO", FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) patch_trl_rl_trainers() - PatchRLStatistics(algorithm) + if algorithm is nont None: PatchRLStatistics(algorithm) pass From 15c52200979b958898f727d9ce7864092505d8c0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:06:08 -0800 Subject: [PATCH 208/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index cfaf6cebe..0b01ffff7 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -911,11 +911,14 @@ def patch_sft_trainer_tokenizer(): """ Patches the trainer with changes """ + sft_trainer = eval(f"trl.trainer.sft_trainer.SFTTrainer") for function_name, replacer in ( - ("_prepare_non_packed_dataloader", "def tokenize(element):",), + ("_prepare_non_packed_dataloader", "def tokenize(element):", "_prepare_dataset",), # ("_prepare_packed_dataloader", "if dataset_text_field is not None",), ): - function = getsource(eval(f"trl.trainer.sft_trainer.SFTTrainer.{function_name}")) + if not hasattr(sft_trainer, function_name): continue + + function = getsource(eval(f"{sft_trainer}.{function_name}")) where = function.find("def") function = function.split("\n") function = "\n".join(x[where:] for x in function) @@ -924,14 +927,20 @@ def patch_sft_trainer_tokenizer(): "\n"\ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ "if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\ + "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"\ "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"\ "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"\ "chat_template = getattr(tokenizer, 'chat_template', None)\n"\ "chat_template = '' if chat_template is None else chat_template\n"\ "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\ "if getattr(tokenizer, 'bos_token', None) is not None else False\n"\ - "add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" - + "if 'add_special_tokens' not in locals() and has_bos_token_already:\n"\ + " from functools import partial\n"\ + " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\ + " processing_class = tokenizer\n"\ + "else:\n"\ + " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" + check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) From 92a9f0b9604c9dd0ba368acf75a41942fa45eada Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:22:02 -0800 Subject: [PATCH 209/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 0b01ffff7..54f0e66c7 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -911,14 +911,19 @@ def patch_sft_trainer_tokenizer(): """ Patches the trainer with changes """ - sft_trainer = eval(f"trl.trainer.sft_trainer.SFTTrainer") + try: + sft_trainer = eval(f"trl.trainer.sft_trainer.SFTTrainer") + except: + all_imports = dir(trl.trainer.sft_trainer) + for function_name, replacer in ( - ("_prepare_non_packed_dataloader", "def tokenize(element):", "_prepare_dataset",), + ("_prepare_non_packed_dataloader", "def tokenize(element):",), + ("_prepare_dataset", None,), # ("_prepare_packed_dataloader", "if dataset_text_field is not None",), ): if not hasattr(sft_trainer, function_name): continue - function = getsource(eval(f"{sft_trainer}.{function_name}")) + function = getsource(eval(f"sft_trainer.{function_name}")) where = function.find("def") function = function.split("\n") function = "\n".join(x[where:] for x in function) @@ -940,14 +945,28 @@ def patch_sft_trainer_tokenizer(): " processing_class = tokenizer\n"\ "else:\n"\ " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" - + check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) - function = function.replace(replacer, check_text + replacer) - exec(function, globals()) + if replacer is None: + replacer = re.findall( + f"def {function_name}\(.+?\).+?\:\n", + function, + flags = re.MULTILINE | re.DOTALL, + ) + if len(replacer) == 0: continue + replacer = replacer[0] + function = function.replace(replacer, replacer + check_text) + else: + function = function.replace(replacer, check_text + replacer) + pass + x = [x for x in all_imports if x in function] + exec(f"from trl.trainer.sft_trainer import ({','.join(x)})", locals()) + exec(function, locals(), globals()) exec(f"trl.trainer.sft_trainer.SFTTrainer.{function_name} = {function_name}", globals()) + print("Patched") pass # Patch train with fix_untrained_tokens From 61b185304a626affb0f1121450d6cd2cff0a0137 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:23:24 -0800 Subject: [PATCH 210/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 54f0e66c7..78494f8ef 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -914,6 +914,7 @@ def patch_sft_trainer_tokenizer(): try: sft_trainer = eval(f"trl.trainer.sft_trainer.SFTTrainer") except: + return all_imports = dir(trl.trainer.sft_trainer) for function_name, replacer in ( From ea8739d3637847054a0b7cbe1d6f67ef223ca955 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:24:31 -0800 Subject: [PATCH 211/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 31a745e0d..bb99f6c88 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -453,5 +453,5 @@ def patch_trl_rl_trainers(): def PatchFastRL(algorithm = "GRPO", FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) patch_trl_rl_trainers() - if algorithm is nont None: PatchRLStatistics(algorithm) + if algorithm is not None: PatchRLStatistics(algorithm) pass From 61699bf7e7c39d363d90dca02b8fe6cff74dc862 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:36:12 -0800 Subject: [PATCH 212/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 78494f8ef..5f904ad7d 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -917,7 +917,7 @@ def patch_sft_trainer_tokenizer(): return all_imports = dir(trl.trainer.sft_trainer) - for function_name, replacer in ( + for (function_name, replacer,) in ( ("_prepare_non_packed_dataloader", "def tokenize(element):",), ("_prepare_dataset", None,), # ("_prepare_packed_dataloader", "if dataset_text_field is not None",), @@ -962,12 +962,12 @@ def patch_sft_trainer_tokenizer(): else: function = function.replace(replacer, check_text + replacer) pass + print(function) x = [x for x in all_imports if x in function] exec(f"from trl.trainer.sft_trainer import ({','.join(x)})", locals()) exec(function, locals(), globals()) exec(f"trl.trainer.sft_trainer.SFTTrainer.{function_name} = {function_name}", globals()) - print("Patched") pass # Patch train with fix_untrained_tokens From acbf23fe110b76b883c46c6954ec631354855873 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:37:08 -0800 Subject: [PATCH 213/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index bb99f6c88..50f979558 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -295,6 +295,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_extras = RLTrainer_extras, ) + print(RLTrainer_source) # Create new function created_module = create_new_function( From b1b9af323e152dcebb63113e3582cd2256a0cfac Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:42:47 -0800 Subject: [PATCH 214/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 5f904ad7d..c35d990d0 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -918,7 +918,8 @@ def patch_sft_trainer_tokenizer(): all_imports = dir(trl.trainer.sft_trainer) for (function_name, replacer,) in ( - ("_prepare_non_packed_dataloader", "def tokenize(element):",), + # ("_prepare_non_packed_dataloader", "def tokenize(element):",), + ("_prepare_non_packed_dataloader", None,), ("_prepare_dataset", None,), # ("_prepare_packed_dataloader", "if dataset_text_field is not None",), ): From fee37b0c61b14946aea7e255f6d3ad2123892b21 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:14:06 -0800 Subject: [PATCH 215/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index c35d990d0..e2ba5fab7 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -952,8 +952,9 @@ def patch_sft_trainer_tokenizer(): check_text = "\n".join(" "*where + x for x in check_text) if replacer is None: + # .*? matches first match. .+? matches final match. replacer = re.findall( - f"def {function_name}\(.+?\).+?\:\n", + f"def {function_name}\(.*?\).*?\:\n", function, flags = re.MULTILINE | re.DOTALL, ) From ff27094cddc6c090b15c0887b72a0dbc1c9377e3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:17:13 -0800 Subject: [PATCH 216/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index e2ba5fab7..7e4baa60e 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -946,7 +946,7 @@ def patch_sft_trainer_tokenizer(): " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\ " processing_class = tokenizer\n"\ "else:\n"\ - " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" + " add_special_tokens = False if has_bos_token_already else add_special_tokens" check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) From 6ab51bedae69f1e0ebd4455d71a4a7f48b2478c7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:22:58 -0800 Subject: [PATCH 217/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 7e4baa60e..dcdd5c662 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -946,8 +946,9 @@ def patch_sft_trainer_tokenizer(): " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\ " processing_class = tokenizer\n"\ "else:\n"\ - " add_special_tokens = False if has_bos_token_already else add_special_tokens" - + " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" + f"{' '*4}" + check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) From b45f633d9547274c9300f2a80329029002d9120f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:25:17 -0800 Subject: [PATCH 218/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index dcdd5c662..4c5737788 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -947,8 +947,7 @@ def patch_sft_trainer_tokenizer(): " processing_class = tokenizer\n"\ "else:\n"\ " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" - f"{' '*4}" - + check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) @@ -961,6 +960,9 @@ def patch_sft_trainer_tokenizer(): ) if len(replacer) == 0: continue replacer = replacer[0] + print("====") + print(check_text) + print("====") function = function.replace(replacer, replacer + check_text) else: function = function.replace(replacer, check_text + replacer) From fd9e67774e43c702330ac0649ddd28e84c750d28 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:27:50 -0800 Subject: [PATCH 219/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 4c5737788..3d8a51738 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -961,7 +961,7 @@ def patch_sft_trainer_tokenizer(): if len(replacer) == 0: continue replacer = replacer[0] print("====") - print(check_text) + print(replacer) print("====") function = function.replace(replacer, replacer + check_text) else: From b9b3166dbdae79bed2cb23c5500cdbb0baa56d25 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:28:28 -0800 Subject: [PATCH 220/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 3d8a51738..2062df480 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -950,7 +950,8 @@ def patch_sft_trainer_tokenizer(): check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) - + check_text = check_text.rstrip() + "\n" + if replacer is None: # .*? matches first match. .+? matches final match. replacer = re.findall( From 7fdab17eae6124507191c672c8f105b18d4cf4d0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:30:02 -0800 Subject: [PATCH 221/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 2062df480..5226c3c5b 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -951,7 +951,7 @@ def patch_sft_trainer_tokenizer(): check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) check_text = check_text.rstrip() + "\n" - + if replacer is None: # .*? matches first match. .+? matches final match. replacer = re.findall( @@ -961,14 +961,10 @@ def patch_sft_trainer_tokenizer(): ) if len(replacer) == 0: continue replacer = replacer[0] - print("====") - print(replacer) - print("====") function = function.replace(replacer, replacer + check_text) else: function = function.replace(replacer, check_text + replacer) pass - print(function) x = [x for x in all_imports if x in function] exec(f"from trl.trainer.sft_trainer import ({','.join(x)})", locals()) From 259597163f5a7056ce251460694fbe206f991010 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:33:43 -0800 Subject: [PATCH 222/473] Update rl.py --- unsloth/models/rl.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 50f979558..c4122f7aa 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -107,18 +107,21 @@ def __init__({RLTrainer_arguments} def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Patch for vLLM and Unsloth PEFT + print(1) import trl import trl.trainer try: trainer = eval(f"trl.trainer.{trainer_file}") except Exception as error: return + print(2) # Get SFTTrainer and SFTConfig names name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()] config = [x for x in dir(trainer) if x.endswith("Config") and x != "Config" and trainer_file.split("_")[0] in x.lower()] if len(name) != 1: return if len(config) != 1: return + print(3) # Get SFTTrainer, SFTConfig RLTrainer_name = name[0] @@ -127,6 +130,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): except: return try: RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}" ) except: return + print(4) # Check name if RLTrainer.__name__.startswith("Unsloth"): return @@ -134,6 +138,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): all_imports = dir(trainer) imports = [x for x in all_imports if not x.startswith("_")] + print(5) # Get default arguments EMPTY = inspect.Parameter.empty @@ -157,6 +162,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): call_args = f"\n{' '*12}" + f",\n{' '*12}".join(call_args) processed.append((arguments, call_args,)) pass + print(6) # Process RLTrainer first arguments, call_args = processed[0] @@ -274,11 +280,13 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_extras = patch_vllm(trainer_file, RLTrainer_name, all_imports, imports) if RLTrainer_extras is None: RLTrainer_extras = f"_Unsloth{RLTrainer_name} = {RLTrainer_name}" + print(7) # Create full module exec(f"from trl.trainer import ({RLTrainer_name}, {RLConfig_name},)") __RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__ __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ + print(8) RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, @@ -295,7 +303,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_extras = RLTrainer_extras, ) - print(RLTrainer_source) # Create new function created_module = create_new_function( @@ -304,6 +311,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): f"trl.trainer.{trainer_file}", imports, ) + print(9) # Patch Trainer exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) From f470f55e9b571977c9b2455bf04c3855ac62666c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:36:30 -0800 Subject: [PATCH 223/473] Update rl.py --- unsloth/models/rl.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c4122f7aa..112ba5d70 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -107,21 +107,18 @@ def __init__({RLTrainer_arguments} def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Patch for vLLM and Unsloth PEFT - print(1) import trl import trl.trainer try: trainer = eval(f"trl.trainer.{trainer_file}") except Exception as error: return - print(2) # Get SFTTrainer and SFTConfig names name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()] config = [x for x in dir(trainer) if x.endswith("Config") and x != "Config" and trainer_file.split("_")[0] in x.lower()] if len(name) != 1: return if len(config) != 1: return - print(3) # Get SFTTrainer, SFTConfig RLTrainer_name = name[0] @@ -130,7 +127,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): except: return try: RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}" ) except: return - print(4) # Check name if RLTrainer.__name__.startswith("Unsloth"): return @@ -138,7 +134,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): all_imports = dir(trainer) imports = [x for x in all_imports if not x.startswith("_")] - print(5) # Get default arguments EMPTY = inspect.Parameter.empty @@ -162,7 +157,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): call_args = f"\n{' '*12}" + f",\n{' '*12}".join(call_args) processed.append((arguments, call_args,)) pass - print(6) # Process RLTrainer first arguments, call_args = processed[0] @@ -277,16 +271,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLConfig_call_args = call_args # Patch vLLM - RLTrainer_extras = patch_vllm(trainer_file, RLTrainer_name, all_imports, imports) + RLTrainer_extras = patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports) if RLTrainer_extras is None: RLTrainer_extras = f"_Unsloth{RLTrainer_name} = {RLTrainer_name}" - print(7) # Create full module exec(f"from trl.trainer import ({RLTrainer_name}, {RLConfig_name},)") __RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__ __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ - print(8) RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, @@ -311,7 +303,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): f"trl.trainer.{trainer_file}", imports, ) - print(9) # Patch Trainer exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) @@ -326,6 +317,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): def patch_vllm(trainer_file, RLTrainer_name, all_imports, imports): + import trl.trainer RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") init = inspect.getsource(RLTrainer.__init__) old_init = init From ddfdca112c03c884ea3549c9748efd200ed3bbb1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:36:57 -0800 Subject: [PATCH 224/473] Update rl.py --- unsloth/models/rl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 112ba5d70..81e929aac 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -316,9 +316,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pass -def patch_vllm(trainer_file, RLTrainer_name, all_imports, imports): - import trl.trainer - RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") +def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): init = inspect.getsource(RLTrainer.__init__) old_init = init From 3e0c7e2a329762c2115e6ec18f2d5abc20926161 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:39:51 -0800 Subject: [PATCH 225/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 81e929aac..3682c71d7 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -449,7 +449,7 @@ def patch_trl_rl_trainers(): pass -def PatchFastRL(algorithm = "GRPO", FastLanguageModel = None): +def PatchFastRL(algorithm = None, FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) patch_trl_rl_trainers() if algorithm is not None: PatchRLStatistics(algorithm) From ae3f2191a17d750a0dc11a41cbd2611b7fac1933 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:04:05 -0800 Subject: [PATCH 226/473] Update rl.py --- unsloth/models/rl.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3682c71d7..b59381640 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -85,10 +85,12 @@ class Unsloth{RLConfig_name}({RLConfig_name}): metadata = {{'help': 'vLLM SamplingParams'}}, ) def __init__({RLConfig_arguments}, - sampling_params = None + sampling_params = None, + *args, **kwargs, ): {RLConfig_extra_args} - super().__init__({RLConfig_call_args}) + super().__init__({RLConfig_call_args}, + *args, **kwargs) pass {RLTrainer_extras} @@ -97,11 +99,13 @@ class Unsloth{RLTrainer_name}(_Unsloth{RLTrainer_name}): """ {__RLTrainer_doc__} """ - def __init__({RLTrainer_arguments} + def __init__({RLTrainer_arguments}, + *args, **kwargs, ): if args is None: args = Unsloth{RLConfig_name}() {RLTrainer_extra_args} - super().__init__({RLTrainer_call_args}) + super().__init__({RLTrainer_call_args}, + *args, **kwargs) pass ''' From 5e71435654124f1dbf43a0f3a743053a09db822f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:05:44 -0800 Subject: [PATCH 227/473] Update rl.py --- unsloth/models/rl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b59381640..7e3282320 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -86,11 +86,11 @@ class Unsloth{RLConfig_name}({RLConfig_name}): ) def __init__({RLConfig_arguments}, sampling_params = None, - *args, **kwargs, + **kwargs, ): {RLConfig_extra_args} super().__init__({RLConfig_call_args}, - *args, **kwargs) + **kwargs) pass {RLTrainer_extras} @@ -100,12 +100,12 @@ class Unsloth{RLTrainer_name}(_Unsloth{RLTrainer_name}): {__RLTrainer_doc__} """ def __init__({RLTrainer_arguments}, - *args, **kwargs, + **kwargs, ): if args is None: args = Unsloth{RLConfig_name}() {RLTrainer_extra_args} super().__init__({RLTrainer_call_args}, - *args, **kwargs) + **kwargs) pass ''' From 883192ddfd3d94033233d971e8255f55f5be0280 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:08:04 -0800 Subject: [PATCH 228/473] Update rl.py --- unsloth/models/rl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 7e3282320..28352b415 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -89,8 +89,7 @@ def __init__({RLConfig_arguments}, **kwargs, ): {RLConfig_extra_args} - super().__init__({RLConfig_call_args}, - **kwargs) + super().__init__({RLConfig_call_args}{RLConfig_kwargs}) pass {RLTrainer_extras} @@ -100,12 +99,11 @@ class Unsloth{RLTrainer_name}(_Unsloth{RLTrainer_name}): {__RLTrainer_doc__} """ def __init__({RLTrainer_arguments}, - **kwargs, + **kwargs ): if args is None: args = Unsloth{RLConfig_name}() {RLTrainer_extra_args} - super().__init__({RLTrainer_call_args}, - **kwargs) + super().__init__({RLTrainer_call_args}{RLTrainer_kwargs}) pass ''' @@ -290,12 +288,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_arguments = RLTrainer_arguments, RLTrainer_extra_args = RLTrainer_extra_args, RLTrainer_call_args = RLTrainer_call_args, + RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args.endswith(",") else 0:], RLConfig_name = RLConfig_name, __RLConfig_doc__ = __RLConfig_doc__, RLConfig_arguments = RLConfig_arguments, RLConfig_extra_args = RLConfig_extra_args, RLConfig_call_args = RLConfig_call_args, + RLConfig_kwargs = ",**kwargs"[1 if RLConfig_call_args .endswith(",") else 0:], RLTrainer_extras = RLTrainer_extras, ) From 22c1cc1ba5a146d032ca83ea7706fad6e85d64cd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:16:07 -0800 Subject: [PATCH 229/473] Update rl.py --- unsloth/models/rl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 28352b415..30786ab6c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -340,6 +340,7 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): r"args.use_vllm = True\n\2", init, 1, ) + print(init) vllm_part = re.findall( r"(\n[\s]{8}"\ @@ -354,6 +355,7 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): vllm_part, args = vllm_part[0][0], vllm_part[0][1] # Strip all comments new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part) + print(new_vllm_part) # Get SamplingParams sampling_params = re.findall( @@ -363,6 +365,7 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): flags = re.MULTILINE | re.DOTALL, ) if len(sampling_params) != 1: return None + print(sampling_params) sampling_params = sampling_params[0] # Replace with our vLLM engine From 3fabc11a9cc4a2dc007b802a1125cdddfcd1a04e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:20:41 -0800 Subject: [PATCH 230/473] Update rl.py --- unsloth/models/rl.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 30786ab6c..225e0e48f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -340,7 +340,6 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): r"args.use_vllm = True\n\2", init, 1, ) - print(init) vllm_part = re.findall( r"(\n[\s]{8}"\ @@ -355,7 +354,6 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): vllm_part, args = vllm_part[0][0], vllm_part[0][1] # Strip all comments new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part) - print(new_vllm_part) # Get SamplingParams sampling_params = re.findall( @@ -364,19 +362,19 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): new_vllm_part, flags = re.MULTILINE | re.DOTALL, ) - if len(sampling_params) != 1: return None - print(sampling_params) - - sampling_params = sampling_params[0] - # Replace with our vLLM engine - sampling_params = \ - " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ - sampling_params # Add spaces - new_vllm_part = \ - f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ - f"if getattr(args, 'sampling_params', None) is None else "\ - f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" - init = init.replace(vllm_part, new_vllm_part) + print(len(sampling_params), RLTrainer_name) + if len(sampling_params) == 1: + sampling_params = sampling_params[0] + # Replace with our vLLM engine + sampling_params = \ + " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ + sampling_params # Add spaces + new_vllm_part = \ + f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ + f"if getattr(args, 'sampling_params', None) is None else "\ + f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" + init = init.replace(vllm_part, new_vllm_part) + pass # Search for vLLM calling in all child functions functions = dir(RLTrainer) From d9687d59ed85979567c579be6fee280319b274ff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:21:59 -0800 Subject: [PATCH 231/473] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 5226c3c5b..82e82eb68 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -945,6 +945,7 @@ def patch_sft_trainer_tokenizer(): " from functools import partial\n"\ " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\ " processing_class = tokenizer\n"\ + " print(1111)\n" "else:\n"\ " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" From 47373802a5829c8b5e5eb2c533e8a2fcd4ba5590 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:23:05 -0800 Subject: [PATCH 232/473] Update rl.py --- unsloth/models/rl.py | 50 ++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 225e0e48f..9b8b410f4 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -349,31 +349,31 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): init, flags = re.MULTILINE | re.DOTALL, ) - if len(vllm_part) != 1: return None - - vllm_part, args = vllm_part[0][0], vllm_part[0][1] - # Strip all comments - new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part) - - # Get SamplingParams - sampling_params = re.findall( - r"\n[\s]{4,}(self\.[^\s]{1,}[\s]{0,}\=[\s]{0,}"\ - r"SamplingParams\(.+?\))", - new_vllm_part, - flags = re.MULTILINE | re.DOTALL, - ) - print(len(sampling_params), RLTrainer_name) - if len(sampling_params) == 1: - sampling_params = sampling_params[0] - # Replace with our vLLM engine - sampling_params = \ - " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ - sampling_params # Add spaces - new_vllm_part = \ - f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ - f"if getattr(args, 'sampling_params', None) is None else "\ - f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" - init = init.replace(vllm_part, new_vllm_part) + if len(vllm_part) == 1: + vllm_part, args = vllm_part[0][0], vllm_part[0][1] + # Strip all comments + new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part) + + # Get SamplingParams + sampling_params = re.findall( + r"\n[\s]{4,}(self\.[^\s]{1,}[\s]{0,}\=[\s]{0,}"\ + r"SamplingParams\(.+?\))", + new_vllm_part, + flags = re.MULTILINE | re.DOTALL, + ) + print(sampling_params) + if len(sampling_params) == 1: + sampling_params = sampling_params[0] + # Replace with our vLLM engine + sampling_params = \ + " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ + sampling_params # Add spaces + new_vllm_part = \ + f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ + f"if getattr(args, 'sampling_params', None) is None else "\ + f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" + init = init.replace(vllm_part, new_vllm_part) + pass pass # Search for vLLM calling in all child functions From 6abf22a253bef80407f3308c9792947fcb2fc85d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:25:37 -0800 Subject: [PATCH 233/473] Update rl.py --- unsloth/models/rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 9b8b410f4..418741707 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -361,7 +361,6 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): new_vllm_part, flags = re.MULTILINE | re.DOTALL, ) - print(sampling_params) if len(sampling_params) == 1: sampling_params = sampling_params[0] # Replace with our vLLM engine From 5edcdf80454685ab7048010674d81f679cc1bfb5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 14:11:33 -0800 Subject: [PATCH 234/473] Update rl.py --- unsloth/models/rl.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 418741707..5ec418dda 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -74,6 +74,7 @@ def generate_with_clone(*args, **kwargs): RLTrainer_replacement = ''' from typing import * from dataclasses import dataclass, field +from packaging.version import Version @dataclass class Unsloth{RLConfig_name}({RLConfig_name}): @@ -197,14 +198,25 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Check eval_dataset first if "eval_dataset" in call_args: check_eval_dataset = \ - "if getattr(args, 'eval_strategy', 'no') == 'no':\n"\ + "if getattr(args, 'eval_dataset', None) is not None and "\ + "getattr(args, 'eval_strategy', 'no') == 'no':\n"\ " args.eval_strategy = 'steps'\n"\ " if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1\n" extra_args += check_eval_dataset pass - eval_changes = \ + # Check if gradient accumulation bug fix is applied + check_ga = \ "ga_steps = getattr(args, 'gradient_accumulation_steps', None)\n"\ + "if ga_steps is not None and ga_steps > 1:\n"\ + " from transformers import __version__ as transformers_version\n"\ + " if Version(transformers_version) <= Version('4.45.2'):\n"\ + " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\n"\ + " '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')\n" + + extra_args += check_ga + + eval_changes = \ "if getattr(args, 'eval_strategy', 'no') != 'no':\n"\ " eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)\n"\ " if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size\n"\ @@ -236,7 +248,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Edit GA / bsz and weight_decay replacements = { - "output_dir" : 'unsloth_training_checkpoints', + "output_dir" : None, "logging_nan_inf_filter" : False, "per_device_train_batch_size" : 4, "gradient_accumulation_steps" : 2, @@ -265,6 +277,16 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += learning_rate_check pass + # Add output_dir saving + if "output_dir" in call_args: + # Default checks + saving_check = \ + "if output_dir is None and save_strategy == 'steps' and save_steps == 500:\n"\ + " output_dir = 'unsloth_training_checkpoints'\n"\ + " save_strategy = 'no'\n" + extra_args += saving_check + pass + # Create RLConfig args extra_args = extra_args.split("\n") extra_args = "\n".join(" "*8 + x for x in extra_args) From 7e55aef9da37607417146890aad50f7bd4d57007 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 14:22:19 -0800 Subject: [PATCH 235/473] max seq length --- unsloth/models/llama.py | 6 +++--- unsloth/models/rl.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c50f65e4b..5583702e7 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1952,13 +1952,13 @@ def from_pretrained( Trainer._inner_training_loop = _fast_inner_training_loop # Save max_seq_length - model.max_seq_length = max_position_embeddings + model.max_seq_length = max_seq_length internal_model = model while hasattr(internal_model, "model"): - internal_model.max_seq_length = max_position_embeddings + internal_model.max_seq_length = max_seq_length internal_model = internal_model.model pass - internal_model.max_seq_length = max_position_embeddings + internal_model.max_seq_length = max_seq_length # We check the tokenizer first for errors if fix_tokenizer: diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5ec418dda..dad658170 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -287,6 +287,25 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += saving_check pass + # Edit dataset_num_proc + if "dataset_num_proc" in call_args: + num_proc_check = \ + "if dataset_num_proc is None:\n"\ + " from multiprocessing import cpu_count\n"\ + " dataset_num_proc = cpu_count()\n" + extra_args += num_proc_check + pass + + # Check max_seq_length + if "max_seq_length" in call_args: + length_check = \ + "if hasattr(model, 'max_seq_length') and model.max_seq_length > max_seq_length:\n"\ + " print('Unsloth: You set `max_seq_length` as ' + str(max_seq_length) + ' but the\\n'"\ + " 'model maximum sequence length is ' + str(model.max_seq_length) + '. We will reduce it.')\n" + " max_seq_length = model.max_seq_length\n" + extra_args += length_check + pass + # Create RLConfig args extra_args = extra_args.split("\n") extra_args = "\n".join(" "*8 + x for x in extra_args) From 6a21b5039ffefbc678bd8b3196658ce04e68852a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 14:27:31 -0800 Subject: [PATCH 236/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index dad658170..a098c896f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -273,7 +273,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if " learning_rate" in call_args: learning_rate_check = \ "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ - "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')" + "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n" extra_args += learning_rate_check pass From 035d24e6d42b2d705e5312e97b52859e77852a63 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 15:00:44 -0800 Subject: [PATCH 237/473] Update rl.py --- unsloth/models/rl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index a098c896f..0c34f5002 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -272,8 +272,10 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Warn on too large or too small learning rate if " learning_rate" in call_args: learning_rate_check = \ - "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ - "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n" + "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! '"\ + "'Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ + "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! '"\ + "Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n" extra_args += learning_rate_check pass From b67327bf3eb559ed15058a73d9c317327935a3c4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 15:11:16 -0800 Subject: [PATCH 238/473] Patching --- unsloth/models/rl.py | 3 ++- unsloth/tokenizer_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 0c34f5002..ab51e9cf6 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -302,9 +302,10 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if "max_seq_length" in call_args: length_check = \ "if hasattr(model, 'max_seq_length') and model.max_seq_length > max_seq_length:\n"\ - " print('Unsloth: You set `max_seq_length` as ' + str(max_seq_length) + ' but the\\n'"\ + " print('Unsloth: You set `max_seq_length` as ' + str(max_seq_length) + ' but the\\n'\n"\ " 'model maximum sequence length is ' + str(model.max_seq_length) + '. We will reduce it.')\n" " max_seq_length = model.max_seq_length\n" + "if hasattr(model, 'max_seq_length') and max_seq_length is None: max_seq_length = model.max_seq_length\n" extra_args += length_check pass diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 82e82eb68..ab3878613 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1056,5 +1056,5 @@ def patch_sft_trainer_tokenizer(): pass pass -# FInally patch TRL tokenizer things -patch_sft_trainer_tokenizer() +# Finally patch TRL tokenizer things +# patch_sft_trainer_tokenizer() From 56bf7a1b3b5c57b4cf1b26fc33c7c14b43a340f7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 15:53:46 -0800 Subject: [PATCH 239/473] Update rl.py --- unsloth/models/rl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ab51e9cf6..3d5dbfdf3 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -272,9 +272,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Warn on too large or too small learning rate if " learning_rate" in call_args: learning_rate_check = \ - "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! '"\ - "'Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ - "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! '"\ + "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! "\ + "Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ + "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! "\ "Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n" extra_args += learning_rate_check pass From 8c236572134d1c4798339992d890363fbb56479e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 15:57:32 -0800 Subject: [PATCH 240/473] Update rl.py --- unsloth/models/rl.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3d5dbfdf3..a5db30d7c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -230,6 +230,17 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += eval_changes pass + # Check max_seq_length + if "max_seq_length" in call_args: + length_check = \ + "if hasattr(model, 'max_seq_length') and model.max_seq_length > max_seq_length:\n"\ + " print('Unsloth: You set `max_seq_length` as ' + str(max_seq_length) + ' but the\\n'\n"\ + " 'model maximum sequence length is ' + str(model.max_seq_length) + '. We will reduce it.')\n" + " max_seq_length = model.max_seq_length\n" + "if hasattr(model, 'max_seq_length') and max_seq_length is None: max_seq_length = model.max_seq_length\n" + extra_args += length_check + pass + # Add statistics as well! extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ @@ -298,17 +309,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += num_proc_check pass - # Check max_seq_length - if "max_seq_length" in call_args: - length_check = \ - "if hasattr(model, 'max_seq_length') and model.max_seq_length > max_seq_length:\n"\ - " print('Unsloth: You set `max_seq_length` as ' + str(max_seq_length) + ' but the\\n'\n"\ - " 'model maximum sequence length is ' + str(model.max_seq_length) + '. We will reduce it.')\n" - " max_seq_length = model.max_seq_length\n" - "if hasattr(model, 'max_seq_length') and max_seq_length is None: max_seq_length = model.max_seq_length\n" - extra_args += length_check - pass - # Create RLConfig args extra_args = extra_args.split("\n") extra_args = "\n".join(" "*8 + x for x in extra_args) From e735ab593636d8d12913e146e8848d214f2694d3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 16:03:33 -0800 Subject: [PATCH 241/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index a5db30d7c..f7265cff8 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -245,6 +245,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ f"PatchRLStatistics('{trainer_file}')\n" + "print(args)\n" # Create RLTrainer args extra_args = extra_args.split("\n") From 484afd783efd90b949725a992a676d8cd1a3342b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 16:04:56 -0800 Subject: [PATCH 242/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f7265cff8..c41b45f1b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -244,7 +244,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Add statistics as well! extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ - f"PatchRLStatistics('{trainer_file}')\n" + f"PatchRLStatistics('{trainer_file}')\n"\ "print(args)\n" # Create RLTrainer args From 4a23920d2bf2f1ba358c5f9a0cbfca09022c4506 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 16:20:14 -0800 Subject: [PATCH 243/473] Update rl.py --- unsloth/models/rl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c41b45f1b..4c488187c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -72,6 +72,7 @@ def generate_with_clone(*args, **kwargs): RLTrainer_replacement = ''' +import os from typing import * from dataclasses import dataclass, field from packaging.version import Version @@ -188,7 +189,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "if not float16 and use_fp16: raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\ "if not use_bf16 and not use_fp16:\n"\ " args.fp16 = float16\n"\ - " args.bf16 = not float16\n" + " args.bf16 = not float16\n"\ + " os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'\n" extra_args += mixed_precision pass @@ -244,8 +246,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Add statistics as well! extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ - f"PatchRLStatistics('{trainer_file}')\n"\ - "print(args)\n" + f"PatchRLStatistics('{trainer_file}')\n" # Create RLTrainer args extra_args = extra_args.split("\n") From 19b16bb3025f6341a4f280b0a50d2ddeaf513240 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 18:19:16 -0800 Subject: [PATCH 244/473] NEFTune --- unsloth/models/llama.py | 7 +++++-- unsloth/models/rl.py | 39 +++++++++++++++++++++++++++++++++++++- unsloth/tokenizer_utils.py | 1 - 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 5583702e7..6a8049192 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -15,6 +15,7 @@ import torch import gc import math +from functools import partial from typing import Optional, Tuple, List, Union from ._utils import * from ._utils import __version__ @@ -1802,8 +1803,6 @@ def from_pretrained( model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype) model.vllm_engine = llm model.fast_generate = model.vllm_engine.generate - - from functools import partial model.fast_generate_batches = partial(generate_batches, model.vllm_engine) pass # Return old flag @@ -2632,6 +2631,10 @@ def patch_peft_model( gc.collect() torch.cuda.empty_cache() pass + + # Add for_inference and for_training + model.for_training = partial(FastLlamaModel.for_training, model) + model.for_inference = partial(FastLlamaModel.for_inference, model) return model pass diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 4c488187c..ec1d65ba4 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -71,6 +71,16 @@ def generate_with_clone(*args, **kwargs): pass +# Handles NEFTune +def neftune_post_forward_hook(module, input, output): + if module.training: + dims = torch.tensor(output.size(1) * output.size(2)) + mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) + output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) + return output +pass + + RLTrainer_replacement = ''' import os from typing import * @@ -106,6 +116,7 @@ def __init__({RLTrainer_arguments}, if args is None: args = Unsloth{RLConfig_name}() {RLTrainer_extra_args} super().__init__({RLTrainer_call_args}{RLTrainer_kwargs}) + {RLTrainer_post} pass ''' @@ -164,6 +175,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Process RLTrainer first arguments, call_args = processed[0] + RLTrainer_post = "" # Add tokenizer if not seen if "tokenizer" not in parameters and "processing_class" in parameters: @@ -215,7 +227,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): " if Version(transformers_version) <= Version('4.45.2'):\n"\ " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\n"\ " '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')\n" - extra_args += check_ga eval_changes = \ @@ -243,6 +254,29 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += length_check pass + # Check NEFTune + if "neftune_noise_alpha" in call_args: + neftune_check = \ + "if hasattr(self, 'neftune_hook_handle'):\n"\ + " self.neftune_hook_handle.remove()\n"\ + " if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"\ + "if getattr(args, 'neftune_noise_alpha', None) is not None:\n"\ + " model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"\ + " self.neftune_hook_handle = self.model.get_input_embeddings().register_forward_hook(neftune_post_forward_hook)\n"\ + "pass\n" + RLTrainer_post += neftune_check + pass + + # Enable for training and move padding side of tokenizer to right + RLTrainer_post += \ + "if model is not None and hasattr(model, 'for_training'):\n"\ + " model.for_training()\n"\ + "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\ + "if 'processing_class' in locals():\n"\ + " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\ + " if hasattr(processing_class, tokenizer) and hasattr(processing_class.tokenizer, 'padding_side'): "\ + "processing_class.tokenizer.padding_side = 'right'\n" + # Add statistics as well! extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ @@ -251,6 +285,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Create RLTrainer args extra_args = extra_args.split("\n") extra_args = "\n".join(" "*8 + x for x in extra_args) + RLTrainer_post = RLTrainer_post.split("\n") + RLTrainer_post = "\n".join(" "*8 + x for x in RLTrainer_post) RLTrainer_arguments = arguments RLTrainer_extra_args = extra_args RLTrainer_call_args = call_args @@ -344,6 +380,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLConfig_kwargs = ",**kwargs"[1 if RLConfig_call_args .endswith(",") else 0:], RLTrainer_extras = RLTrainer_extras, + RLTrainer_post = RLTrainer_post, ) # Create new function diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index ab3878613..0300d1330 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -945,7 +945,6 @@ def patch_sft_trainer_tokenizer(): " from functools import partial\n"\ " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\ " processing_class = tokenizer\n"\ - " print(1111)\n" "else:\n"\ " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" From 7e19c0f6f3dfed00c6aa2ee7f8fa1380beb73c77 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 18:49:09 -0800 Subject: [PATCH 245/473] Update rl.py --- unsloth/models/rl.py | 48 +++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ec1d65ba4..2ab8f218a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -188,7 +188,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Edit bf16, fp16 by checking model's torch_dtype directly extra_args = "" - if "args" in call_args: + if "args" in call_args and "model" in call_args: mixed_precision = \ "use_bf16 = getattr(args, 'bf16', False)\n"\ "use_fp16 = getattr(args, 'fp16', False)\n"\ @@ -239,23 +239,30 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"\ "if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"\ "if not bf16_full_eval and not fp16_full_eval: args.bf16_full_eval = args.bf16; args.fp16_full_eval = args.fp16\n" - extra_args += eval_changes pass # Check max_seq_length - if "max_seq_length" in call_args: + if "model" in call_args: length_check = \ - "if hasattr(model, 'max_seq_length') and model.max_seq_length > max_seq_length:\n"\ - " print('Unsloth: You set `max_seq_length` as ' + str(max_seq_length) + ' but the\\n'\n"\ - " 'model maximum sequence length is ' + str(model.max_seq_length) + '. We will reduce it.')\n" - " max_seq_length = model.max_seq_length\n" - "if hasattr(model, 'max_seq_length') and max_seq_length is None: max_seq_length = model.max_seq_length\n" + "if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):\n"\ + " pass\n"\ + "else:\n"\ + " model_max_seq_length = getattr(model, 'max_seq_length', None)\n"\ + " args_max_seq_length = getattr(args, 'max_seq_length', None)\n"\ + " if args_max_seq_length is None and model_max_seq_length is not None:\n"\ + " max_seq_length = model.max_seq_length\n"\ + " if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length\n" + " elif args_max_seq_length is not None and model_max_seq_length is not None:\n"\ + " if args_max_seq_length > model_max_seq_length:\n"\ + " print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but \n"\ + " the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')\n"\ + " args.max_seq_length = model_max_seq_length\n\n" extra_args += length_check pass # Check NEFTune - if "neftune_noise_alpha" in call_args: + if "model" in call_args: neftune_check = \ "if hasattr(self, 'neftune_hook_handle'):\n"\ " self.neftune_hook_handle.remove()\n"\ @@ -268,15 +275,18 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pass # Enable for training and move padding side of tokenizer to right - RLTrainer_post += \ - "if model is not None and hasattr(model, 'for_training'):\n"\ - " model.for_training()\n"\ - "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\ - "if 'processing_class' in locals():\n"\ - " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\ - " if hasattr(processing_class, tokenizer) and hasattr(processing_class.tokenizer, 'padding_side'): "\ - "processing_class.tokenizer.padding_side = 'right'\n" - + if "model" in call_args: + training_check = \ + "if model is not None and hasattr(model, 'for_training'):\n"\ + " model.for_training()\n"\ + "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\ + "if 'processing_class' in locals():\n"\ + " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\ + " if hasattr(processing_class, tokenizer) and hasattr(processing_class.tokenizer, 'padding_side'): "\ + "processing_class.tokenizer.padding_side = 'right'\n" + RLTrainer_post += training_check + pass + # Add statistics as well! extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ @@ -347,6 +357,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += num_proc_check pass + # Edit report_to and default it to nothing if max_steps is like 60 + # Create RLConfig args extra_args = extra_args.split("\n") extra_args = "\n".join(" "*8 + x for x in extra_args) From 0ac3d15339f1dd3d2d00aa0f8f8d3ec6b1ad8bbe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 18:54:39 -0800 Subject: [PATCH 246/473] Update rl.py --- unsloth/models/rl.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2ab8f218a..c26d450ca 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -257,10 +257,23 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): " if args_max_seq_length > model_max_seq_length:\n"\ " print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but \n"\ " the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')\n"\ - " args.max_seq_length = model_max_seq_length\n\n" + " args.max_seq_length = model_max_seq_length\n" extra_args += length_check pass + # Enable for training and move padding side of tokenizer to right + if "model" in call_args: + training_check = \ + "if model is not None and hasattr(model, 'for_training'):\n"\ + " model.for_training()\n"\ + "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\ + "if 'processing_class' in locals():\n"\ + " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\ + " if hasattr(processing_class, tokenizer) and hasattr(processing_class.tokenizer, 'padding_side'): "\ + "processing_class.tokenizer.padding_side = 'right'\n" + extra_args += training_check + pass + # Check NEFTune if "model" in call_args: neftune_check = \ @@ -274,19 +287,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_post += neftune_check pass - # Enable for training and move padding side of tokenizer to right - if "model" in call_args: - training_check = \ - "if model is not None and hasattr(model, 'for_training'):\n"\ - " model.for_training()\n"\ - "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\ - "if 'processing_class' in locals():\n"\ - " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\ - " if hasattr(processing_class, tokenizer) and hasattr(processing_class.tokenizer, 'padding_side'): "\ - "processing_class.tokenizer.padding_side = 'right'\n" - RLTrainer_post += training_check - pass - # Add statistics as well! extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ From 70b341cc6ceb7645c2fb5db2d5faaa88c5490adc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 18:56:09 -0800 Subject: [PATCH 247/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c26d450ca..2a3a9eb20 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -116,7 +116,7 @@ def __init__({RLTrainer_arguments}, if args is None: args = Unsloth{RLConfig_name}() {RLTrainer_extra_args} super().__init__({RLTrainer_call_args}{RLTrainer_kwargs}) - {RLTrainer_post} +{RLTrainer_post} pass ''' From 3b641de6f54632043b9f49b07a7ebe99f2a18368 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 18:57:35 -0800 Subject: [PATCH 248/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2a3a9eb20..c55e4141d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -269,7 +269,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\ "if 'processing_class' in locals():\n"\ " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\ - " if hasattr(processing_class, tokenizer) and hasattr(processing_class.tokenizer, 'padding_side'): "\ + " if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): "\ "processing_class.tokenizer.padding_side = 'right'\n" extra_args += training_check pass From 30ad4c4fe897ff76b4ecabd958dd68bff6b7924d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 19:00:53 -0800 Subject: [PATCH 249/473] Update rl.py --- unsloth/models/rl.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c55e4141d..2d75452b2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -71,7 +71,14 @@ def generate_with_clone(*args, **kwargs): pass -# Handles NEFTune +RLTrainer_replacement = ''' +import os +from typing import * +from dataclasses import dataclass, field +from packaging.version import Version +import torch + +# https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_utils.py#L126 def neftune_post_forward_hook(module, input, output): if module.training: dims = torch.tensor(output.size(1) * output.size(2)) @@ -80,13 +87,6 @@ def neftune_post_forward_hook(module, input, output): return output pass - -RLTrainer_replacement = ''' -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version - @dataclass class Unsloth{RLConfig_name}({RLConfig_name}): """ From a848c019b09ea65b19d8e569bb96b6df98da84fb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 19:34:29 -0800 Subject: [PATCH 250/473] Update rl.py --- unsloth/models/rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2d75452b2..b1ee649c8 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -282,7 +282,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): " if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"\ "if getattr(args, 'neftune_noise_alpha', None) is not None:\n"\ " model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"\ - " self.neftune_hook_handle = self.model.get_input_embeddings().register_forward_hook(neftune_post_forward_hook)\n"\ "pass\n" RLTrainer_post += neftune_check pass From f25abe6a700747ee5376ed5da1315c65d9e23cf6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 19:34:41 -0800 Subject: [PATCH 251/473] Update rl.py --- unsloth/models/rl.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b1ee649c8..4e7fcfa7a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -78,15 +78,6 @@ def generate_with_clone(*args, **kwargs): from packaging.version import Version import torch -# https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_utils.py#L126 -def neftune_post_forward_hook(module, input, output): - if module.training: - dims = torch.tensor(output.size(1) * output.size(2)) - mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) - output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) - return output -pass - @dataclass class Unsloth{RLConfig_name}({RLConfig_name}): """ From 069446362f7d496909dd02f8dfe5390be21be858 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 20:35:34 -0800 Subject: [PATCH 252/473] Extra replacements --- unsloth/models/rl.py | 11 ++++++- unsloth/models/rl_replacements.py | 50 +++++++++++++++++++++++++++++++ unsloth/tokenizer_utils.py | 3 +- 3 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 unsloth/models/rl_replacements.py diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 4e7fcfa7a..3e1b6993f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -23,7 +23,9 @@ import re from unsloth_zoo.compiler import create_new_function from unsloth_zoo.logging_utils import PatchRLStatistics - +from .rl_replacements import ( + RL_EXTRA_ARGS, +) def PatchRL(FastLanguageModel): @@ -282,6 +284,13 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ f"PatchRLStatistics('{trainer_file}')\n" + # Patch optional args + if trainer_file in RL_EXTRA_ARGS: + process_extra_args = RL_EXTRA_ARGS[trainer_file] + for process_extra_arg in process_extra_args: + extra_args += process_extra_args(call_args, extra_args) + pass + # Create RLTrainer args extra_args = extra_args.split("\n") extra_args = "\n".join(" "*8 + x for x in extra_args) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py new file mode 100644 index 000000000..56ad57f5c --- /dev/null +++ b/unsloth/models/rl_replacements.py @@ -0,0 +1,50 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + "RL_EXTRA_ARGS", +] + +RL_EXTRA_ARGS = dict() + +def sft_trainer_fix_untraiend_tokens(call_args, extra_args): + if "model" in call_args and "train_dataset" in call_args: + fix_tokenizer = \ + "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', set())\n"\ + "from unsloth_zoo.tokenizer_utils import fix_untrained_tokens\n"\ + "from unsloth_zoo.training_utils import fix_zero_training_loss\n"\ + "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ + "fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n"\ + "fix_zero_training_loss(model, tokenizer, train_dataset)\n" + return fix_tokenizer + return "" +pass +RL_EXTRA_ARGS["sft_trainer"] = [sft_trainer_fix_untraiend_tokens,] + + +def dpo_trainer_fix_columns(call_args, extra_args): + if "model" in call_args and "train_dataset" in call_args: + fix_dpo = \ + "if hasattr(train_dataset, 'column_names'):\n"\ + " column_names = set(train_dataset.column_names)\n"\ + " check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\ + " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\ + " 'prompt_input_ids', 'prompt_attention_mask']\n"\ + " if all(x in column_names for x in check):\n"\ + " train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\ + " del check, column_names\n"\ + return fix_dpo + return "" +pass +RL_EXTRA_ARGS["dpo_trainer"] = [dpo_trainer_fix_columns,] diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 0300d1330..404fce319 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -59,6 +59,7 @@ [x.lower() for x in IGNORED_TOKENIZER_NAMES] + \ [x.lower()+"-bnb-4bit" for x in IGNORED_TOKENIZER_NAMES] ) +os.environ["UNSLOTH_IGNORED_TOKENIZER_NAMES"] = "\n".join(IGNORED_TOKENIZER_NAMES) # Check environments keynames = "\n" + "\n".join(os.environ.keys()) @@ -1055,5 +1056,5 @@ def patch_sft_trainer_tokenizer(): pass pass -# Finally patch TRL tokenizer things +# Finally patch TRL tokenizer things -> moved to RL # patch_sft_trainer_tokenizer() From 8cc0338fb3d5e7281da39a00340bb129c05594cb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 20:37:18 -0800 Subject: [PATCH 253/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 56ad57f5c..a09fcb1fb 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -43,7 +43,7 @@ def dpo_trainer_fix_columns(call_args, extra_args): " 'prompt_input_ids', 'prompt_attention_mask']\n"\ " if all(x in column_names for x in check):\n"\ " train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\ - " del check, column_names\n"\ + " del check, column_names\n" return fix_dpo return "" pass From a145a835459acc9e59fc603ac235ae30fd1612e0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 20:39:55 -0800 Subject: [PATCH 254/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3e1b6993f..d91a6680d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -288,7 +288,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if trainer_file in RL_EXTRA_ARGS: process_extra_args = RL_EXTRA_ARGS[trainer_file] for process_extra_arg in process_extra_args: - extra_args += process_extra_args(call_args, extra_args) + extra_args += process_extra_arg(call_args, extra_args) pass # Create RLTrainer args From 39fbcfb0add504b974f0c6b5a5ec23061d20a423 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 21:10:32 -0800 Subject: [PATCH 255/473] extra RL replacements --- unsloth/models/rl.py | 13 ++++++-- unsloth/models/rl_replacements.py | 54 ++++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index d91a6680d..24a5c8d1f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -25,6 +25,7 @@ from unsloth_zoo.logging_utils import PatchRLStatistics from .rl_replacements import ( RL_EXTRA_ARGS, + RL_FUNCTIONS, ) def PatchRL(FastLanguageModel): @@ -365,8 +366,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLConfig_extra_args = extra_args RLConfig_call_args = call_args - # Patch vLLM - RLTrainer_extras = patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports) + # Patch vLLM and other functions + RLTrainer_extras = patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports) if RLTrainer_extras is None: RLTrainer_extras = f"_Unsloth{RLTrainer_name} = {RLTrainer_name}" @@ -414,7 +415,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pass -def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): +def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): init = inspect.getsource(RLTrainer.__init__) old_init = init @@ -475,6 +476,7 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): functions = [x for x in functions if f"def {x}" in RLTrainer_source] changed = {"__init__" : (old_init, init,)} + edit_functions = RL_FUNCTIONS.get(trainer_file, []) for function in functions: if not hasattr(RLTrainer, function): continue @@ -483,6 +485,11 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): except: continue original_source = source + # Check for function + for edit_function in edit_functions: + source = edit_function(function, source) + pass + # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model source = re.sub( r"(\n[\s]{4,}).+?model_executor\.driver_worker.+?\n", diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a09fcb1fb..56c5c7ad9 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -14,14 +14,19 @@ __all__ = [ "RL_EXTRA_ARGS", + "RL_FUNCTIONS", ] -RL_EXTRA_ARGS = dict() +import re +from collections import defaultdict +RL_EXTRA_ARGS = defaultdict(list) +RL_FUNCTIONS = defaultdict(list) + def sft_trainer_fix_untraiend_tokens(call_args, extra_args): if "model" in call_args and "train_dataset" in call_args: fix_tokenizer = \ - "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', set())\n"\ + "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')\n"\ "from unsloth_zoo.tokenizer_utils import fix_untrained_tokens\n"\ "from unsloth_zoo.training_utils import fix_zero_training_loss\n"\ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ @@ -30,7 +35,7 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): return fix_tokenizer return "" pass -RL_EXTRA_ARGS["sft_trainer"] = [sft_trainer_fix_untraiend_tokens,] +RL_EXTRA_ARGS["sft_trainer"].append(sft_trainer_fix_untraiend_tokens) def dpo_trainer_fix_columns(call_args, extra_args): @@ -47,4 +52,45 @@ def dpo_trainer_fix_columns(call_args, extra_args): return fix_dpo return "" pass -RL_EXTRA_ARGS["dpo_trainer"] = [dpo_trainer_fix_columns,] +RL_EXTRA_ARGS["dpo_trainer"].append(dpo_trainer_fix_columns) + + +def sft_trainer_prepare_dataset(function_name, function): + if function_name != "_prepare_non_packed_dataloader" and \ + function_name != "_prepare_dataset": return + + check_text = \ + "\n"\ + "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ + "if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\ + "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"\ + "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"\ + "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"\ + "chat_template = getattr(tokenizer, 'chat_template', None)\n"\ + "chat_template = '' if chat_template is None else chat_template\n"\ + "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\ + "if getattr(tokenizer, 'bos_token', None) is not None else False\n"\ + "if 'add_special_tokens' not in locals() and has_bos_token_already:\n"\ + " from functools import partial\n"\ + " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\ + " processing_class = tokenizer\n"\ + "else:\n"\ + " add_special_tokens = False if has_bos_token_already else add_special_tokens\n" + + check_text = check_text.split("\n") + check_text = "\n".join(" "*where + x for x in check_text) + check_text = check_text.rstrip() + "\n" + + # .*? matches first match. .+? matches final match. + replacer = re.findall( + f"def {function_name}\(.*?\).*?\:\n", + function, + flags = re.MULTILINE | re.DOTALL, + ) + if len(replacer) != 0: + replacer = replacer[0] + function = function.replace(replacer, replacer + check_text) + pass + return function +pass +RL_FUNCTIONS["sft_trainer"].append(sft_trainer_prepare_dataset) From 2e68bb352569e6fb5226f919a21c398f8a8b6bb6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 21:13:31 -0800 Subject: [PATCH 256/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 56c5c7ad9..b60a10319 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -57,8 +57,8 @@ def dpo_trainer_fix_columns(call_args, extra_args): def sft_trainer_prepare_dataset(function_name, function): if function_name != "_prepare_non_packed_dataloader" and \ - function_name != "_prepare_dataset": return - + function_name != "_prepare_dataset": return function + check_text = \ "\n"\ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ @@ -90,7 +90,7 @@ def sft_trainer_prepare_dataset(function_name, function): if len(replacer) != 0: replacer = replacer[0] function = function.replace(replacer, replacer + check_text) - pass + pass return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_prepare_dataset) From 82d3f6af8198d8595f2ea6fae39f2a89c3569459 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 21:14:41 -0800 Subject: [PATCH 257/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b60a10319..6098336e1 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -78,7 +78,7 @@ def sft_trainer_prepare_dataset(function_name, function): " add_special_tokens = False if has_bos_token_already else add_special_tokens\n" check_text = check_text.split("\n") - check_text = "\n".join(" "*where + x for x in check_text) + check_text = "\n".join(" "*4 + x for x in check_text) check_text = check_text.rstrip() + "\n" # .*? matches first match. .+? matches final match. From 0c691cf8213aa2b9d79232860e4cdb5a3bdfa162 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 21:16:56 -0800 Subject: [PATCH 258/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 6098336e1..5c6cb0c64 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -26,7 +26,7 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): if "model" in call_args and "train_dataset" in call_args: fix_tokenizer = \ - "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')\n"\ + "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\\n')\n"\ "from unsloth_zoo.tokenizer_utils import fix_untrained_tokens\n"\ "from unsloth_zoo.training_utils import fix_zero_training_loss\n"\ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ From cd6f9b684f967c27e1944987f34bd3ec975ebcdc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 21:18:55 -0800 Subject: [PATCH 259/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 5c6cb0c64..c98adfee8 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -78,7 +78,7 @@ def sft_trainer_prepare_dataset(function_name, function): " add_special_tokens = False if has_bos_token_already else add_special_tokens\n" check_text = check_text.split("\n") - check_text = "\n".join(" "*4 + x for x in check_text) + check_text = "\n".join(" "*8 + x for x in check_text) check_text = check_text.rstrip() + "\n" # .*? matches first match. .+? matches final match. From be568b03e9eb2a3a26c7b49785a0abb06c588224 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 21:31:23 -0800 Subject: [PATCH 260/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index c98adfee8..b7d018915 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -60,7 +60,6 @@ def sft_trainer_prepare_dataset(function_name, function): function_name != "_prepare_dataset": return function check_text = \ - "\n"\ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ "if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\ "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"\ From 9ade7824064db4b346061812797a3095fd08d163 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 22:00:44 -0800 Subject: [PATCH 261/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b7d018915..f3d5039a6 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -23,6 +23,7 @@ RL_FUNCTIONS = defaultdict(list) +# Check untrained tokens def sft_trainer_fix_untraiend_tokens(call_args, extra_args): if "model" in call_args and "train_dataset" in call_args: fix_tokenizer = \ @@ -38,6 +39,7 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): RL_EXTRA_ARGS["sft_trainer"].append(sft_trainer_fix_untraiend_tokens) +# Remove DPO columns which might randomnly be tokenized def dpo_trainer_fix_columns(call_args, extra_args): if "model" in call_args and "train_dataset" in call_args: fix_dpo = \ @@ -55,6 +57,7 @@ def dpo_trainer_fix_columns(call_args, extra_args): RL_EXTRA_ARGS["dpo_trainer"].append(dpo_trainer_fix_columns) +# Fix tokenizer double BOS def sft_trainer_prepare_dataset(function_name, function): if function_name != "_prepare_non_packed_dataloader" and \ function_name != "_prepare_dataset": return function @@ -93,3 +96,23 @@ def sft_trainer_prepare_dataset(function_name, function): return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_prepare_dataset) + + +# Ignore mean_token_accuracy since it needs logits +def sft_trainer_compute_loss(function_name, function): + if function_name != "compute_loss": return function + + # .*? matches first match. .+? matches final match. + replacer = re.findall( + f"\.compute_loss\(.*?\)", + function, + flags = re.MULTILINE | re.DOTALL, + ) + if len(replacer) != 0: + replacer = replacer[0] + returner = " "*8 + "return (loss, outputs) if return_outputs else loss" + function = function.replace(replacer, replacer + returner) + pass + return function +pass +RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss) From e49815038ac2fb5d29af342e3cc6b6ca273a0885 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 22:02:22 -0800 Subject: [PATCH 262/473] Update llama.py --- unsloth/models/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 6a8049192..3a87ab56d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2145,8 +2145,6 @@ def get_peft_model( signature = str(inspect.signature(LoraConfig)) SUPPORTS_LOFTQ = "loftq_config" in signature SUPPORTS_RSLORA = "use_rslora" in signature - - assert(max_seq_length <= model.max_seq_length) if lora_dropout != 0: logger.warning_once( From 2a5aa3d0ba1710dd7e9a225470cf7fe457d88e64 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 22:02:41 -0800 Subject: [PATCH 263/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index f3d5039a6..65138feb1 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -110,7 +110,7 @@ def sft_trainer_compute_loss(function_name, function): ) if len(replacer) != 0: replacer = replacer[0] - returner = " "*8 + "return (loss, outputs) if return_outputs else loss" + returner = "\n" + " "*8 + "return (loss, outputs) if return_outputs else loss" function = function.replace(replacer, replacer + returner) pass return function From 25245382083bb5dff58f853e2cdb70fc70012702 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 23:10:11 -0800 Subject: [PATCH 264/473] Update _utils.py --- unsloth/models/_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 2ec4adaa1..6aa7f94cf 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -131,6 +131,7 @@ # Ignore logging messages class HideLoggingMessage(logging.Filter): + __slots__ = "text", def __init__(self, text): self.text = text def filter(self, x): return not (self.text in x.getMessage()) pass @@ -138,6 +139,8 @@ def filter(self, x): return not (self.text in x.getMessage()) # The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here. from transformers.training_args import logger as transformers_training_args_logger transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups")) +# torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. +transformers_training_args_logger.addFilter(HideLoggingMessage("torch.distributed")) del transformers_training_args_logger # Using the default loss: `ForCausalLMLoss`. From c9ba000df50d2338fbbf55e1396847c2862ad4c7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 23:45:26 -0800 Subject: [PATCH 265/473] Update loader_utils.py --- unsloth/models/loader_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth/models/loader_utils.py b/unsloth/models/loader_utils.py index b778b7e95..e3eadd8c0 100644 --- a/unsloth/models/loader_utils.py +++ b/unsloth/models/loader_utils.py @@ -58,6 +58,11 @@ def __get_model_name( elif load_in_4bit and SUPPORTS_FOURBIT and lower_model_name in FLOAT_TO_INT_MAPPER: + # Support returning original full -bnb-4bit name if specified specifically + # since we'll map it to the dynamic version instead + if lower_model_name.endswith("-bnb-4bit"): + return lower_model_name + new_model_name = FLOAT_TO_INT_MAPPER[lower_model_name] # logger.warning_once( # f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\ From 5b2fd7272860850c79a9d8b130d830a5300bc655 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 23:47:33 -0800 Subject: [PATCH 266/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 24a5c8d1f..1639590c2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -401,6 +401,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, + overwrite = False, ) # Patch Trainer From 3466186a78496a4849b7fe93033572255cbc9956 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 23:58:26 -0800 Subject: [PATCH 267/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 65138feb1..ba759095e 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -33,6 +33,7 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ "fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n"\ "fix_zero_training_loss(model, tokenizer, train_dataset)\n" + "print(1111)\n", return fix_tokenizer return "" pass From 5dc88470026ddd47380061961b9e18f39bdbb0e7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 01:53:16 -0800 Subject: [PATCH 268/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ba759095e..65138feb1 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -33,7 +33,6 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ "fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n"\ "fix_zero_training_loss(model, tokenizer, train_dataset)\n" - "print(1111)\n", return fix_tokenizer return "" pass From 9aad48e1ee1ac1de72bd7c2b132ca27bc2b9418f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 02:27:34 -0800 Subject: [PATCH 269/473] Update rl.py --- unsloth/models/rl.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 1639590c2..cf351ebf3 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -428,19 +428,27 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import init = init.replace("get_peft_model(model, peft_config)", "model") # Set use_vllm if not set - init = re.sub( - r"\)([ ]{0,}\-\>[ ]{0,}None[ ]{0,}):\n([\s]{4})", - r"):\n\2 "\ - r"if hasattr(model, 'vllm_engine') and "\ - r"getattr(args, 'use_vllm') and getattr(args, 'use_vllm', False): "\ - r"args.use_vllm = True\n\2", - init, 1, - ) + if "args.use_vllm" in init and "model" in init and "args" in init: + # .*? matches first match. .+? matches final match. + replacer = re.findall( + "def __init__\(.*?\).*?\:\n", + init, + flags = re.MULTILINE | re.DOTALL, + ) + if len(replacer) != 0: + replacer = replacer[0] + vllm_setter = "\n" + " "*8 + \ + "if hasattr(model, 'vllm_engine') and "\ + "getattr(args, 'use_vllm') and getattr(args, 'use_vllm', False): "\ + "args.use_vllm = True\n" + init = init.replace(replacer, replacer + vllm_setter) + pass + pass vllm_part = re.findall( r"(\n[\s]{8}"\ - r"if (self|args)\.use_vllm\:.+?"\ - r"\n[\s]{8,}"\ + r"if (self|args)\.use_vllm\:.*?"\ + r"\n[\s]{8}"\ "else:\n)", init, flags = re.MULTILINE | re.DOTALL, From f121a5c37dc5f087c925944b9ee798d13f288eaa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:03:43 -0800 Subject: [PATCH 270/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3a87ab56d..4f77280ad 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1154,6 +1154,7 @@ def _CausalLM_fast_forward( output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output + print("========== dtype = ", logits.dtype) return CausalLMOutputWithPast( loss=loss, logits=logits, From 5052d354e5f6cfd8f8fe15c2b3a3ef972793561a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:08:40 -0800 Subject: [PATCH 271/473] Update llama.py --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4f77280ad..eaf4f8b73 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1153,8 +1153,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - - print("========== dtype = ", logits.dtype) + return CausalLMOutputWithPast( loss=loss, logits=logits, From a11aa96555440aed6ee94d281e37c625df27ef80 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:24:32 -0800 Subject: [PATCH 272/473] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index eaf4f8b73..fb05e052d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1153,7 +1153,8 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - + + print(loss, logits) return CausalLMOutputWithPast( loss=loss, logits=logits, From a6abe0261c2e3264dd1aa90e32d69e4ffdb0e921 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:32:12 -0800 Subject: [PATCH 273/473] Update llama.py --- unsloth/models/llama.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fb05e052d..e03f73301 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1154,7 +1154,13 @@ def _CausalLM_fast_forward( output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - print(loss, logits) + print(CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )) return CausalLMOutputWithPast( loss=loss, logits=logits, From d867faa1dc845c70e548caa25353d87c491130c8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:50:07 -0800 Subject: [PATCH 274/473] autocast --- unsloth/models/rl.py | 1 + unsloth/models/rl_replacements.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index cf351ebf3..466101d16 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -80,6 +80,7 @@ def generate_with_clone(*args, **kwargs): from dataclasses import dataclass, field from packaging.version import Version import torch +from contextlib import nullcontext @dataclass class Unsloth{RLConfig_name}({RLConfig_name}): diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 65138feb1..2ea12f69c 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -116,3 +116,22 @@ def sft_trainer_compute_loss(function_name, function): return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss) + + +# Autocast precision for GRPO +def grpo_trainer__prepare_inputs(function_name, function): + if function_name != "_prepare_inputs": return function + + if "with torch.inference_mode()" not in function: return function + + function = function.replace( + "with torch.inference_mode()", + + "with torch.inference_mode(), "\ + "torch.amp.autocast(device_type = 'cuda', "\ + "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ + "if not torch.is_autocast_enabled('cuda') else nullcontext()", + ) + return function +pass +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) From 44c9228b8d4360146d53220721bcd6692bc5d1de Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:50:32 -0800 Subject: [PATCH 275/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 2ea12f69c..67027f0b4 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -125,12 +125,12 @@ def grpo_trainer__prepare_inputs(function_name, function): if "with torch.inference_mode()" not in function: return function function = function.replace( - "with torch.inference_mode()", + "with torch.inference_mode():", "with torch.inference_mode(), "\ "torch.amp.autocast(device_type = 'cuda', "\ "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ - "if not torch.is_autocast_enabled('cuda') else nullcontext()", + "if not torch.is_autocast_enabled('cuda') else nullcontext():", ) return function pass From e83d854ae9e8cd03655b78f70f56923af155f537 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:56:12 -0800 Subject: [PATCH 276/473] Update llama.py --- unsloth/models/llama.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index e03f73301..eaf4f8b73 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1153,14 +1153,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - - print(CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - )) + return CausalLMOutputWithPast( loss=loss, logits=logits, From 623eb656feeed7800a6f62360457598a9eb41991 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 16:16:31 -0800 Subject: [PATCH 277/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 67027f0b4..a101d35a0 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -135,3 +135,14 @@ def grpo_trainer__prepare_inputs(function_name, function): return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) + + +# Remove _move_model_to_vllm +def grpo_trainer__move_model_to_vllm(function_name, function): + if function_name != "_move_model_to_vllm": return function + + # .*? matches first match. .+? matches final match. + function = "def _move_model_to_vllm(*args, **kwargs): return None\n" + return function +pass +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) From 7e612f0a567de70a85cbb296efe0ef3918e48969 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 16:19:13 -0800 Subject: [PATCH 278/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a101d35a0..9405fef57 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -143,6 +143,6 @@ def grpo_trainer__move_model_to_vllm(function_name, function): # .*? matches first match. .+? matches final match. function = "def _move_model_to_vllm(*args, **kwargs): return None\n" - return function + return function.find("def") * " " + function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) From a45266be8ea5cab78982254ee46feac7c21ac6c3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 16:19:34 -0800 Subject: [PATCH 279/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9405fef57..0063ea4af 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -142,7 +142,7 @@ def grpo_trainer__move_model_to_vllm(function_name, function): if function_name != "_move_model_to_vllm": return function # .*? matches first match. .+? matches final match. - function = "def _move_model_to_vllm(*args, **kwargs): return None\n" + function = "def _move_model_to_vllm(self, *args, **kwargs): return None\n" return function.find("def") * " " + function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) From c855d7ef663cddad980a6c0dcb95bbdf146f7b8e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 16:23:47 -0800 Subject: [PATCH 280/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0063ea4af..0f342ec86 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -142,7 +142,7 @@ def grpo_trainer__move_model_to_vllm(function_name, function): if function_name != "_move_model_to_vllm": return function # .*? matches first match. .+? matches final match. - function = "def _move_model_to_vllm(self, *args, **kwargs): return None\n" - return function.find("def") * " " + function + replacement = "def _move_model_to_vllm(self, *args, **kwargs): return None\n" + return " "*function.find("def") + replacement pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) From d7cefba3e2b00f4fe066f6f547afd44ea5b67dac Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 18:44:47 -0800 Subject: [PATCH 281/473] Update llama.py --- unsloth/models/llama.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index eaf4f8b73..0b567b023 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -448,20 +448,28 @@ def LlamaAttention_fast_forward( A = flash_attn_func(Q, K, V, causal = True) else: # Grouped query attention - if n_groups != 1: - K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) - V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) - K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) - V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) - pass - # Must be contiguous or else results are False! - # https://github.com/pytorch/pytorch/issues/112577 - Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() - # Needs (batch_size, n_heads, seq_len, head_dim) - # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) - # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2).contiguous() + if SDPA_HAS_GQA: + # Needs (batch_size, n_heads, seq_len, head_dim) + # is_casual and attention_mask must not be both set! + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) + # Go back to (batch_size, seq_len, n_heads, head_dim) + A = A.transpose(1, 2)#.contiguous() + else: + if n_groups != 1: + K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) + V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) + K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) + V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) + pass + # Must be contiguous or else results are False! + # https://github.com/pytorch/pytorch/issues/112577 + Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() + # Needs (batch_size, n_heads, seq_len, head_dim) + # is_casual and attention_mask must not be both set! + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) + # Go back to (batch_size, seq_len, n_heads, head_dim) + A = A.transpose(1, 2).contiguous() + pass pass attn_output = A.reshape(bsz, q_len, n_heads*head_dim) attn_output = self.apply_o(self, attn_output) @@ -1153,7 +1161,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - + return CausalLMOutputWithPast( loss=loss, logits=logits, From 52d996aaf45e2cb8379f2533ca766dcf3abb4fad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 18:50:45 -0800 Subject: [PATCH 282/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0f342ec86..781f5984d 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -146,3 +146,17 @@ def grpo_trainer__move_model_to_vllm(function_name, function): return " "*function.find("def") + replacement pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) + + +# Edit _get_per_token_logps +def grpo_trainer__get_per_token_logps(function_name, function): + if function_name != "_get_per_token_logps": return function + + # Set attention_mask to boolean + function = function.replace( + "attention_mask=attention_mask", + "attention_mask=attention_mask.to(torch.bool)" + ) + return function +pass +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) From 56f5b31d4c45eb7ca19c858d8161009979826572 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 18:57:01 -0800 Subject: [PATCH 283/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0b567b023..7481b833d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,6 +449,7 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: + print("=====================") # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From 5f1e98cb9e49f6094c22933ad97c55f8d38a9650 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 19:01:23 -0800 Subject: [PATCH 284/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7481b833d..1b1da9001 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -375,6 +375,7 @@ def LlamaAttention_fast_forward( del self.RH_Q del self.attention pass + print(attention_mask) bsz, q_len, _ = hidden_states.size() @@ -449,7 +450,6 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: - print("=====================") # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From e713129b867b331ba920adabeaeb3aace5c0b99d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 19:07:50 -0800 Subject: [PATCH 285/473] Update llama.py --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1b1da9001..3a9ee5331 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -375,7 +375,6 @@ def LlamaAttention_fast_forward( del self.RH_Q del self.attention pass - print(attention_mask) bsz, q_len, _ = hidden_states.size() @@ -709,7 +708,7 @@ def LlamaModel_fast_forward( if attention_mask is None: padding_mask = None elif self.training: - attention_mask = None + # attention_mask = None padding_mask = None else: # if 0 in attention_mask: From 310fc16da5d59634b5fec2edc80152b767132cbb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 19:10:48 -0800 Subject: [PATCH 286/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3a9ee5331..452bb78e2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,6 +449,7 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: + print(attention_mask.shape, Q.shape, K.shape, V.shape) # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From 76a122e9012473d5aa1d027bf242e8e4d76bf2f0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 19:11:09 -0800 Subject: [PATCH 287/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 452bb78e2..e04c573c6 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,7 +449,7 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: - print(attention_mask.shape, Q.shape, K.shape, V.shape) + print(attention_mask.shape, Q.shape, K.shape, V.shape, attention_mask.dtype) # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From 2dd29e57654a8036646da5fb82f9c2060cd20b5f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 20:18:07 -0800 Subject: [PATCH 288/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 781f5984d..aaa5b7214 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -153,10 +153,10 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function # Set attention_mask to boolean - function = function.replace( - "attention_mask=attention_mask", - "attention_mask=attention_mask.to(torch.bool)" - ) + # function = function.replace( + # "attention_mask=attention_mask", + # "attention_mask=attention_mask.to(torch.bool)" + # ) return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) From 3c5be915066f803f96eec892fee773c431fba7cc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 20:24:48 -0800 Subject: [PATCH 289/473] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index e04c573c6..fbc6d53af 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -706,9 +706,10 @@ def LlamaModel_fast_forward( pass # Ignore attention_mask + print(attention_mask, attention_mask.shape, attention_mask.dtype) if attention_mask is None: padding_mask = None - elif self.training: + elif attention_mask is None and self.training: # attention_mask = None padding_mask = None else: From e548b1517970a26ddb743eb3a2dbcac07da06684 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 20:29:10 -0800 Subject: [PATCH 290/473] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fbc6d53af..653ebb351 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -706,7 +706,6 @@ def LlamaModel_fast_forward( pass # Ignore attention_mask - print(attention_mask, attention_mask.shape, attention_mask.dtype) if attention_mask is None: padding_mask = None elif attention_mask is None and self.training: From 296b3b3196010f14cd872650d455d0d1929e56a3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 20:33:37 -0800 Subject: [PATCH 291/473] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 653ebb351..088450b9e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,7 +449,6 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: - print(attention_mask.shape, Q.shape, K.shape, V.shape, attention_mask.dtype) # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From 8de588b4df1091d3da0d635b01e1417b24c4eda7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 21:06:33 -0800 Subject: [PATCH 292/473] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 088450b9e..0b567b023 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -707,8 +707,8 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif attention_mask is None and self.training: - # attention_mask = None + elif self.training: + attention_mask = None padding_mask = None else: # if 0 in attention_mask: From f87909a12c01c59b9b5584a023f88e69530406f5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 21:16:44 -0800 Subject: [PATCH 293/473] Update pyproject.toml --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d89ea2c4d..5bdf3c4dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -187,9 +187,9 @@ cu124onlytorch260 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu126onlytorch260 = [ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", From 270444089c55bbc200de6fa045c9690dacb1fdc8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 22:34:50 -0800 Subject: [PATCH 294/473] Update llama.py --- unsloth/models/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0b567b023..8436ab18e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -707,9 +707,9 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training: - attention_mask = None - padding_mask = None + # elif self.training: + # attention_mask = None + # padding_mask = None else: # if 0 in attention_mask: # padding_mask = attention_mask From 42e196752b2789d185914928f5fa619fc148c511 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 22:56:47 -0800 Subject: [PATCH 295/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8436ab18e..af144f01a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1161,7 +1161,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - + print("***", logits.dtype, logits.shape) return CausalLMOutputWithPast( loss=loss, logits=logits, From 36bf805fa331a35c811e3f82a2d9348ad3732843 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 23:00:51 -0800 Subject: [PATCH 296/473] Update llama.py --- unsloth/models/llama.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index af144f01a..1f002b559 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1162,6 +1162,13 @@ def _CausalLM_fast_forward( output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output print("***", logits.dtype, logits.shape) + print(CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )) return CausalLMOutputWithPast( loss=loss, logits=logits, From a3af8e3718cc3e4208828d02757224feff42921d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 23:03:40 -0800 Subject: [PATCH 297/473] Update llama.py --- unsloth/models/llama.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1f002b559..af144f01a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1162,13 +1162,6 @@ def _CausalLM_fast_forward( output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output print("***", logits.dtype, logits.shape) - print(CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - )) return CausalLMOutputWithPast( loss=loss, logits=logits, From 9d10d2f41b2cf825a934c35021ae30d6789bb372 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:10:44 -0800 Subject: [PATCH 298/473] Update llama.py --- unsloth/models/llama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index af144f01a..0b567b023 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -707,9 +707,9 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - # elif self.training: - # attention_mask = None - # padding_mask = None + elif self.training: + attention_mask = None + padding_mask = None else: # if 0 in attention_mask: # padding_mask = attention_mask @@ -1161,7 +1161,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - print("***", logits.dtype, logits.shape) + return CausalLMOutputWithPast( loss=loss, logits=logits, From b30a81f3085743228b25e42b2bae0caf1b3a46df Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:26:58 -0800 Subject: [PATCH 299/473] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0b567b023..2d5e43ba6 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1189,7 +1189,7 @@ def PeftModelForCausalLM_fast_forward( num_logits_to_keep=0, **kwargs, ): - return self.base_model( + a = self.base_model( input_ids=input_ids, causal_mask=causal_mask, attention_mask=attention_mask, @@ -1201,6 +1201,7 @@ def PeftModelForCausalLM_fast_forward( num_logits_to_keep=num_logits_to_keep, **kwargs, ) + print(a) pass From b7e855945e7413bd17d61014deb5c53c718d40c8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:29:34 -0800 Subject: [PATCH 300/473] Update llama.py --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 2d5e43ba6..0b567b023 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1189,7 +1189,7 @@ def PeftModelForCausalLM_fast_forward( num_logits_to_keep=0, **kwargs, ): - a = self.base_model( + return self.base_model( input_ids=input_ids, causal_mask=causal_mask, attention_mask=attention_mask, @@ -1201,7 +1201,6 @@ def PeftModelForCausalLM_fast_forward( num_logits_to_keep=num_logits_to_keep, **kwargs, ) - print(a) pass From 4b201d98c6cc5dfec3e249dde69c9fb7f9344c0b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 01:42:55 -0800 Subject: [PATCH 301/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index aaa5b7214..0df37e508 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -124,6 +124,7 @@ def grpo_trainer__prepare_inputs(function_name, function): if "with torch.inference_mode()" not in function: return function + # Add mixed precision training function = function.replace( "with torch.inference_mode():", @@ -132,6 +133,12 @@ def grpo_trainer__prepare_inputs(function_name, function): "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ "if not torch.is_autocast_enabled('cuda') else nullcontext():", ) + + # Disable attaching a float32 conversion hook which upcasts logits to FP32 + function = function.replace( + "self.accelerator.unwrap_model(self.model)", + "self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False)", + ) return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) @@ -148,15 +155,26 @@ def grpo_trainer__move_model_to_vllm(function_name, function): RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) -# Edit _get_per_token_logps +# Edit _get_per_token_logps to handle mixed precision def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function - # Set attention_mask to boolean - # function = function.replace( - # "attention_mask=attention_mask", - # "attention_mask=attention_mask.to(torch.bool)" - # ) + # Edit model to autocast it + # .*? matches first match. .+? matches final match. + original = re.findall( + f"logits = model\(.*?\)", + function, + flags = re.MULTILINE | re.DOTALL, + ) + if len(original) != 0: + original = original[0] + replacer = \ + " "*4 + "with torch.amp.autocast(device_type = 'cuda', "\ + "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ + "if not torch.is_autocast_enabled('cuda') else nullcontext():\n" + \ + " "*8 + original + function = function.replace(original, replacer) + pass return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) From dc723bc70eb78c914a6f86d6a69e94328c3ac179 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 01:44:00 -0800 Subject: [PATCH 302/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0df37e508..a56a7840c 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -169,7 +169,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if len(original) != 0: original = original[0] replacer = \ - " "*4 + "with torch.amp.autocast(device_type = 'cuda', "\ + "with torch.amp.autocast(device_type = 'cuda', "\ "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ "if not torch.is_autocast_enabled('cuda') else nullcontext():\n" + \ " "*8 + original From 0309949b63080b8b5a7834c217bce9e0c950cad6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 01:48:38 -0800 Subject: [PATCH 303/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a56a7840c..e8fb1ffc0 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -168,11 +168,12 @@ def grpo_trainer__get_per_token_logps(function_name, function): ) if len(original) != 0: original = original[0] + spaces = function.find(original) replacer = \ "with torch.amp.autocast(device_type = 'cuda', "\ "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ "if not torch.is_autocast_enabled('cuda') else nullcontext():\n" + \ - " "*8 + original + " "*(spaces + 4) + original function = function.replace(original, replacer) pass return function From c409574568715e7552572bff61411ec2d6acd7e2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 01:52:07 -0800 Subject: [PATCH 304/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index e8fb1ffc0..6abab318a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -85,7 +85,7 @@ def sft_trainer_prepare_dataset(function_name, function): # .*? matches first match. .+? matches final match. replacer = re.findall( - f"def {function_name}\(.*?\).*?\:\n", + r"def {function_name}\(.*?\).*?\:\n", function, flags = re.MULTILINE | re.DOTALL, ) @@ -104,7 +104,7 @@ def sft_trainer_compute_loss(function_name, function): # .*? matches first match. .+? matches final match. replacer = re.findall( - f"\.compute_loss\(.*?\)", + r"\.compute_loss\(.*?\)", function, flags = re.MULTILINE | re.DOTALL, ) @@ -162,13 +162,13 @@ def grpo_trainer__get_per_token_logps(function_name, function): # Edit model to autocast it # .*? matches first match. .+? matches final match. original = re.findall( - f"logits = model\(.*?\)", + r"\n([ ]{4,})(logits = model\(.*?\))", function, flags = re.MULTILINE | re.DOTALL, ) if len(original) != 0: - original = original[0] - spaces = function.find(original) + spaces, original = original[0] + spaces = len(spaces) replacer = \ "with torch.amp.autocast(device_type = 'cuda', "\ "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ From 8e5b09adb05e9306a11e81a35ef1e07adc1d80ad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 01:57:45 -0800 Subject: [PATCH 305/473] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0b567b023..d6916814a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -707,7 +707,7 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training: + elif attention_mask is not None and self.training: attention_mask = None padding_mask = None else: @@ -723,6 +723,7 @@ def LlamaModel_fast_forward( past_key_values_length, sliding_window = getattr(self.config, "sliding_window", None), ) + attention_mask = attention_mask.to(torch.bool) pass hidden_states = inputs_embeds From 6652f1df661e973cc122d0260fd266511942a3f2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 02:01:51 -0800 Subject: [PATCH 306/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 6abab318a..048db868a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -168,12 +168,12 @@ def grpo_trainer__get_per_token_logps(function_name, function): ) if len(original) != 0: spaces, original = original[0] - spaces = len(spaces) + spaces = len(spaces) + 4 replacer = \ - "with torch.amp.autocast(device_type = 'cuda', "\ - "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ - "if not torch.is_autocast_enabled('cuda') else nullcontext():\n" + \ - " "*(spaces + 4) + original + "if not hasattr(self, '_autocast_dtype'):\n" + \ + " "*spaces + "self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16\n" + \ + "with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):\n" + \ + " "*spaces + original function = function.replace(original, replacer) pass return function From 9215bbefb5a0ec03f08870c93bf9b2b745c8a50b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 02:04:11 -0800 Subject: [PATCH 307/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 048db868a..81ca2debc 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -172,7 +172,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): replacer = \ "if not hasattr(self, '_autocast_dtype'):\n" + \ " "*spaces + "self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16\n" + \ - "with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):\n" + \ + " "*spaces + "with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):\n" + \ " "*spaces + original function = function.replace(original, replacer) pass From 4bff998081e3622bb60080dd51d631cd8e37a797 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 02:06:36 -0800 Subject: [PATCH 308/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 81ca2debc..3eb16bb1f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -168,12 +168,12 @@ def grpo_trainer__get_per_token_logps(function_name, function): ) if len(original) != 0: spaces, original = original[0] - spaces = len(spaces) + 4 + spaces = len(spaces) replacer = \ "if not hasattr(self, '_autocast_dtype'):\n" + \ - " "*spaces + "self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16\n" + \ - " "*spaces + "with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):\n" + \ - " "*spaces + original + " "*(spaces + 4) + "self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16\n" + \ + " "*(spaces + 0) + "with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):\n" + \ + " "*(spaces + 4) + original function = function.replace(original, replacer) pass return function From c859030d0f641502b63a5a6941a03774e5525580 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 04:25:56 -0800 Subject: [PATCH 309/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 3eb16bb1f..968f2b19f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -18,6 +18,7 @@ ] import re +import inspect from collections import defaultdict RL_EXTRA_ARGS = defaultdict(list) RL_FUNCTIONS = defaultdict(list) @@ -99,20 +100,21 @@ def sft_trainer_prepare_dataset(function_name, function): # Ignore mean_token_accuracy since it needs logits +# We override it directly with our version +def _sft_trainer_compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None): + (loss, outputs) = super().compute_loss( + model, + inputs, + return_outputs = return_outputs, + num_items_in_batch = num_items_in_batch, + ) + return (loss, outputs) if return_outputs else loss +pass + def sft_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function - # .*? matches first match. .+? matches final match. - replacer = re.findall( - r"\.compute_loss\(.*?\)", - function, - flags = re.MULTILINE | re.DOTALL, - ) - if len(replacer) != 0: - replacer = replacer[0] - returner = "\n" + " "*8 + "return (loss, outputs) if return_outputs else loss" - function = function.replace(replacer, replacer + returner) - pass + function = inspect.getsource(_sft_trainer_compute_loss) return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss) From 2daa8e3e3cf5715f13d31dc0372fb0cb094cf756 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 04:34:13 -0800 Subject: [PATCH 310/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 968f2b19f..5da57c44b 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -115,6 +115,7 @@ def sft_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function function = inspect.getsource(_sft_trainer_compute_loss) + function = function.replace("def _sft_trainer_compute_loss", "def compute_loss") return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss) From 527a0c4fc8f18b22926bb29b4919109a7113b4da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 04:40:42 -0800 Subject: [PATCH 311/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 5da57c44b..4d7a4dbe0 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -116,6 +116,8 @@ def sft_trainer_compute_loss(function_name, function): function = inspect.getsource(_sft_trainer_compute_loss) function = function.replace("def _sft_trainer_compute_loss", "def compute_loss") + function = function.split("\n") + function = "\n".join(" "*4+x for x in function) return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss) From 087a5dc2f02a6fdcbc76d3e33e3a4c7104874f75 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 04:40:53 -0800 Subject: [PATCH 312/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4d7a4dbe0..aeb5f3e0d 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -108,6 +108,7 @@ def _sft_trainer_compute_loss(self, model, inputs, return_outputs = False, num_i return_outputs = return_outputs, num_items_in_batch = num_items_in_batch, ) + print(loss, outputs) return (loss, outputs) if return_outputs else loss pass From 73210b3b8e82131b23ea47eb43e53d69c7de571f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 04:44:21 -0800 Subject: [PATCH 313/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index aeb5f3e0d..4d7a4dbe0 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -108,7 +108,6 @@ def _sft_trainer_compute_loss(self, model, inputs, return_outputs = False, num_i return_outputs = return_outputs, num_items_in_batch = num_items_in_batch, ) - print(loss, outputs) return (loss, outputs) if return_outputs else loss pass From 2635f2af96ea1ea592eb7008763dba4b7833dd2c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 14:42:38 -0800 Subject: [PATCH 314/473] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d6916814a..ec6706e51 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -707,7 +707,8 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif attention_mask is not None and self.training: + elif self.training: + # elif attention_mask is not None and self.training: attention_mask = None padding_mask = None else: From 69ab838499d4c53413d214732690d3f8fad1724b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 14:47:54 -0800 Subject: [PATCH 315/473] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 6aa7f94cf..656096b70 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.4" +__version__ = "2025.2.5" __all__ = [ "SUPPORTS_BFLOAT16", From acf98dccdcfb3a4c329230517603dea9bb214250 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:11:51 -0800 Subject: [PATCH 316/473] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ec6706e51..511ae5c68 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -724,7 +724,8 @@ def LlamaModel_fast_forward( past_key_values_length, sliding_window = getattr(self.config, "sliding_window", None), ) - attention_mask = attention_mask.to(torch.bool) + if attention_mask is not None: + attention_mask = attention_mask.to(torch.bool) pass hidden_states = inputs_embeds From 139911095fca3316fd24cbbedc7236e279c48413 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:23:51 -0800 Subject: [PATCH 317/473] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index f5d00eab2..8d0eadb96 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.8" +__version__ = "2025.2.9" __all__ = [ "SUPPORTS_BFLOAT16", From 881105b2c828c0580b9d60b2b2432b379c4733ca Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:27:02 -0800 Subject: [PATCH 318/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4d7a4dbe0..0a6ea5dff 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -101,23 +101,20 @@ def sft_trainer_prepare_dataset(function_name, function): # Ignore mean_token_accuracy since it needs logits # We override it directly with our version -def _sft_trainer_compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None): - (loss, outputs) = super().compute_loss( - model, - inputs, - return_outputs = return_outputs, - num_items_in_batch = num_items_in_batch, - ) - return (loss, outputs) if return_outputs else loss -pass - def sft_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function - function = inspect.getsource(_sft_trainer_compute_loss) - function = function.replace("def _sft_trainer_compute_loss", "def compute_loss") - function = function.split("\n") - function = "\n".join(" "*4+x for x in function) + def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None): + (loss, outputs) = super().compute_loss( + model, + inputs, + return_outputs = return_outputs, + num_items_in_batch = num_items_in_batch, + ) + return (loss, outputs) if return_outputs else loss + pass + + function = inspect.getsource(compute_loss) return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss) From cfdd3f150f011132c72e713a3dd8c374229da1f3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:27:23 -0800 Subject: [PATCH 319/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index fc094b083..b8d191dcf 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -402,7 +402,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = False, + overwrite = True, ) # Patch Trainer From 95b7df53e874ce8ea55fdcfa6c2568182e30d16d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:35:09 -0800 Subject: [PATCH 320/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b8d191dcf..fadae874d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -404,6 +404,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): imports, overwrite = True, ) + print("###") # Patch Trainer exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) From 17bfcf9ebb94672746c9d17b3df90a6c854900b8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:37:05 -0800 Subject: [PATCH 321/473] Update rl.py --- unsloth/models/rl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index fadae874d..048ec7bb0 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -566,8 +566,8 @@ def patch_trl_rl_trainers(): def PatchFastRL(algorithm = None, FastLanguageModel = None): - return - # if FastLanguageModel is not None: PatchRL(FastLanguageModel) - # patch_trl_rl_trainers() - # if algorithm is not None: PatchRLStatistics(algorithm) + if FastLanguageModel is not None: PatchRL(FastLanguageModel) + patch_trl_rl_trainers() + if type(algorithm) is str and algorithm.islower(): + PatchRLStatistics(algorithm) pass From 61c219d4fc610c9a2706c62d88956b5290462019 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:38:21 -0800 Subject: [PATCH 322/473] Update rl.py --- unsloth/models/rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 048ec7bb0..9f5fe99c9 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -404,7 +404,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): imports, overwrite = True, ) - print("###") # Patch Trainer exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) From 9794dc230878e74f724649310ea1eae80b360ab6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:47:27 -0800 Subject: [PATCH 323/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 9f5fe99c9..3d601b0af 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -402,7 +402,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = True, + overwrite = False, ) # Patch Trainer From 3687a6f7b9192faa3c2ef79fbd1fa2b8caffd1a3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:00:14 -0800 Subject: [PATCH 324/473] Update llama.py --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 511ae5c68..817b014ac 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1052,6 +1052,7 @@ def _CausalLM_fast_forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model._has_no_labels = labels is None + print(1055, input_ids) outputs = self.model( input_ids=input_ids, causal_mask=causal_mask, @@ -1064,6 +1065,7 @@ def _CausalLM_fast_forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) + print(1068) pass hidden_states = outputs[0] From c495bfad6922a45171a39179427d67d206b9e7db Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:05:04 -0800 Subject: [PATCH 325/473] Update llama.py --- unsloth/models/llama.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 817b014ac..188c12ba9 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1068,6 +1068,7 @@ def _CausalLM_fast_forward( print(1068) pass hidden_states = outputs[0] + print(1071) bsz, q_len, hd = hidden_states.shape lm_head = self.lm_head.weight @@ -1084,6 +1085,8 @@ def _CausalLM_fast_forward( RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" # < 1024 Normal Unsloth uses less VRAM! if bsz*q_len <= 1024: RETURN_LOGITS = True + + print(1089) if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: @@ -1095,6 +1098,8 @@ def _CausalLM_fast_forward( num_items_in_batch = n_items, logit_softcapping = logit_softcapping, ) + + print(1102, loss) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -1108,6 +1113,7 @@ def _CausalLM_fast_forward( ) return output pass + print(1116, hidden_states.dtype, hidden_states.shape) logits = self.lm_head(hidden_states.to(dtype)) pass @@ -1117,6 +1123,7 @@ def _CausalLM_fast_forward( else: raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") pass + print(1126) loss = None logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) @@ -1142,6 +1149,7 @@ def _CausalLM_fast_forward( logit_scaling = logit_scaling, n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None), ) + print(1152, loss) else: if logit_scaling != 0: if logits.requires_grad: @@ -1166,7 +1174,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - + print(1177, loss, logits.shape, logits.dtype) return CausalLMOutputWithPast( loss=loss, logits=logits, From f9055a767e1ea34b333363873b6533135a86fd49 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:11:33 -0800 Subject: [PATCH 326/473] Update llama.py --- unsloth/models/llama.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 188c12ba9..511ae5c68 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1052,7 +1052,6 @@ def _CausalLM_fast_forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model._has_no_labels = labels is None - print(1055, input_ids) outputs = self.model( input_ids=input_ids, causal_mask=causal_mask, @@ -1065,10 +1064,8 @@ def _CausalLM_fast_forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - print(1068) pass hidden_states = outputs[0] - print(1071) bsz, q_len, hd = hidden_states.shape lm_head = self.lm_head.weight @@ -1085,8 +1082,6 @@ def _CausalLM_fast_forward( RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" # < 1024 Normal Unsloth uses less VRAM! if bsz*q_len <= 1024: RETURN_LOGITS = True - - print(1089) if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: @@ -1098,8 +1093,6 @@ def _CausalLM_fast_forward( num_items_in_batch = n_items, logit_softcapping = logit_softcapping, ) - - print(1102, loss) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -1113,7 +1106,6 @@ def _CausalLM_fast_forward( ) return output pass - print(1116, hidden_states.dtype, hidden_states.shape) logits = self.lm_head(hidden_states.to(dtype)) pass @@ -1123,7 +1115,6 @@ def _CausalLM_fast_forward( else: raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") pass - print(1126) loss = None logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) @@ -1149,7 +1140,6 @@ def _CausalLM_fast_forward( logit_scaling = logit_scaling, n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None), ) - print(1152, loss) else: if logit_scaling != 0: if logits.requires_grad: @@ -1174,7 +1164,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - print(1177, loss, logits.shape, logits.dtype) + return CausalLMOutputWithPast( loss=loss, logits=logits, From 945e3f95e14a90f4d5b75b60d85ab8b7ced22e33 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:12:06 -0800 Subject: [PATCH 327/473] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 511ae5c68..04d2ee039 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -707,8 +707,8 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training: - # elif attention_mask is not None and self.training: + # elif self.training: + elif attention_mask is not None and self.training: attention_mask = None padding_mask = None else: From 3d9fe12a2310771c4f6a858b82a90069f8f1061e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:15:29 -0800 Subject: [PATCH 328/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0a6ea5dff..82fd3f8d3 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -105,13 +105,13 @@ def sft_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None): - (loss, outputs) = super().compute_loss( + outputs = super().compute_loss( model, inputs, return_outputs = return_outputs, num_items_in_batch = num_items_in_batch, ) - return (loss, outputs) if return_outputs else loss + return outputs pass function = inspect.getsource(compute_loss) From ed907850ad1bccf330488dc7d751189418046c7d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:18:42 -0800 Subject: [PATCH 329/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 04d2ee039..841dcd7c4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,6 +449,7 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: + print("##") # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From 640bc8878e7820a3d8f6eb4dee4198dec4a49957 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:22:39 -0800 Subject: [PATCH 330/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 841dcd7c4..d6f6ae6f0 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -709,7 +709,7 @@ def LlamaModel_fast_forward( if attention_mask is None: padding_mask = None # elif self.training: - elif attention_mask is not None and self.training: + elif attention_mask is not None: attention_mask = None padding_mask = None else: From bb3bb2dc8c059fc6e3f303b9fca6cfceb7dfef8a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:25:12 -0800 Subject: [PATCH 331/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d6f6ae6f0..811e6ccd1 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -709,7 +709,7 @@ def LlamaModel_fast_forward( if attention_mask is None: padding_mask = None # elif self.training: - elif attention_mask is not None: + elif attention_mask is None: attention_mask = None padding_mask = None else: From 9065938acb1d8614c830194bb5117fb87f13899a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 19:11:38 -0800 Subject: [PATCH 332/473] Update llama.py --- unsloth/models/llama.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 811e6ccd1..1eae97ff1 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,7 +449,6 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: - print("##") # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) @@ -708,8 +707,8 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - # elif self.training: - elif attention_mask is None: + elif self.training: + # elif attention_mask is None: attention_mask = None padding_mask = None else: From 48c5e0d121ec1e651e103e98b3d63b0300447e9e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:30:15 -0800 Subject: [PATCH 333/473] GRPO optimized --- unsloth/models/rl.py | 55 ++++++++++++- unsloth/models/rl_replacements.py | 127 ++++++++++++++++++++++++++---- 2 files changed, 165 insertions(+), 17 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3d601b0af..a216f4f38 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -26,8 +26,17 @@ from .rl_replacements import ( RL_EXTRA_ARGS, RL_FUNCTIONS, + RL_PRE_ITEMS, ) +torch_compile_options = { + "epilogue_fusion" : True, + "max_autotune" : True, + "shape_padding" : True, + "trace.enabled" : False, + "triton.cudagraphs" : False, +} + def PatchRL(FastLanguageModel): from trl.models.utils import unwrap_model_for_generation @@ -74,6 +83,23 @@ def generate_with_clone(*args, **kwargs): pass +# https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L1674 +@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) +def _selective_log_softmax(logits, index): + logits = logits.to(torch.float32) + selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + # loop to reduce peak mem consumption + # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + logsumexp_values = torch.logsumexp(logits, dim = -1) + per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + return per_token_logps +pass + +def selective_log_softmax(logits, index): + return _selective_log_softmax(logits, index) +pass + + RLTrainer_replacement = ''' import os from typing import * @@ -81,6 +107,17 @@ def generate_with_clone(*args, **kwargs): from packaging.version import Version import torch from contextlib import nullcontext +from torch.nn import functional as F +torch_compile_options = { + "epilogue_fusion" : True, + "max_autotune" : True, + "shape_padding" : True, + "trace.enabled" : False, + "triton.cudagraphs" : False, +} + +{selective_log_softmax_code} +{RL_pre} @dataclass class Unsloth{RLConfig_name}({RLConfig_name}): @@ -377,6 +414,19 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): __RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__ __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ + # Get all pre-modules + if RLTrainer_name in RL_PRE_ITEMS: + RL_pre = "\n".join(RL_PRE_ITEMS) + else: + RL_pre = "" + pass + + # Selective log softmax + selective_log_softmax_code = \ + inspect.getsource(_selective_log_softmax) + "\n" + \ + inspect.getsource(selective_log_softmax) + "\n" + + # Get final source code RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, __RLTrainer_doc__ = __RLTrainer_doc__, @@ -394,6 +444,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_extras = RLTrainer_extras, RLTrainer_post = RLTrainer_post, + RL_pre = RL_pre, + + selective_log_softmax_code = selective_log_softmax_code, ) # Create new function @@ -402,7 +455,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = False, + overwrite = True, ) # Patch Trainer diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 82fd3f8d3..39db05355 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -15,6 +15,7 @@ __all__ = [ "RL_EXTRA_ARGS", "RL_FUNCTIONS", + "RL_PRE_ITEMS", ] import re @@ -22,7 +23,15 @@ from collections import defaultdict RL_EXTRA_ARGS = defaultdict(list) RL_FUNCTIONS = defaultdict(list) +RL_PRE_ITEMS = defaultdict(list) +torch_compile_options = { + "epilogue_fusion" : True, + "max_autotune" : True, + "shape_padding" : True, + "trace.enabled" : False, + "triton.cudagraphs" : False, +} # Check untrained tokens def sft_trainer_fix_untraiend_tokens(call_args, extra_args): @@ -161,23 +170,109 @@ def grpo_trainer__move_model_to_vllm(function_name, function): def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function - # Edit model to autocast it - # .*? matches first match. .+? matches final match. - original = re.findall( - r"\n([ ]{4,})(logits = model\(.*?\))", - function, - flags = re.MULTILINE | re.DOTALL, - ) - if len(original) != 0: - spaces, original = original[0] - spaces = len(spaces) - replacer = \ - "if not hasattr(self, '_autocast_dtype'):\n" + \ - " "*(spaces + 4) + "self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16\n" + \ - " "*(spaces + 0) + "with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):\n" + \ - " "*(spaces + 4) + original - function = function.replace(original, replacer) + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + if not hasattr(self, '_autocast_dtype'): + self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 + with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits + logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + + input_ids = input_ids[:, -logits_to_keep:] + # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. + # See https://github.com/huggingface/trl/issues/2770 + logits = logits[:, -logits_to_keep:] + return logits + # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + pass pass + + function = inspect.getsource(_get_per_token_logps) return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) + + +# Custom compiled GRPO loss - creates 3 Triton kernels +@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) +def _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): + old_logits = old_logits.to(torch.float32) + new_logits = new_logits.to(torch.float32) + input_ids = input_ids.unsqueeze(-1) + + # x_i - logsumexp(x_i) + old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1) + new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1) + old = old_x - torch.logsumexp(old_logits, dim = -1) + new = new_x - torch.logsumexp(new_logits, dim = -1) + + kl_i = torch.exp(old - new) - (old - new) - 1.0 + loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) + loss_i = -(loss_i - beta * kl_i) + + mask = mask.to(torch.float32) + n_mask = mask.sum(1) + loss_per_reward = (loss_i * mask).sum(1) / n_mask + loss = loss_per_reward.mean() + + # Get metrics as well which are folded + with torch.inference_mode(): + completion_length = n_mask.mean() + mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask + mean_kl = mean_kl_per_reward.mean() + pass + return loss, completion_length, mean_kl +pass +def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): + loss, completion_length, mean_kl = _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta) + return loss, completion_length.item(), mean_kl.item() +pass +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource((_grpo_compute_loss))) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource((grpo_compute_loss))) + + +# Edit _get_per_token_logps to handle mixed precision +def grpo_trainer_compute_loss(function_name, function): + if function_name != "compute_loss": return function + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") + # Compute the per-token log probabilities for the model + + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + attention_mask = None + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + + # Compute the KL divergence between the model and the reference model + ref_per_token_logps = inputs["ref_per_token_logps"] + # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + + # x - x.detach() allows for preserving gradients from x + advantages = inputs["advantages"] + # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + # per_token_loss = -(per_token_loss - self.beta * per_token_kl) + # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + + loss, completion_length, mean_kl = grpo_compute_loss( + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, + ) + # Log the metrics + # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() + self._metrics["completion_length"].append(completion_length) + + # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) + self._metrics["kl"].append(mean_kl) + return loss + pass + + function = inspect.getsource(compute_loss) + return function +pass +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) From 3a1fb635b4dcd977d282a2c9f84f98f0bac2af59 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:31:27 -0800 Subject: [PATCH 334/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index a216f4f38..ecd394cea 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -21,6 +21,7 @@ import inspect import os import re +import torch from unsloth_zoo.compiler import create_new_function from unsloth_zoo.logging_utils import PatchRLStatistics from .rl_replacements import ( From 19014b0f7e73fae525b3dba08374e5534525867d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:32:24 -0800 Subject: [PATCH 335/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 39db05355..ed802e487 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -19,6 +19,7 @@ ] import re +import torch import inspect from collections import defaultdict RL_EXTRA_ARGS = defaultdict(list) From 0c17e794f35c49f803a27f9ed2dac5126942820b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:33:41 -0800 Subject: [PATCH 336/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ed802e487..9b9a113f2 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -212,14 +212,14 @@ def _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): loss_i = -(loss_i - beta * kl_i) mask = mask.to(torch.float32) - n_mask = mask.sum(1) - loss_per_reward = (loss_i * mask).sum(1) / n_mask + n_mask_per_reward = mask.sum(1) + loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward loss = loss_per_reward.mean() # Get metrics as well which are folded with torch.inference_mode(): - completion_length = n_mask.mean() - mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask + completion_length = n_mask_per_reward.mean() + mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward mean_kl = mean_kl_per_reward.mean() pass return loss, completion_length, mean_kl From aee44e219f31cb201e28221136b50d5ae21f5ce1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:35:03 -0800 Subject: [PATCH 337/473] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ecd394cea..8dcd855d0 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -109,13 +109,13 @@ def selective_log_softmax(logits, index): import torch from contextlib import nullcontext from torch.nn import functional as F -torch_compile_options = { +torch_compile_options = {{ "epilogue_fusion" : True, "max_autotune" : True, "shape_padding" : True, "trace.enabled" : False, "triton.cudagraphs" : False, -} +}} {selective_log_softmax_code} {RL_pre} From 953d957c694a8954050e309a8687c42023c290c5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:38:03 -0800 Subject: [PATCH 338/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 8dcd855d0..fb1446037 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -417,7 +417,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Get all pre-modules if RLTrainer_name in RL_PRE_ITEMS: - RL_pre = "\n".join(RL_PRE_ITEMS) + RL_pre = "\n".join(RL_PRE_ITEMS[RLTrainer_name]) else: RL_pre = "" pass From 2a2b9f7c7cd4ce8b4326fe05e73e768ff177eae5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:42:05 -0800 Subject: [PATCH 339/473] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index fb1446037..1ac511e83 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -416,8 +416,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ # Get all pre-modules - if RLTrainer_name in RL_PRE_ITEMS: - RL_pre = "\n".join(RL_PRE_ITEMS[RLTrainer_name]) + if trainer_file in RL_PRE_ITEMS: + RL_pre = "\n".join(RL_PRE_ITEMS[trainer_file]) else: RL_pre = "" pass From fcb0f4aad69f70a009217953e4333c478c599cec Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:44:03 -0800 Subject: [PATCH 340/473] Update rl.py --- unsloth/models/rl.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 1ac511e83..128725a0a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -86,7 +86,7 @@ def generate_with_clone(*args, **kwargs): # https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L1674 @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def _selective_log_softmax(logits, index): +def selective_log_softmax(logits, index): logits = logits.to(torch.float32) selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) # loop to reduce peak mem consumption @@ -96,10 +96,6 @@ def _selective_log_softmax(logits, index): return per_token_logps pass -def selective_log_softmax(logits, index): - return _selective_log_softmax(logits, index) -pass - RLTrainer_replacement = ''' import os @@ -423,10 +419,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pass # Selective log softmax - selective_log_softmax_code = \ - inspect.getsource(_selective_log_softmax) + "\n" + \ - inspect.getsource(selective_log_softmax) + "\n" - + selective_log_softmax_code = inspect.getsource(selective_log_softmax) + # Get final source code RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, From eabc36527590a07449aa4da25196b8a876783752 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:45:48 -0800 Subject: [PATCH 341/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9b9a113f2..36022f1e3 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -195,7 +195,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # Custom compiled GRPO loss - creates 3 Triton kernels -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) +# @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) def _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): old_logits = old_logits.to(torch.float32) new_logits = new_logits.to(torch.float32) From 74083182a2092af9adc7fc000e4ae44894115db4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:49:41 -0800 Subject: [PATCH 342/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 36022f1e3..30b304563 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -195,7 +195,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # Custom compiled GRPO loss - creates 3 Triton kernels -# @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) +@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) def _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): old_logits = old_logits.to(torch.float32) new_logits = new_logits.to(torch.float32) @@ -247,7 +247,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) attention_mask = None logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model @@ -259,7 +259,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() - + input_ids = input_ids[:, -logits_to_keep:] loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, ) From f35eae3a90d4ba57865bb9cdb6c8000da5408603 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:53:06 -0800 Subject: [PATCH 343/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 30b304563..c4a52987a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -196,7 +196,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # Custom compiled GRPO loss - creates 3 Triton kernels @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): +def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages): old_logits = old_logits.to(torch.float32) new_logits = new_logits.to(torch.float32) input_ids = input_ids.unsqueeze(-1) @@ -224,11 +224,6 @@ def _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): pass return loss, completion_length, mean_kl pass -def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): - loss, completion_length, mean_kl = _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta) - return loss, completion_length.item(), mean_kl.item() -pass -RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource((_grpo_compute_loss))) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource((grpo_compute_loss))) @@ -247,7 +242,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) attention_mask = None logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model @@ -261,15 +256,15 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() - self._metrics["completion_length"].append(completion_length) + self._metrics["completion_length"].append(completion_length.item()) # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) - self._metrics["kl"].append(mean_kl) + self._metrics["kl"].append(mean_kl.item()) return loss pass From 2b89daea278ac4bd3cf148c291449fd726ffd131 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 14:56:37 -0800 Subject: [PATCH 344/473] Selective Log softmax --- unsloth/models/rl.py | 17 +++----------- unsloth/models/rl_replacements.py | 38 ++++--------------------------- 2 files changed, 7 insertions(+), 48 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 128725a0a..58b6d8271 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -24,11 +24,13 @@ import torch from unsloth_zoo.compiler import create_new_function from unsloth_zoo.logging_utils import PatchRLStatistics +from unsloth_zoo.rl_replacements import RL_REPLACEMENTS from .rl_replacements import ( RL_EXTRA_ARGS, RL_FUNCTIONS, RL_PRE_ITEMS, ) +selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"] torch_compile_options = { "epilogue_fusion" : True, @@ -84,19 +86,6 @@ def generate_with_clone(*args, **kwargs): pass -# https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L1674 -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def selective_log_softmax(logits, index): - logits = logits.to(torch.float32) - selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) - # loop to reduce peak mem consumption - # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) - logsumexp_values = torch.logsumexp(logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) - return per_token_logps -pass - - RLTrainer_replacement = ''' import os from typing import * @@ -420,7 +409,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Selective log softmax selective_log_softmax_code = inspect.getsource(selective_log_softmax) - + # Get final source code RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index c4a52987a..d01f6cd45 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -22,6 +22,7 @@ import torch import inspect from collections import defaultdict +from unsloth_zoo.rl_replacements import RL_REPLACEMENTS RL_EXTRA_ARGS = defaultdict(list) RL_FUNCTIONS = defaultdict(list) RL_PRE_ITEMS = defaultdict(list) @@ -193,45 +194,14 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) - -# Custom compiled GRPO loss - creates 3 Triton kernels -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages): - old_logits = old_logits.to(torch.float32) - new_logits = new_logits.to(torch.float32) - input_ids = input_ids.unsqueeze(-1) - - # x_i - logsumexp(x_i) - old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1) - new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1) - old = old_x - torch.logsumexp(old_logits, dim = -1) - new = new_x - torch.logsumexp(new_logits, dim = -1) - - kl_i = torch.exp(old - new) - (old - new) - 1.0 - loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) - loss_i = -(loss_i - beta * kl_i) - - mask = mask.to(torch.float32) - n_mask_per_reward = mask.sum(1) - loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward - loss = loss_per_reward.mean() - - # Get metrics as well which are folded - with torch.inference_mode(): - completion_length = n_mask_per_reward.mean() - mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward - mean_kl = mean_kl_per_reward.mean() - pass - return loss, completion_length, mean_kl -pass -RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource((grpo_compute_loss))) - +grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None): if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") # Compute the per-token log probabilities for the model From 45c8431715572d5c18c513a4ab7d8de9d9a5fc1e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 15:32:02 -0800 Subject: [PATCH 345/473] Fix GRPO bsz --- unsloth/models/rl.py | 16 +++++++++++++++- unsloth/models/rl_replacements.py | 24 +++++++++++++++++++++--- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 58b6d8271..eba1e46a2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -29,6 +29,7 @@ RL_EXTRA_ARGS, RL_FUNCTIONS, RL_PRE_ITEMS, + RL_CONFIG_CHANGES, ) selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"] @@ -165,8 +166,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if RLTrainer.__name__.startswith("Unsloth"): return if RLConfig .__name__.startswith("Unsloth"): return + # Get old source + old_RLTrainer_source = inspect.getsource(RLTrainer) + old_RLConfig_source = inspect.getsource(RLConfig) + all_imports = dir(trainer) - imports = [x for x in all_imports if not x.startswith("_")] + # imports = [x for x in all_imports if not x.startswith("_")] + # Fix _deprecate_arguments not getting imported + imports = all_imports # Get default arguments EMPTY = inspect.Parameter.empty @@ -381,6 +388,13 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += num_proc_check pass + # Edit config with anything extra + if trainer_file in RL_CONFIG_CHANGES: + process_extra_args = RL_CONFIG_CHANGES[trainer_file] + for process_extra_arg in process_extra_args: + extra_args += process_extra_arg(old_RLTrainer_source, old_RLConfig_source) + pass + # Edit report_to and default it to nothing if max_steps is like 60 # Create RLConfig args diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index d01f6cd45..fefba2444 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -16,6 +16,7 @@ "RL_EXTRA_ARGS", "RL_FUNCTIONS", "RL_PRE_ITEMS", + "RL_CONFIG_CHANGES", ] import re @@ -23,9 +24,10 @@ import inspect from collections import defaultdict from unsloth_zoo.rl_replacements import RL_REPLACEMENTS -RL_EXTRA_ARGS = defaultdict(list) -RL_FUNCTIONS = defaultdict(list) -RL_PRE_ITEMS = defaultdict(list) +RL_EXTRA_ARGS = defaultdict(list) +RL_FUNCTIONS = defaultdict(list) +RL_PRE_ITEMS = defaultdict(list) +RL_CONFIG_CHANGES = defaultdict(list) torch_compile_options = { "epilogue_fusion" : True, @@ -242,3 +244,19 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) + +# https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 +# TRL warns if batch size is not a multiple of num_generations -> fix this. +def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): + if "multiple of num_generations" not in RLTrainer_source: return "" + if "num_generations" not in RLConfig_source: return "" + + check_batch_size = \ + "div = per_device_train_batch_size // num_generations\n"\ + "if div * num_generations != per_device_train_batch_size:\n"\ + " print('Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n'\\"\ + " 'We will change the batch size of ' + per_device_train_batch_size + ' to the `num_generations` of ' + num_generations')\n"\ + " per_device_train_batch_size = num_generations\n" + return check_batch_size +pass +RL_CONFIG_CHANGES["grpo_trainer"].append(grpo_trainer_fix_batch_size) From 644cedfa339be1c29b5226f30a67995b7a36877f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 15:56:05 -0800 Subject: [PATCH 346/473] Update rl.py --- unsloth/models/rl.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index eba1e46a2..2875ff64a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -171,9 +171,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): old_RLConfig_source = inspect.getsource(RLConfig) all_imports = dir(trainer) - # imports = [x for x in all_imports if not x.startswith("_")] - # Fix _deprecate_arguments not getting imported - imports = all_imports + # Fix _deprecate_arguments not getting imported so stop __ but not _ + imports = [x for x in all_imports if not x.startswith("__")] # Get default arguments EMPTY = inspect.Parameter.empty From 4b765d77590054598eaffbe2b1cce9416c786ee8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 15:58:13 -0800 Subject: [PATCH 347/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index fefba2444..c7fdb4cbd 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -248,7 +248,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): - if "multiple of num_generations" not in RLTrainer_source: return "" + if "divisible by the number of generations" not in RLTrainer_source: return "" if "num_generations" not in RLConfig_source: return "" check_batch_size = \ From 0a7c56d7bdd4aa39d86abf20722ab7b92c182c8d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 16:01:29 -0800 Subject: [PATCH 348/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index c7fdb4cbd..682a35ed1 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -248,8 +248,12 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): - if "divisible by the number of generations" not in RLTrainer_source: return "" - if "num_generations" not in RLConfig_source: return "" + if "divisible by the number of generations" not in RLTrainer_source: + print(RLTrainer_source) + return "" + if "num_generations" not in RLConfig_source: + print(RLConfig_source) + return "" check_batch_size = \ "div = per_device_train_batch_size // num_generations\n"\ From 1b43e1de8dbccd6c580b47a4475a57eedcef1530 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 16:03:13 -0800 Subject: [PATCH 349/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 682a35ed1..2925bd5b7 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -248,18 +248,14 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): - if "divisible by the number of generations" not in RLTrainer_source: - print(RLTrainer_source) - return "" - if "num_generations" not in RLConfig_source: - print(RLConfig_source) - return "" + if "divisible by the number of generations" not in RLTrainer_source: return "" + if "num_generations" not in RLConfig_source: return "" check_batch_size = \ "div = per_device_train_batch_size // num_generations\n"\ "if div * num_generations != per_device_train_batch_size:\n"\ - " print('Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n'\\"\ - " 'We will change the batch size of ' + per_device_train_batch_size + ' to the `num_generations` of ' + num_generations')\n"\ + " print('Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\ + "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations)')\n"\ " per_device_train_batch_size = num_generations\n" return check_batch_size pass From d588665d98934d502dfc852237e9f2ddda086892 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 16:08:49 -0800 Subject: [PATCH 350/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 2925bd5b7..63fe24359 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -255,7 +255,7 @@ def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): "div = per_device_train_batch_size // num_generations\n"\ "if div * num_generations != per_device_train_batch_size:\n"\ " print('Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\ - "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations)')\n"\ + "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"\ " per_device_train_batch_size = num_generations\n" return check_batch_size pass From 54bd82743363ef79fa081e35c5fbcacd13379de5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 01:13:41 -0800 Subject: [PATCH 351/473] Fix TRL --- pyproject.toml | 34 +++++++++++++++++----------------- unsloth/models/_utils.py | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2a6e31dca..59a7c4473 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.2.2", + "unsloth_zoo>=2025.2.5", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -50,7 +50,7 @@ huggingface = [ "wheel>=0.42.0", "numpy", "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<0.15.0", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", "peft>=0.7.1,!=0.11.0", "protobuf<4.0.0", "huggingface_hub", @@ -176,26 +176,26 @@ cu124onlytorch251 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu118onlytorch260 = [ - "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", ] cu124onlytorch260 = [ - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu126onlytorch260 = [ - "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", ] cu118 = [ "unsloth[huggingface]", @@ -344,7 +344,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.2.2", + "unsloth_zoo>=2025.2.5", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -362,7 +362,7 @@ colab-new = [ ] colab-no-deps = [ "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<0.15.0", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", "peft>=0.7.1", "xformers", "bitsandbytes>=0.46.1", diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 8d0eadb96..df925d746 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.9" +__version__ = "2025.2.10" __all__ = [ "SUPPORTS_BFLOAT16", From fa560ce4e7d381cd346b3221004e910c35a41ebe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 02:08:33 -0800 Subject: [PATCH 352/473] Metrics GRPO --- unsloth/models/_utils.py | 2 +- unsloth/models/rl.py | 13 ++++++++++++- unsloth/models/rl_replacements.py | 26 ++++++++++++++++++++++---- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index df925d746..2a5b71d39 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.10" +__version__ = "2025.2.11" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2875ff64a..7b363d8fc 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -30,6 +30,7 @@ RL_FUNCTIONS, RL_PRE_ITEMS, RL_CONFIG_CHANGES, + RL_METRICS_CHANGES, ) selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"] @@ -310,10 +311,20 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_post += neftune_check pass + # Edit optional metrics + other_metrics_processor = "" + if trainer_file in RL_METRICS_CHANGES: + process_extra_args = RL_METRICS_CHANGES[trainer_file] + for process_extra_arg in process_extra_args: + other_metrics_processor += process_extra_arg(call_args, extra_args) + pass + # Add statistics as well! extra_args += \ + "other_metrics = []\n"\ + f"{other_metrics_processor}\n"\ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ - f"PatchRLStatistics('{trainer_file}')\n" + f"PatchRLStatistics('{trainer_file}', other_metrics)\n" # Patch optional args if trainer_file in RL_EXTRA_ARGS: diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 63fe24359..1e1306821 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -17,6 +17,7 @@ "RL_FUNCTIONS", "RL_PRE_ITEMS", "RL_CONFIG_CHANGES", + "RL_METRICS_CHANGES", ] import re @@ -24,10 +25,11 @@ import inspect from collections import defaultdict from unsloth_zoo.rl_replacements import RL_REPLACEMENTS -RL_EXTRA_ARGS = defaultdict(list) -RL_FUNCTIONS = defaultdict(list) -RL_PRE_ITEMS = defaultdict(list) -RL_CONFIG_CHANGES = defaultdict(list) +RL_EXTRA_ARGS = defaultdict(list) +RL_FUNCTIONS = defaultdict(list) +RL_PRE_ITEMS = defaultdict(list) +RL_CONFIG_CHANGES = defaultdict(list) +RL_METRICS_CHANGES = dict() torch_compile_options = { "epilogue_fusion" : True, @@ -260,3 +262,19 @@ def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): return check_batch_size pass RL_CONFIG_CHANGES["grpo_trainer"].append(grpo_trainer_fix_batch_size) + + +# Add other reward function names +def grpo_trainer_metrics(RLTrainer_source, RLConfig_source): + if "reward_funcs" not in RLTrainer_source: return "" + + log_metrics = \ + "if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]\n"\ + "for reward_func in _reward_funcs:\n"\ + " try:\n"\ + " reward_func_name = reward_func.__name__\n"\ + " other_metrics.append(f'rewards/{reward_func_name}')\n"\ + " except: pass\n" + return log_metrics +pass +RL_METRICS_CHANGES["grpo_trainer"].append(grpo_trainer_metrics) From 46462f1de080607e3a8e88f69cb08912a9712145 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 02:12:49 -0800 Subject: [PATCH 353/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 1e1306821..95db25289 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -29,7 +29,7 @@ RL_FUNCTIONS = defaultdict(list) RL_PRE_ITEMS = defaultdict(list) RL_CONFIG_CHANGES = defaultdict(list) -RL_METRICS_CHANGES = dict() +RL_METRICS_CHANGES = defaultdict(list) torch_compile_options = { "epilogue_fusion" : True, From 12c497a64e22a0bafec1b4e331b5118401418a6b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 02:17:26 -0800 Subject: [PATCH 354/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 95db25289..b2501c94f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -270,6 +270,7 @@ def grpo_trainer_metrics(RLTrainer_source, RLConfig_source): log_metrics = \ "if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]\n"\ + "else: _reward_funcs = reward_funcs\n"\ "for reward_func in _reward_funcs:\n"\ " try:\n"\ " reward_func_name = reward_func.__name__\n"\ From c14faee9fd641eef4d5580103784ffe9a5c34a50 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 16:45:25 -0800 Subject: [PATCH 355/473] No compile --- unsloth/models/rl.py | 4 ++-- unsloth/models/rl_replacements.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 7b363d8fc..d53c9606d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -112,12 +112,12 @@ class Unsloth{RLConfig_name}({RLConfig_name}): """ {__RLConfig_doc__} """ - sampling_params: Optional[Any] = field( + vllm_sampling_params: Optional[Any] = field( default = None, metadata = {{'help': 'vLLM SamplingParams'}}, ) def __init__({RLConfig_arguments}, - sampling_params = None, + vllm_sampling_params = None, **kwargs, ): {RLConfig_extra_args} diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b2501c94f..b9ba34726 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -188,8 +188,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] - return logits - # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + # return logits + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens pass pass @@ -199,7 +199,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] -RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +# RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): @@ -245,7 +245,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch function = inspect.getsource(compute_loss) return function pass -RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) +# RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. From 1fcad323e2a90c3fcdff09b579c25fc0f0ffe099 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 16:45:57 -0800 Subject: [PATCH 356/473] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index d53c9606d..ac1b83667 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -535,8 +535,8 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import sampling_params # Add spaces new_vllm_part = \ f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ - f"if getattr(args, 'sampling_params', None) is None else "\ - f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" + f"if getattr(args, 'vllm_sampling_params', None) is None else "\ + f"getattr(args, 'vllm_sampling_params', None)\n{' '*8}else:\n" init = init.replace(vllm_part, new_vllm_part) pass pass From 80be827ba7f4b0c21174967fcaaa496e71251cd9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 17:36:18 -0800 Subject: [PATCH 357/473] Remove docs --- unsloth/models/rl.py | 19 ++++++++++++++++++- unsloth/models/rl_replacements.py | 4 ++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ac1b83667..b13e6f9c7 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -548,6 +548,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import changed = {"__init__" : (old_init, init,)} edit_functions = RL_FUNCTIONS.get(trainer_file, []) + remover = [] for function in functions: if not hasattr(RLTrainer, function): continue @@ -591,7 +592,9 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import ) # Skip if no changes done - if source == original_source: continue + if source == original_source: + remover.append(original_source) + continue # Find all imports imports += [x for x in all_imports if not x.startswith("_") and x in source] @@ -607,9 +610,23 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import old, new = changed[function] RLTrainer_source = RLTrainer_source.replace(old, new) pass + + # Remove non editted functions + for remove in remover: + RLTrainer_source = RLTrainer_source.replace(remove, "\n") + pass + RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 ) + + # Get rid of docs since we repeated it + RLTrainer_source = re.sub( + rf"class _Unsloth{RLTrainer_name}:.+?def __init__\(", + rf"class _Unsloth{RLTrainer_name}:\n def __init__(", + RLTrainer_source, + flags = re.MULTILINE | re.DOTALL, + ) return RLTrainer_source pass diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b9ba34726..46d44b92f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -40,7 +40,7 @@ } # Check untrained tokens -def sft_trainer_fix_untraiend_tokens(call_args, extra_args): +def sft_trainer_fix_untrained_tokens(call_args, extra_args): if "model" in call_args and "train_dataset" in call_args: fix_tokenizer = \ "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\\n')\n"\ @@ -52,7 +52,7 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): return fix_tokenizer return "" pass -RL_EXTRA_ARGS["sft_trainer"].append(sft_trainer_fix_untraiend_tokens) +RL_EXTRA_ARGS["sft_trainer"].append(sft_trainer_fix_untrained_tokens) # Remove DPO columns which might randomnly be tokenized From 9254243f4d221fef9105856f59f78270a1d41b9a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 17:48:52 -0800 Subject: [PATCH 358/473] Update rl.py --- unsloth/models/rl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b13e6f9c7..8f60fa3ca 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -613,17 +613,17 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import # Remove non editted functions for remove in remover: - RLTrainer_source = RLTrainer_source.replace(remove, "\n") + RLTrainer_source = RLTrainer_source.replace(remove, "") pass - + RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 ) # Get rid of docs since we repeated it RLTrainer_source = re.sub( - rf"class _Unsloth{RLTrainer_name}:.+?def __init__\(", - rf"class _Unsloth{RLTrainer_name}:\n def __init__(", + rf"class _Unsloth{RLTrainer_name}(.*?:).+?def __init__\(", + rf"class _Unsloth{RLTrainer_name}\1\n def __init__(", RLTrainer_source, flags = re.MULTILINE | re.DOTALL, ) From 09cb804c784d4d6e7eeb28d1ce4c361fa136ca9a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 17:57:47 -0800 Subject: [PATCH 359/473] Update rl.py --- unsloth/models/rl.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 8f60fa3ca..51a5abb75 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -457,6 +457,11 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): selective_log_softmax_code = selective_log_softmax_code, ) + # Remove multiple doc strings + if RLTrainer_source.count(__RLTrainer_doc__) == 2: + RLTrainer_source = RLTrainer_source.replace(__RLTrainer_doc__, "", 1) + pass + # Create new function created_module = create_new_function( f"Unsloth{RLTrainer_name}", @@ -619,14 +624,6 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 ) - - # Get rid of docs since we repeated it - RLTrainer_source = re.sub( - rf"class _Unsloth{RLTrainer_name}(.*?:).+?def __init__\(", - rf"class _Unsloth{RLTrainer_name}\1\n def __init__(", - RLTrainer_source, - flags = re.MULTILINE | re.DOTALL, - ) return RLTrainer_source pass From 86dabcfeef3dad65bdd4d1668c35275bc1250fbd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 18:00:08 -0800 Subject: [PATCH 360/473] Update rl.py --- unsloth/models/rl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 51a5abb75..149846ca2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -422,7 +422,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Create full module exec(f"from trl.trainer import ({RLTrainer_name}, {RLConfig_name},)") __RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__ + if __RLTrainer_doc__ is None: __RLTrainer_doc__ = "" __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ + if __RLConfig_doc__ is None: __RLConfig_doc__ = "" # Get all pre-modules if trainer_file in RL_PRE_ITEMS: @@ -458,7 +460,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): ) # Remove multiple doc strings - if RLTrainer_source.count(__RLTrainer_doc__) == 2: + if __RLConfig_doc__ != "" and RLTrainer_source.count(__RLTrainer_doc__) == 2: RLTrainer_source = RLTrainer_source.replace(__RLTrainer_doc__, "", 1) pass From ba1c93e485b0a193b42a8602272b8879de99c65b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 18:03:57 -0800 Subject: [PATCH 361/473] Update rl.py --- unsloth/models/rl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 149846ca2..2facd3ccb 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -464,6 +464,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source = RLTrainer_source.replace(__RLTrainer_doc__, "", 1) pass + # Remove multiple newlines + RLTrainer_source = re.sub(r"[\n]{3,}", "\n", RLTrainer_source) + # Create new function created_module = create_new_function( f"Unsloth{RLTrainer_name}", From 0d75afdffea695a179138c139994c8b0eacd12b7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 18:06:12 -0800 Subject: [PATCH 362/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 46d44b92f..ad6d0f2bb 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -188,8 +188,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] - # return logits - return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + return logits + # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens pass pass @@ -199,7 +199,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] -# RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): @@ -245,7 +245,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch function = inspect.getsource(compute_loss) return function pass -# RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. From 18036583ec599355a690f948b9fadb2b804f30bc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 18:34:25 -0800 Subject: [PATCH 363/473] Update rl.py --- unsloth/models/rl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2facd3ccb..df1f2f110 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -622,9 +622,9 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import pass # Remove non editted functions - for remove in remover: - RLTrainer_source = RLTrainer_source.replace(remove, "") - pass + # for remove in remover: + # RLTrainer_source = RLTrainer_source.replace(remove, "") + # pass RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 From a856085a8982628d22c7ce158e839a37fbc2dd11 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 18:35:07 -0800 Subject: [PATCH 364/473] Update rl.py --- unsloth/models/rl.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index df1f2f110..1b2f34854 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -558,7 +558,6 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import changed = {"__init__" : (old_init, init,)} edit_functions = RL_FUNCTIONS.get(trainer_file, []) - remover = [] for function in functions: if not hasattr(RLTrainer, function): continue @@ -602,9 +601,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import ) # Skip if no changes done - if source == original_source: - remover.append(original_source) - continue + if source == original_source: continue # Find all imports imports += [x for x in all_imports if not x.startswith("_") and x in source] @@ -621,11 +618,6 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import RLTrainer_source = RLTrainer_source.replace(old, new) pass - # Remove non editted functions - # for remove in remover: - # RLTrainer_source = RLTrainer_source.replace(remove, "") - # pass - RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 ) From eeac4f301b689c1a821e07e150279def4ad527ba Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 20:04:35 -0800 Subject: [PATCH 365/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ad6d0f2bb..a139a8533 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -188,8 +188,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] - return logits - # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + # return logits + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens pass pass @@ -198,8 +198,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) -grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] -RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +# grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +# RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): @@ -245,7 +245,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch function = inspect.getsource(compute_loss) return function pass -RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) +# RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. From 6f1beb01a192e934a12ab752f0ab1c6693736d0b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 01:49:29 -0800 Subject: [PATCH 366/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a139a8533..ad6d0f2bb 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -188,8 +188,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] - # return logits - return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + return logits + # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens pass pass @@ -198,8 +198,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) -# grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] -# RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): @@ -245,7 +245,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch function = inspect.getsource(compute_loss) return function pass -# RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. From 222b1e7effef33f2d73ff63d95a32d078036f205 Mon Sep 17 00:00:00 2001 From: Gennadii Manzhos <105049664+everythingisc00l@users.noreply.github.com> Date: Sun, 16 Feb 2025 13:04:08 +0300 Subject: [PATCH 367/473] llama-quantize on WINDOWS WSL error fix - edit save.py (gguf saving breaks) (#1649) * edit save.py to fix gguf saving breaks. * add check for .exe or not exe file extension for linux and windows --- unsloth/save.py | 67 ++++++++++++++++++++++++++++--------------------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/unsloth/save.py b/unsloth/save.py index d3ba1928c..0f75ecfd0 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -254,7 +254,7 @@ def unsloth_save_model( # First check for a token! if push_to_hub: from huggingface_hub import whoami - try: + try: username = whoami(token = token)["name"] except: raise RuntimeError( @@ -385,7 +385,7 @@ def unsloth_save_model( else: internal_model = model pass - + # Cannot be converted properly! if (save_method == "merged_4bit") or (save_method == "lora") or ( not hasattr(model, "model") or \ @@ -481,7 +481,7 @@ def unsloth_save_model( gb_found = re.match("([0-9]{1,})[\s]{0,}GB", max_shard_size, flags = re.IGNORECASE) mb_found = re.match("([0-9]{1,})[\s]{0,}MB", max_shard_size, flags = re.IGNORECASE) if gb_found: sharded_ram_usage = int(gb_found.group(1)) * 1024 * 1024 * 1024 - elif mb_found: sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024 + elif mb_found: sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024 elif type(max_shard_size) is int: sharded_ram_usage = sharded_ram_usage pass @@ -612,7 +612,7 @@ def unsloth_save_model( # Edit save_pretrained_settings # [TODO] _create_repo has errors due to **kwargs getting accepted save_pretrained_settings["state_dict"] = state_dict - + # commit_description does not seem to work? what_to_delete = ("use_temp_dir", "commit_message", "create_pr", "revision", "commit_description", "tags",) \ if not push_to_hub else \ @@ -665,7 +665,7 @@ def unsloth_save_model( # Revert back padding side tokenizer.padding_side = old_padding_side - + print(" Done.") else: print() @@ -877,10 +877,15 @@ def install_llama_cpp_old(version = -10): pass # Check if successful - if not os.path.exists("llama.cpp/quantize") and not os.path.exists("llama.cpp/llama-quantize"): + if not ( + os.path.exists("llama.cpp/llama-quantize.exe") or + os.path.exists("llama.cpp/llama-quantize") or + os.path.exists("llama.cpp/quantize.exe") or + os.path.exists("llama.cpp/quantize") + ): raise RuntimeError( "Unsloth: The file 'llama.cpp/llama-quantize' or `llama.cpp/quantize` does not exist.\n"\ - "But we expect this file to exist! Maybe the llama.cpp developers changed the name?" + "But we expect this file to exist! Maybe the llama.cpp developers changed the name or check extension of the llama-quantize file." ) pass pass @@ -957,7 +962,7 @@ def save_to_gguf( else: raise TypeError("Unsloth: quantization_method can only be a string or a list of strings") pass - + # Check if bfloat16 is supported if model_dtype == "bf16" and not torch.cuda.is_bf16_supported(): logger.warning( @@ -973,7 +978,7 @@ def save_to_gguf( pass # Check I quants - for quant_method in quantization_method: + for quant_method in quantization_method: if quant_method.startswith("iq2"): raise RuntimeError("Unsloth: Currently iq2 type quantizations aren't supported yet - sorry!") pass @@ -1026,9 +1031,9 @@ def save_to_gguf( pass # Determine whether the system already has llama.cpp installed and the scripts are executable - quantize_location = get_executable(["llama-quantize", "quantize"]) + quantize_location = get_executable(["llama-quantize", "quantize", "llama-quantize.exe", "quantize.exe"]) convert_location = get_executable(["convert-hf-to-gguf.py", "convert_hf_to_gguf.py"]) - + error = 0 if quantize_location is not None and convert_location is not None: print("Unsloth: llama.cpp found in the system. We shall skip installation.") @@ -1062,14 +1067,18 @@ def save_to_gguf( # and llama.cpp/main changed to llama.cpp/llama-cli # See https://github.com/ggerganov/llama.cpp/pull/7809 quantize_location = None - if os.path.exists("llama.cpp/quantize"): + if os.path.exists("llama.cpp/quantize.exe"): + quantize_location = "llama.cpp/quantize.exe" + elif os.path.exists("llama.cpp/quantize"): quantize_location = "llama.cpp/quantize" + elif os.path.exists("llama.cpp/llama-quantize.exe"): + quantize_location = "llama.cpp/llama-quantize.exe" elif os.path.exists("llama.cpp/llama-quantize"): quantize_location = "llama.cpp/llama-quantize" else: raise RuntimeError( - "Unsloth: The file 'llama.cpp/llama-quantize' or 'llama.cpp/quantize' does not exist.\n"\ - "But we expect this file to exist! Maybe the llama.cpp developers changed the name?" + "Unsloth: The file ('llama.cpp/llama-quantize' or 'llama.cpp/llama-quantize.exe' if you are on Windows WSL) or 'llama.cpp/quantize' does not exist.\n"\ + "But we expect this file to exist! Maybe the llama.cpp developers changed the name or check extension of the llama-quantize file." ) pass @@ -1150,7 +1159,7 @@ def save_to_gguf( # Concurrency from https://rentry.org/llama-cpp-conversions#merging-loras-into-a-model final_location = str((Path(model_directory) / f"unsloth.{first_conversion.upper()}.gguf").absolute()) - + print(f"Unsloth: [1] Converting model at {model_directory} into {first_conversion} GGUF format.\n"\ f"The output location will be {final_location}\n"\ "This might take 3 minutes...") @@ -1217,7 +1226,7 @@ def save_to_gguf( command = f"./{quantize_location} {full_precision_location} "\ f"{final_location} {quant_method} {n_cpus}" - + try_execute([command,], force_complete = True) # Check if quantization succeeded! @@ -1378,7 +1387,7 @@ def _determine_username(save_directory, old_username, token): save_directory = save_directory.lstrip("./") if "/" not in save_directory: from huggingface_hub import whoami - try: + try: username = whoami(token = token)["name"] if type(old_username) is str and username != old_username: username = old_username @@ -1412,7 +1421,7 @@ def create_huggingface_repo( repo_type = "model", exist_ok = False, private = private, - ) + ) # Create model card from huggingface_hub import ModelCard @@ -1453,7 +1462,7 @@ def upload_to_huggingface( repo_type = "model", exist_ok = False, private = private, - ) + ) # Create model card from huggingface_hub import ModelCard @@ -1527,7 +1536,7 @@ def fix_tokenizer_bos_token(tokenizer): # Check if BOS added already, then warn fix_bos_token = False chat_template = getattr(tokenizer, "chat_template", None) - + if (tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None)): if chat_template is not None and \ ( @@ -1546,7 +1555,7 @@ def fix_tokenizer_bos_token(tokenizer): new_chat_template = re.sub(r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\}[\s]{0,}\}", "", chat_template) # Remove {{bos_token + new_chat_template = re.sub(r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\+[\s]{0,}", "", new_chat_template) - + tokenizer.chat_template = new_chat_template pass @@ -1580,7 +1589,7 @@ def create_ollama_modelfile(tokenizer, gguf_location): modelfile = modelfile\ .replace(FILE_LOCATION_REPLACER, "{__FILE_LOCATION__}")\ .replace(EOS_TOKEN_REPLACER, "{__EOS_TOKEN__}") - + if "__EOS_TOKEN__" in modelfile: modelfile = modelfile.format( __FILE_LOCATION__ = gguf_location, @@ -1591,7 +1600,7 @@ def create_ollama_modelfile(tokenizer, gguf_location): __FILE_LOCATION__ = gguf_location, ) pass - + modelfile = modelfile\ .replace("⚫@✅#🦥", "{")\ .replace("⚡@🦥#⛵", "}")\ @@ -1733,7 +1742,7 @@ def unsloth_save_pretrained_gguf( # Save to GGUF all_file_locations, want_full_precision = save_to_gguf( - model_type, model_dtype, is_sentencepiece_model, + model_type, model_dtype, is_sentencepiece_model, new_save_directory, quantization_method, first_conversion, makefile, ) @@ -1911,7 +1920,7 @@ def unsloth_push_to_hub_gguf( # Save to GGUF all_file_locations, want_full_precision = save_to_gguf( - model_type, model_dtype, is_sentencepiece_model, + model_type, model_dtype, is_sentencepiece_model, new_save_directory, quantization_method, first_conversion, makefile, ) @@ -1928,7 +1937,7 @@ def unsloth_push_to_hub_gguf( # If not needing full precision, skip the first if not want_full_precision: all_file_locations = all_file_locations[1:] - + for file_location in all_file_locations: print("Unsloth: Uploading GGUF to Huggingface Hub...") username = upload_to_huggingface( @@ -2044,8 +2053,8 @@ def unsloth_convert_lora_to_ggml_and_push_to_hub( def unsloth_convert_lora_to_ggml_and_save_locally( self, - save_directory: str, # Added parameter for the folder name - tokenizer, + save_directory: str, # Added parameter for the folder name + tokenizer, temporary_location: str = "_unsloth_temporary_saved_buffers", maximum_memory_usage: float = 0.85, ): @@ -2162,7 +2171,7 @@ def unsloth_generic_save_pretrained_merged( tags : List[str] = None, temporary_location : str = "_unsloth_temporary_saved_buffers", maximum_memory_usage : float = 0.75, -): +): """ Same as .push_to_hub(...) except 4bit weights are auto converted to float16 with as few overhead as possible. From 103cff459a11fc3ecd293e342b1aecaa00bb35aa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 16:14:21 -0800 Subject: [PATCH 368/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ad6d0f2bb..ad6d7822a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -229,6 +229,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] + print(input_ids.shape, ref_per_token_logps.shape, per_token_logps.shape, completion_mask.shape, advantages.shape) loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) From 89a1d035ae5692c2edebf473b63bb36548c5866d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 17:13:24 -0800 Subject: [PATCH 369/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ad6d7822a..f2ac7f80d 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,6 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -201,6 +202,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +global INPUTS + # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function @@ -229,10 +232,15 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - print(input_ids.shape, ref_per_token_logps.shape, per_token_logps.shape, completion_mask.shape, advantages.shape) loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) + global INPUTS + INPUTS = ( + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, + loss, completion_length, mean_kl, + ) + raise # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From c46b544c8c370e650bbcfb163adad8577f765e17 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 18:09:03 -0800 Subject: [PATCH 370/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index f2ac7f80d..b91c80871 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,6 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): From ed84307d46fea7090ea506b91301de9eff1b05da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 18:27:04 -0800 Subject: [PATCH 371/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b91c80871..8e930261f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -202,6 +202,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) global INPUTS +INPUTS = None # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): From 93d3f162f0a6f51db8e2302dc9a255dc33825605 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 18:34:45 -0800 Subject: [PATCH 372/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 8e930261f..92b12647c 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -201,9 +201,6 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) -global INPUTS -INPUTS = None - # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function @@ -235,8 +232,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) - global INPUTS - INPUTS = ( + from unsloth_zoo.rl_replacements import RL_REPLACEMENTS + RL_REPLACEMENTS["data"] = ( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, loss, completion_length, mean_kl, ) From 429ba6d57de05cf3c0b8bf73eb76ceab1823972f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 19:47:39 -0800 Subject: [PATCH 373/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 92b12647c..b058d0d27 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -233,11 +233,14 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) from unsloth_zoo.rl_replacements import RL_REPLACEMENTS + if "count" in RL_REPLACEMENTS: + RL_REPLACEMENTS["count"] += 1 + if RL_REPLACEMENTS["count"] == 5: raise + else: RL_REPLACEMENTS["count"] = 1 RL_REPLACEMENTS["data"] = ( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, loss, completion_length, mean_kl, ) - raise # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From 1e42bad4adffd6694407a3a43bd43813371a2589 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 20:20:01 -0800 Subject: [PATCH 374/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b058d0d27..bb41cff75 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -235,7 +235,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch from unsloth_zoo.rl_replacements import RL_REPLACEMENTS if "count" in RL_REPLACEMENTS: RL_REPLACEMENTS["count"] += 1 - if RL_REPLACEMENTS["count"] == 5: raise + if RL_REPLACEMENTS["count"] == 10: raise else: RL_REPLACEMENTS["count"] = 1 RL_REPLACEMENTS["data"] = ( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, From 38a1885bf619e22d4ce2c8fb07caa01030975d29 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 20:55:11 -0800 Subject: [PATCH 375/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index bb41cff75..034ce8678 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -235,11 +235,11 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch from unsloth_zoo.rl_replacements import RL_REPLACEMENTS if "count" in RL_REPLACEMENTS: RL_REPLACEMENTS["count"] += 1 - if RL_REPLACEMENTS["count"] == 10: raise + if RL_REPLACEMENTS["count"] == 20: raise else: RL_REPLACEMENTS["count"] = 1 RL_REPLACEMENTS["data"] = ( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, - loss, completion_length, mean_kl, + loss, completion_length, mean_kl, completion_ids, ) # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() From f0ee4f5c91e107b28b866dded3c53f736b625d81 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 20:55:26 -0800 Subject: [PATCH 376/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 034ce8678..53ec6e6cd 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -235,7 +235,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch from unsloth_zoo.rl_replacements import RL_REPLACEMENTS if "count" in RL_REPLACEMENTS: RL_REPLACEMENTS["count"] += 1 - if RL_REPLACEMENTS["count"] == 20: raise + if RL_REPLACEMENTS["count"] == 10: raise else: RL_REPLACEMENTS["count"] = 1 RL_REPLACEMENTS["data"] = ( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, From b68dce6b766f72be33560fea6ab00a8b63a7427d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 21:06:47 -0800 Subject: [PATCH 377/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 53ec6e6cd..77d7e6a53 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -216,7 +216,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) attention_mask = None logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - + _input_ids = input_ids + _logits_to_keep = logits_to_keep per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model @@ -238,8 +239,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch if RL_REPLACEMENTS["count"] == 10: raise else: RL_REPLACEMENTS["count"] = 1 RL_REPLACEMENTS["data"] = ( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, - loss, completion_length, mean_kl, completion_ids, + ref_per_token_logps, per_token_logps, _input_ids, completion_mask, self.beta, advantages, + loss, completion_length, mean_kl, completion_ids, _logits_to_keep, ) # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() From 0827067906d73cfa65ad97501f40a79e4d2dbbc5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 21:22:35 -0800 Subject: [PATCH 378/473] Update llama.py --- unsloth/models/llama.py | 62 ++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1eae97ff1..9403b50e4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1030,6 +1030,7 @@ def _CausalLM_fast_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_logits_to_keep: Optional[int] = 0, + logits_to_keep: Optional[int] = 0, *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: @@ -1053,16 +1054,16 @@ def _CausalLM_fast_forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model._has_no_labels = labels is None outputs = self.model( - input_ids=input_ids, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + input_ids = input_ids, + causal_mask = causal_mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + inputs_embeds = inputs_embeds, + use_cache = use_cache, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + return_dict = return_dict, ) pass hidden_states = outputs[0] @@ -1072,6 +1073,7 @@ def _CausalLM_fast_forward( logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) logit_scaling = getattr(self.config, "logit_scale", 0) dtype = lm_head.dtype + num_logits_to_keep = max(num_logits_to_keep, logits_to_keep) if bsz == 1 and q_len == 1: logits = torch.mv(lm_head, hidden_states.ravel().to(dtype)) @@ -1180,28 +1182,30 @@ def _CausalLM_fast_forward( @torch._disable_dynamo def PeftModelForCausalLM_fast_forward( self, - input_ids=None, - causal_mask=None, - attention_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - task_ids=None, - num_logits_to_keep=0, + input_ids = None, + causal_mask = None, + attention_mask = None, + inputs_embeds = None, + labels = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + task_ids = None, + num_logits_to_keep = 0, + logits_to_keep = 0, **kwargs, ): return self.base_model( - input_ids=input_ids, - causal_mask=causal_mask, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - labels=labels, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - num_logits_to_keep=num_logits_to_keep, + input_ids = input_ids, + causal_mask = causal_mask, + attention_mask = attention_mask, + inputs_embeds = inputs_embeds, + labels = labels, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + return_dict = return_dict, + num_logits_to_keep = num_logits_to_keep, + logits_to_keep = logits_to_keep, **kwargs, ) pass From 204cd7a38ad946c7e0c7767f6d9807148361bc81 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 23:49:20 -0800 Subject: [PATCH 379/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 77d7e6a53..99dba9b9a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -239,7 +239,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch if RL_REPLACEMENTS["count"] == 10: raise else: RL_REPLACEMENTS["count"] = 1 RL_REPLACEMENTS["data"] = ( - ref_per_token_logps, per_token_logps, _input_ids, completion_mask, self.beta, advantages, + ref_per_token_logps, per_token_logps.detach(), _input_ids, completion_mask, self.beta, advantages, loss, completion_length, mean_kl, completion_ids, _logits_to_keep, ) # Log the metrics From e14107523c95b0ee3515071d81466ca966d04f9b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 00:05:32 -0800 Subject: [PATCH 380/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 99dba9b9a..0f1c81bb8 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -233,15 +233,15 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) + RL_REPLACEMENTS["data"] = ( + ref_per_token_logps.detach(), per_token_logps.detach(), _input_ids, completion_mask, self.beta, advantages, + loss.detach(), completion_length, mean_kl, completion_ids, _logits_to_keep, + ) from unsloth_zoo.rl_replacements import RL_REPLACEMENTS if "count" in RL_REPLACEMENTS: RL_REPLACEMENTS["count"] += 1 if RL_REPLACEMENTS["count"] == 10: raise else: RL_REPLACEMENTS["count"] = 1 - RL_REPLACEMENTS["data"] = ( - ref_per_token_logps, per_token_logps.detach(), _input_ids, completion_mask, self.beta, advantages, - loss, completion_length, mean_kl, completion_ids, _logits_to_keep, - ) # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From a07a9e3c1d0bd3019716b31dc97df1b71532a552 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 00:43:11 -0800 Subject: [PATCH 381/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0f1c81bb8..eb41507b1 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -233,11 +233,11 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) + from unsloth_zoo.rl_replacements import RL_REPLACEMENTS RL_REPLACEMENTS["data"] = ( ref_per_token_logps.detach(), per_token_logps.detach(), _input_ids, completion_mask, self.beta, advantages, loss.detach(), completion_length, mean_kl, completion_ids, _logits_to_keep, ) - from unsloth_zoo.rl_replacements import RL_REPLACEMENTS if "count" in RL_REPLACEMENTS: RL_REPLACEMENTS["count"] += 1 if RL_REPLACEMENTS["count"] == 10: raise From cf2720d1812f1727290e9c4bbe09a68ef4441f9b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 00:49:35 -0800 Subject: [PATCH 382/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9403b50e4..378431ec5 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -700,6 +700,7 @@ def LlamaModel_fast_forward( elif inputs_requires_grad: inputs_embeds.requires_grad_(False) pass + attention_mask = attention_mask[:,:self.max_seq_length] # Must resize! inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2) if inputs_requires_grad: inputs_embeds.requires_grad_(True) pass From 5c6f5866beb723eb35bf1a406db9d14801e6cc77 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 19:16:41 -0800 Subject: [PATCH 383/473] Update llama.py --- unsloth/models/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 378431ec5..f34968c3a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1699,9 +1699,9 @@ def from_pretrained( elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: logger.warning_once("Device does not support bfloat16. Will change to float16.") dtype = torch.float16 - elif dtype == torch.float16 and SUPPORTS_BFLOAT16: - logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.") - dtype = torch.bfloat16 + # elif dtype == torch.float16 and SUPPORTS_BFLOAT16: + # logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.") + # dtype = torch.bfloat16 assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) From 2e0762385723b542f33c855f170f49d2862a7d79 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 19:44:43 -0800 Subject: [PATCH 384/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index eb41507b1..86cc2fb14 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -198,7 +198,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) -grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) # Edit _get_per_token_logps to handle mixed precision @@ -213,6 +214,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + bsz, qlen = input_ids.shape # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) attention_mask = None logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens @@ -233,6 +235,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) + accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( + self, input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 1, + ) + print("loss", loss, accumulated_loss) + print("completion_length", completion_length, accumulated_completion_length) + print("mean_kl", mean_kl, accumulated_mean_kl) + from unsloth_zoo.rl_replacements import RL_REPLACEMENTS RL_REPLACEMENTS["data"] = ( ref_per_token_logps.detach(), per_token_logps.detach(), _input_ids, completion_mask, self.beta, advantages, From 8025cfeefbeb42d74e4d1195269e447a4d7067d3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 19:45:07 -0800 Subject: [PATCH 385/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 86cc2fb14..3d97b90df 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -233,7 +233,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( self, input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 1, From ba484956752e0bc432b8d1d8b65444f48abff43b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 19:53:49 -0800 Subject: [PATCH 386/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 3d97b90df..17215bafb 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -201,6 +201,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +RL_PRE_ITEMS["grpo_accumulated_loss"].append(inspect.getsource(grpo_accumulated_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): From f0078de7b982c71e89e612d42663550258015920 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 19:58:17 -0800 Subject: [PATCH 387/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 1b2f34854..746889785 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -429,6 +429,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Get all pre-modules if trainer_file in RL_PRE_ITEMS: RL_pre = "\n".join(RL_PRE_ITEMS[trainer_file]) + print(RL_pre) else: RL_pre = "" pass From 15e014043a5d2fc1d168d9e98d027f1748e8546e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:00:04 -0800 Subject: [PATCH 388/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 746889785..646676558 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -429,7 +429,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Get all pre-modules if trainer_file in RL_PRE_ITEMS: RL_pre = "\n".join(RL_PRE_ITEMS[trainer_file]) - print(RL_pre) + print(RL_PRE_ITEMS[trainer_file]) else: RL_pre = "" pass From 5f5cca406fed09cf7d90c1ef866a515baa24f1a2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:04:14 -0800 Subject: [PATCH 389/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 17215bafb..d8a1c6371 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -201,7 +201,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) -RL_PRE_ITEMS["grpo_accumulated_loss"].append(inspect.getsource(grpo_accumulated_loss)) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): From d80be70ac4d703a57e1fbd6c47842276f2a86aaa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:04:26 -0800 Subject: [PATCH 390/473] Update rl.py --- unsloth/models/rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 646676558..1b2f34854 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -429,7 +429,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Get all pre-modules if trainer_file in RL_PRE_ITEMS: RL_pre = "\n".join(RL_PRE_ITEMS[trainer_file]) - print(RL_PRE_ITEMS[trainer_file]) else: RL_pre = "" pass From 47a85eba5a7bf64804da1511563d682d889bbff0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:08:15 -0800 Subject: [PATCH 391/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 1b2f34854..f36598b0a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -94,6 +94,7 @@ def generate_with_clone(*args, **kwargs): from dataclasses import dataclass, field from packaging.version import Version import torch +import numpy as np from contextlib import nullcontext from torch.nn import functional as F torch_compile_options = {{ From f09478de3672e7281d3de360320201d2f1d1885d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:21:20 -0800 Subject: [PATCH 392/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index d8a1c6371..ee57055a0 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -237,7 +237,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( - self, input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 1, + self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 1, ) print("loss", loss, accumulated_loss) print("completion_length", completion_length, accumulated_completion_length) From 97637c5b3d29ee999f004debd2fe05db490f034b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:38:18 -0800 Subject: [PATCH 393/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ee57055a0..b1a2ba8f7 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,6 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -221,7 +222,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep - per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + # per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model ref_per_token_logps = inputs["ref_per_token_logps"] @@ -233,25 +234,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, - ) + # loss, completion_length, mean_kl = grpo_compute_loss( + # ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, + # ) accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( - self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 1, - ) - print("loss", loss, accumulated_loss) - print("completion_length", completion_length, accumulated_completion_length) - print("mean_kl", mean_kl, accumulated_mean_kl) - - from unsloth_zoo.rl_replacements import RL_REPLACEMENTS - RL_REPLACEMENTS["data"] = ( - ref_per_token_logps.detach(), per_token_logps.detach(), _input_ids, completion_mask, self.beta, advantages, - loss.detach(), completion_length, mean_kl, completion_ids, _logits_to_keep, + self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, ) - if "count" in RL_REPLACEMENTS: - RL_REPLACEMENTS["count"] += 1 - if RL_REPLACEMENTS["count"] == 10: raise - else: RL_REPLACEMENTS["count"] = 1 + loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From 58bd27f332e5ce3d0d038b44ed003ae8184fae68 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 21:03:44 -0800 Subject: [PATCH 394/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b1a2ba8f7..9a1cf4b4c 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - return None + # return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -234,13 +234,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - # loss, completion_length, mean_kl = grpo_compute_loss( - # ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, - # ) - accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( - self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + loss, completion_length, mean_kl = grpo_compute_loss( + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) - loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl + # accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( + # self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + # ) + # loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From 7c0c7493cb301dada287d3d9955b190091cab5bd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 21:08:32 -0800 Subject: [PATCH 395/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9a1cf4b4c..b1a2ba8f7 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - # return None + return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -234,13 +234,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, - ) - # accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( - # self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + # loss, completion_length, mean_kl = grpo_compute_loss( + # ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, # ) - # loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl + accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( + self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + ) + loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From 97b55c139f38daff37c3e789918dea5b2c04f7fe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 21:10:26 -0800 Subject: [PATCH 396/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b1a2ba8f7..21f271258 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - return None + # return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -222,7 +222,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep - # per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model ref_per_token_logps = inputs["ref_per_token_logps"] @@ -234,13 +234,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - # loss, completion_length, mean_kl = grpo_compute_loss( - # ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, - # ) - accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( - self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + loss, completion_length, mean_kl = grpo_compute_loss( + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) - loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl + # accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( + # self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + # ) + # loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From 24c7a2f7c49cbca7005a46be1577f6d1bd7dedf5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 21:17:58 -0800 Subject: [PATCH 397/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 21f271258..405f79094 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - # return None + return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -234,13 +234,15 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, - ) - # accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( - # self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, - # ) - # loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl + if per_token_logps is not None: + loss, completion_length, mean_kl = grpo_compute_loss( + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, + ) + else: + loss, completion_length, mean_kl = grpo_accumulated_loss( + self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + ) + # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From 06b2cd3e57c0befd273ddc4e256c1bfeaa04ba1f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 22:17:11 -0800 Subject: [PATCH 398/473] unsloth_num_chunks --- unsloth/models/rl.py | 4 ++++ unsloth/models/rl_replacements.py | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f36598b0a..fa617d5d4 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -117,6 +117,10 @@ class Unsloth{RLConfig_name}({RLConfig_name}): default = None, metadata = {{'help': 'vLLM SamplingParams'}}, ) + unsloth_num_chunks : Optional[int] = field( + default = 1, + metadata = {{'help': 'Chunk size to reduce memory usage'}}, + ) def __init__({RLConfig_arguments}, vllm_sampling_params = None, **kwargs, diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 405f79094..decaf3209 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - return None + if self.args.unsloth_num_chunks != 1: return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -240,7 +240,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ) else: loss, completion_length, mean_kl = grpo_accumulated_loss( - self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + self, _input_ids, logits_to_keep, completion_mask, advantages, + n_chunks = self.args.unsloth_num_chunks, ) # Log the metrics From cbb16e363b3ac6bd730f34abeef8e1a714de7d2f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 22:24:57 -0800 Subject: [PATCH 399/473] Update rl.py --- unsloth/models/rl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index fa617d5d4..231dbe776 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -122,7 +122,8 @@ class Unsloth{RLConfig_name}({RLConfig_name}): metadata = {{'help': 'Chunk size to reduce memory usage'}}, ) def __init__({RLConfig_arguments}, - vllm_sampling_params = None, + vllm_sampling_params = vllm_sampling_params, + unsloth_num_chunks = unsloth_num_chunks, **kwargs, ): {RLConfig_extra_args} From d16299b1549ffe59018253a6ad1aac89f45444dd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 22:30:13 -0800 Subject: [PATCH 400/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index decaf3209..3b23e8bac 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -239,6 +239,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) else: + print(self.args.unsloth_num_chunks, end = ",") loss, completion_length, mean_kl = grpo_accumulated_loss( self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = self.args.unsloth_num_chunks, From 0c1a808e3a5828c615921fe7d3c8c10d7de6324c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 22:30:20 -0800 Subject: [PATCH 401/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 3b23e8bac..443c8b267 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -239,7 +239,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) else: - print(self.args.unsloth_num_chunks, end = ",") + print(int(self.args.unsloth_num_chunks), end = ",") loss, completion_length, mean_kl = grpo_accumulated_loss( self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = self.args.unsloth_num_chunks, From 67968012470a1e484a6f2cc69d3e5376b3ba24c6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 23:47:52 -0800 Subject: [PATCH 402/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 443c8b267..bcfe4d777 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,6 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + print(self.args.unsloth_num_chunks) if self.args.unsloth_num_chunks != 1: return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 From bd046ca2265c95dbcd94fe9574cb606f85748956 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 23:57:57 -0800 Subject: [PATCH 403/473] Update rl.py --- unsloth/models/rl.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 231dbe776..7a90b8115 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -442,6 +442,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Selective log softmax selective_log_softmax_code = inspect.getsource(selective_log_softmax) + # Trainer kwargs + comma = "" if RLTrainer_call_args.endswith(",") else "," + unsloth_extra_args = comma + \ + "vllm_sampling_params = vllm_sampling_params,\n"\ + "unsloth_num_chunks = unsloth_num_chunks, **kwargs" + # Get final source code RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, @@ -449,7 +455,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_arguments = RLTrainer_arguments, RLTrainer_extra_args = RLTrainer_extra_args, RLTrainer_call_args = RLTrainer_call_args, - RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args.endswith(",") else 0:], + RLTrainer_kwargs = unsloth_extra_args, RLConfig_name = RLConfig_name, __RLConfig_doc__ = __RLConfig_doc__, From ac2e814c2509a8751d920bfd74941812d3e6add1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 00:01:09 -0800 Subject: [PATCH 404/473] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 7a90b8115..3b7b88b6c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -455,14 +455,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_arguments = RLTrainer_arguments, RLTrainer_extra_args = RLTrainer_extra_args, RLTrainer_call_args = RLTrainer_call_args, - RLTrainer_kwargs = unsloth_extra_args, + RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args .endswith(",") else 0:], RLConfig_name = RLConfig_name, __RLConfig_doc__ = __RLConfig_doc__, RLConfig_arguments = RLConfig_arguments, RLConfig_extra_args = RLConfig_extra_args, RLConfig_call_args = RLConfig_call_args, - RLConfig_kwargs = ",**kwargs"[1 if RLConfig_call_args .endswith(",") else 0:], + RLConfig_kwargs = unsloth_extra_args, RLTrainer_extras = RLTrainer_extras, RLTrainer_post = RLTrainer_post, From a88712f94ac82708a2ea33f716ed232f56908e27 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 00:05:40 -0800 Subject: [PATCH 405/473] Update rl.py --- unsloth/models/rl.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3b7b88b6c..231dbe776 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -442,12 +442,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Selective log softmax selective_log_softmax_code = inspect.getsource(selective_log_softmax) - # Trainer kwargs - comma = "" if RLTrainer_call_args.endswith(",") else "," - unsloth_extra_args = comma + \ - "vllm_sampling_params = vllm_sampling_params,\n"\ - "unsloth_num_chunks = unsloth_num_chunks, **kwargs" - # Get final source code RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, @@ -455,14 +449,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_arguments = RLTrainer_arguments, RLTrainer_extra_args = RLTrainer_extra_args, RLTrainer_call_args = RLTrainer_call_args, - RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args .endswith(",") else 0:], + RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args.endswith(",") else 0:], RLConfig_name = RLConfig_name, __RLConfig_doc__ = __RLConfig_doc__, RLConfig_arguments = RLConfig_arguments, RLConfig_extra_args = RLConfig_extra_args, RLConfig_call_args = RLConfig_call_args, - RLConfig_kwargs = unsloth_extra_args, + RLConfig_kwargs = ",**kwargs"[1 if RLConfig_call_args .endswith(",") else 0:], RLTrainer_extras = RLTrainer_extras, RLTrainer_post = RLTrainer_post, From 0daa328df3964cd0a16d23b6ffca7dcec4eb7581 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 00:09:11 -0800 Subject: [PATCH 406/473] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 231dbe776..da73ec49f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -122,8 +122,8 @@ class Unsloth{RLConfig_name}({RLConfig_name}): metadata = {{'help': 'Chunk size to reduce memory usage'}}, ) def __init__({RLConfig_arguments}, - vllm_sampling_params = vllm_sampling_params, - unsloth_num_chunks = unsloth_num_chunks, + vllm_sampling_params = None, + unsloth_num_chunks = 1, **kwargs, ): {RLConfig_extra_args} From 1afe3f2bf6ba968a9a738c2aae1ffe4a486be9d9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 00:13:26 -0800 Subject: [PATCH 407/473] Update rl.py --- unsloth/models/rl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index da73ec49f..29773d0a8 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -128,6 +128,8 @@ def __init__({RLConfig_arguments}, ): {RLConfig_extra_args} super().__init__({RLConfig_call_args}{RLConfig_kwargs}) + self.vllm_sampling_params = vllm_sampling_params + self.unsloth_num_chunks = unsloth_num_chunks pass {RLTrainer_extras} From 6732822a83782f19fe96695c980664adb012a37f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 00:17:07 -0800 Subject: [PATCH 408/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index bcfe4d777..decaf3209 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,6 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - print(self.args.unsloth_num_chunks) if self.args.unsloth_num_chunks != 1: return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 @@ -240,7 +239,6 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) else: - print(int(self.args.unsloth_num_chunks), end = ",") loss, completion_length, mean_kl = grpo_accumulated_loss( self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = self.args.unsloth_num_chunks, From 5efe9f356c4a674b3038c7c5ae004b7813d4e3b2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 01:57:18 -0800 Subject: [PATCH 409/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index decaf3209..5fa4ec5a4 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -234,9 +234,9 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - if per_token_logps is not None: + if False:#per_token_logps is not None: loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) else: loss, completion_length, mean_kl = grpo_accumulated_loss( From 15442d1036e9574e9398cc85ec8b576d6196ebf1 Mon Sep 17 00:00:00 2001 From: Seth Weidman Date: Wed, 19 Feb 2025 02:12:07 -0800 Subject: [PATCH 410/473] Update rl_replacements.py (#1754) Fix typo in comment: know -> now. This was printed when running the Llama3.1_(8B)-GRPO.ipynb example notebook, so I'd expect others to run into it as well. --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 5fa4ec5a4..c8caa1b58 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -268,7 +268,7 @@ def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): check_batch_size = \ "div = per_device_train_batch_size // num_generations\n"\ "if div * num_generations != per_device_train_batch_size:\n"\ - " print('Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\ + " print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\ "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"\ " per_device_train_batch_size = num_generations\n" return check_batch_size From 91ab43dbd40788cbea2098c76991fef21bb05c1e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 02:23:00 -0800 Subject: [PATCH 411/473] Optional logits --- unsloth/models/llama.py | 23 ++++++++++++++++++----- unsloth/models/rl.py | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index f34968c3a..27651be97 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1076,6 +1076,19 @@ def _CausalLM_fast_forward( dtype = lm_head.dtype num_logits_to_keep = max(num_logits_to_keep, logits_to_keep) + # Output last hidden states without logits if asked + if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": + if num_logits_to_keep != 0: + hidden_states = hidden_states[:, -num_logits_to_keep:, :] + return CausalLMOutputWithPast( + loss = None, + logits = hidden_states, + past_key_values = outputs.past_key_values, + hidden_states = outputs.hidden_states, + attentions= outputs.attentions, + ) + pass + if bsz == 1 and q_len == 1: logits = torch.mv(lm_head, hidden_states.ravel().to(dtype)) logits = logits.unsqueeze(0).unsqueeze(0) @@ -1169,11 +1182,11 @@ def _CausalLM_fast_forward( return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + loss = loss, + logits = logits, + past_key_values = outputs.past_key_values, + hidden_states = outputs.hidden_states, + attentions= outputs.attentions, ) pass return _CausalLM_fast_forward diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 29773d0a8..6947be81a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -481,7 +481,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = True, + overwrite = False, ) # Patch Trainer From a6a5f609955ca3ef8bb98ecdb98f0d7815bf7558 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 03:41:51 -0800 Subject: [PATCH 412/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 6947be81a..fc92e1b32 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -519,7 +519,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import replacer = replacer[0] vllm_setter = "\n" + " "*8 + \ "if hasattr(model, 'vllm_engine') and "\ - "getattr(args, 'use_vllm') and getattr(args, 'use_vllm', False): "\ + "hasattr(trainer.args, 'use_vllm') and (getattr(trainer.args, 'use_vllm', False) == False): "\ "args.use_vllm = True\n" init = init.replace(replacer, replacer + vllm_setter) pass From 83ce085c881796a04d1c5bf17ced356b4f230ca9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 12:47:15 -0800 Subject: [PATCH 413/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index fc92e1b32..85b66e3f8 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -519,7 +519,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import replacer = replacer[0] vllm_setter = "\n" + " "*8 + \ "if hasattr(model, 'vllm_engine') and "\ - "hasattr(trainer.args, 'use_vllm') and (getattr(trainer.args, 'use_vllm', False) == False): "\ + "hasattr(self.args, 'use_vllm') and (getattr(self.args, 'use_vllm', False) == False): "\ "args.use_vllm = True\n" init = init.replace(replacer, replacer + vllm_setter) pass From 8ece11ffbaa74a86a5be07096189d1acbdf8825e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 12:51:22 -0800 Subject: [PATCH 414/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 85b66e3f8..48f04412f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -519,7 +519,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import replacer = replacer[0] vllm_setter = "\n" + " "*8 + \ "if hasattr(model, 'vllm_engine') and "\ - "hasattr(self.args, 'use_vllm') and (getattr(self.args, 'use_vllm', False) == False): "\ + "hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): "\ "args.use_vllm = True\n" init = init.replace(replacer, replacer + vllm_setter) pass From bc6bfae66331e341ab85b2a514e93ee1f0229131 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 16:37:12 -0800 Subject: [PATCH 415/473] Update rl.py --- unsloth/models/rl.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 48f04412f..9980d3278 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -481,7 +481,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = False, + overwrite = True, ) # Patch Trainer @@ -547,6 +547,13 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import ) if len(sampling_params) == 1: sampling_params = sampling_params[0] + + # Fix guided_decoding + sampling_params = sampling_params.replace( + "guided_decoding=guided_decoding,", + 'GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) '\ + 'if getattr(args, "vllm_guided_decoding_regex", None) is not None else None', + ) # Replace with our vLLM engine sampling_params = \ " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ From 95fb6a49f2aca9ace6aab6fa9a34d3ed8f4817d1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 16:38:44 -0800 Subject: [PATCH 416/473] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 9980d3278..e977d2f91 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -551,6 +551,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import # Fix guided_decoding sampling_params = sampling_params.replace( "guided_decoding=guided_decoding,", + 'guided_decoding='\ 'GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) '\ 'if getattr(args, "vllm_guided_decoding_regex", None) is not None else None', ) From ba01cf500d41cb369ba31d894711480094d8b485 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 16:40:37 -0800 Subject: [PATCH 417/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index e977d2f91..24f503dc6 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -553,7 +553,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import "guided_decoding=guided_decoding,", 'guided_decoding='\ 'GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) '\ - 'if getattr(args, "vllm_guided_decoding_regex", None) is not None else None', + 'if getattr(args, "vllm_guided_decoding_regex", None) is not None else None,', ) # Replace with our vLLM engine sampling_params = \ From eb48b98bcf08ac10ef6b15cdddba2106792d3b42 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 17:58:25 -0800 Subject: [PATCH 418/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 24f503dc6..1aacade93 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -481,7 +481,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = True, + overwrite = False, ) # Patch Trainer From 3c750a1608d8f0dfbd424616a0ce76c4b056fb19 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 21:41:17 -0800 Subject: [PATCH 419/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 1aacade93..24f503dc6 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -481,7 +481,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = False, + overwrite = True, ) # Patch Trainer From 515cf5a764d61cbfb5beea7f2041d3b8c4229f8c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 22:03:47 -0800 Subject: [PATCH 420/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index c8caa1b58..5d6201dd2 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -200,8 +200,10 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"] grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss)) # Edit _get_per_token_logps to handle mixed precision From 2cf4349740d98d2519184fdf0663a222c801fc74 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:06:18 -0800 Subject: [PATCH 421/473] Update rl.py --- unsloth/models/rl.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 24f503dc6..c8602d31b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -36,12 +36,19 @@ torch_compile_options = { "epilogue_fusion" : True, - "max_autotune" : True, + "max_autotune" : False, # Disable Triton mm kernels "shape_padding" : True, "trace.enabled" : False, "triton.cudagraphs" : False, } + +def vLLMSamplingParams(**kwargs): + sampling_params = SamplingParams(**kwargs) + sampling_params._set_kwargs = kwargs + return sampling_params +pass + def PatchRL(FastLanguageModel): from trl.models.utils import unwrap_model_for_generation @@ -99,7 +106,7 @@ def generate_with_clone(*args, **kwargs): from torch.nn import functional as F torch_compile_options = {{ "epilogue_fusion" : True, - "max_autotune" : True, + "max_autotune" : False, "shape_padding" : True, "trace.enabled" : False, "triton.cudagraphs" : False, @@ -128,6 +135,7 @@ def __init__({RLConfig_arguments}, ): {RLConfig_extra_args} super().__init__({RLConfig_call_args}{RLConfig_kwargs}) + assert(hasattr(vllm_sampling_params, '_set_kwargs')) self.vllm_sampling_params = vllm_sampling_params self.unsloth_num_chunks = unsloth_num_chunks pass @@ -441,6 +449,11 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RL_pre = "" pass + # Check if SamplingParams is in there + if "SamplingParams" in RLTrainer_source: + RL_pre = RL_pre + "\n" + inspect.getsource(vLLMSamplingParams) + pass + # Selective log softmax selective_log_softmax_code = inspect.getsource(selective_log_softmax) @@ -559,10 +572,17 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import sampling_params = \ " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ sampling_params # Add spaces + + # Add extra arguments to SamplingParams + extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams())), '_set_kwargs', {})" + sampling_params = sampling_params.replace(")", "," + extra + "," + ")") + # Strip multiple commas + sampling_params = re.sub(r"[\,][\s]{0,}\,", ",", sampling_params) + new_vllm_part = \ - f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ - f"if getattr(args, 'vllm_sampling_params', None) is None else "\ - f"getattr(args, 'vllm_sampling_params', None)\n{' '*8}else:\n" + f"\n{' '*8}if {args}.use_vllm:\n{sampling_params}"\ + f"\n{' '*8}else:\n" + init = init.replace(vllm_part, new_vllm_part) pass pass From ae8bf68e4dd3fafe4378c5b24b4220737f5292dd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:15:13 -0800 Subject: [PATCH 422/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c8602d31b..f754fa953 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -450,7 +450,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pass # Check if SamplingParams is in there - if "SamplingParams" in RLTrainer_source: + if "SamplingParams" in old_RLTrainer_source: RL_pre = RL_pre + "\n" + inspect.getsource(vLLMSamplingParams) pass From e07f4bc303010c27587da253a49a4d8d0b1f0280 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:23:39 -0800 Subject: [PATCH 423/473] Update rl.py --- unsloth/models/rl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f754fa953..38f9ab5a0 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -575,7 +575,9 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import # Add extra arguments to SamplingParams extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams())), '_set_kwargs', {})" - sampling_params = sampling_params.replace(")", "," + extra + "," + ")") + # Backwards replace + to_replace = "," + extra + "," + ")" + sampling_params = to_replace.join(sampling_params.rsplit(")", 1)) # Strip multiple commas sampling_params = re.sub(r"[\,][\s]{0,}\,", ",", sampling_params) From 3fccf5d6b0355a911e25ae7627dd5cb66ce26a0f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:27:16 -0800 Subject: [PATCH 424/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 38f9ab5a0..3ab45cdf7 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -574,7 +574,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import sampling_params # Add spaces # Add extra arguments to SamplingParams - extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams())), '_set_kwargs', {})" + extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {})" # Backwards replace to_replace = "," + extra + "," + ")" sampling_params = to_replace.join(sampling_params.rsplit(")", 1)) From 798ad9588118899e73178810ff5e90d2afeb5642 Mon Sep 17 00:00:00 2001 From: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com> Date: Thu, 20 Feb 2025 08:32:25 +0100 Subject: [PATCH 425/473] fix an import error (#1767) * fix an import error * Delete .gitignore * Update loader.py * Update save.py --------- Co-authored-by: Daniel Han --- unsloth/models/loader.py | 10 +++++++--- unsloth/save.py | 10 +++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 39b367e27..186545cf0 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -24,10 +24,14 @@ from .loader_utils import get_model_name import os, contextlib, sys try: - from huggingface_hub.utils import get_token + from huggingface_hub import get_token except: - # Old HF Hub versions <= 0.0.25 - from huggingface_hub.utils._token import get_token + try: + from huggingface_hub.utils import get_token + except: + # For older versions of huggingface_hub + from huggingface_hub.utils._token import get_token + pass pass from huggingface_hub import HfFileSystem import importlib.util diff --git a/unsloth/save.py b/unsloth/save.py index 0f75ecfd0..eaddfa05c 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -31,10 +31,14 @@ from .tokenizer_utils import fix_sentencepiece_gguf from huggingface_hub import HfApi try: - from huggingface_hub.utils import get_token + from huggingface_hub import get_token except: - # Old HF Hub versions <= 0.0.25 - from huggingface_hub.utils._token import get_token + try: + from huggingface_hub.utils import get_token + except: + # For older versions of huggingface_hub + from huggingface_hub.utils._token import get_token + pass pass from pathlib import Path From 2957d89d6786d100c92c608f4d73c5146f8abc06 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:37:52 -0800 Subject: [PATCH 426/473] SamplingParams --- unsloth/models/__init__.py | 2 +- unsloth/models/rl.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index b15e04ab7..29ad78dae 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -20,4 +20,4 @@ from .qwen2 import FastQwen2Model from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported -from .rl import PatchFastRL +from .rl import PatchFastRL, vLLMSamplingParams diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3ab45cdf7..572caf594 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -44,6 +44,7 @@ def vLLMSamplingParams(**kwargs): + from vllm import SamplingParams sampling_params = SamplingParams(**kwargs) sampling_params._set_kwargs = kwargs return sampling_params From 19d57bcae6cece5ab4d31836c762f60e2dfa9256 Mon Sep 17 00:00:00 2001 From: Edd <68678137+Erland366@users.noreply.github.com> Date: Thu, 20 Feb 2025 11:38:48 +0400 Subject: [PATCH 427/473] Convert mask to float (#1762) --- unsloth/models/llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 27651be97..909dfc339 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -775,9 +775,12 @@ def LlamaModel_fast_forward( self.SWA_mask = True self.GA_mask = False elif attention_mask is not None: - # Fixes https://github.com/unslothai/unsloth/issues/853 # Unsloth needs a 2D mask, not a [2, 1, n, n] mask! + + # https://github.com/pytorch/pytorch/issues/103749 + # Need to convert to float and not using bool + attention_mask = (1.0 - attention_mask.float()) * torch.finfo(inputs_embeds.dtype).min dynamic_SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), From 07aea401fab4b916b8ea41f7c52c218c619bf534 Mon Sep 17 00:00:00 2001 From: Ben <6579034+versipellis@users.noreply.github.com> Date: Wed, 19 Feb 2025 23:40:07 -0800 Subject: [PATCH 428/473] [Windows Support] Add latest `xformers` wheels to pyproject.toml (#1753) * Add latest xformers * Add a couple of lines to docs --- README.md | 7 +++++-- pyproject.toml | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 45312a43d..4bdd7e289 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://git ### Windows Installation To run Unsloth directly on Windows: -- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows +- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows (be aware that the Windows fork requires PyTorch >= 2.4 and CUDA 12) - In the SFTTrainer, set `dataset_num_proc=1` to avoid a crashing issue: ```python trainer = SFTTrainer( @@ -202,12 +202,15 @@ trainer = SFTTrainer( ) ``` +### Advanced/Troubleshooting + For **advanced installation instructions** or if you see weird errors during installations: 1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton` 2. Confirm if CUDA is installated correctly. Try `nvcc`. If that fails, you need to install `cudatoolkit` or CUDA drivers. 3. Install `xformers` manually. You can try installing `vllm` and seeing if `vllm` succeeds. Check if `xformers` succeeded with `python -m xformers.info` Go to https://github.com/facebookresearch/xformers. Another option is to install `flash-attn` for Ampere GPUs. -4. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes` +4. Double check that your versions of Python, CUDA, CUDNN, `torch`, `triton`, and `xformers` are compatible with one another. The [PyTorch Compatibility Matrix](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix) may be useful. +5. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes` ## 📜 [Documentation](https://docs.unsloth.ai) - Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more! diff --git a/pyproject.toml b/pyproject.toml index 59a7c4473..07085adcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -196,6 +196,10 @@ cu126onlytorch260 = [ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu118 = [ "unsloth[huggingface]", From f3d9efb40ca611acd2354341b78a272f9491f530 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:43:52 -0800 Subject: [PATCH 429/473] vLLMSamplingParams --- unsloth/__init__.py | 1 + unsloth/models/rl.py | 1 + 2 files changed, 2 insertions(+) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index f0600f332..ee3024bc9 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -210,6 +210,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 pass from .models import * +from .rl import vLLMSamplingParams from .save import * from .chat_templates import * from .tokenizer_utils import * diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 572caf594..0207f1c9b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -14,6 +14,7 @@ __all__ = [ "PatchFastRL", + "vLLMSamplingParams", ] import torch From 6d5caca27196a1d13d00491c6c248098ce6bfe29 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:45:07 -0800 Subject: [PATCH 430/473] Update __init__.py --- unsloth/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index ee3024bc9..f0600f332 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -210,7 +210,6 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 pass from .models import * -from .rl import vLLMSamplingParams from .save import * from .chat_templates import * from .tokenizer_utils import * From 3a5610e53fdde2406087f388f65e2139f77fc11c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:51:06 -0800 Subject: [PATCH 431/473] default num_chunks == -1 --- unsloth/models/rl.py | 6 +++--- unsloth/models/rl_replacements.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 0207f1c9b..f6b3fdbf3 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -127,12 +127,12 @@ class Unsloth{RLConfig_name}({RLConfig_name}): metadata = {{'help': 'vLLM SamplingParams'}}, ) unsloth_num_chunks : Optional[int] = field( - default = 1, - metadata = {{'help': 'Chunk size to reduce memory usage'}}, + default = -1, + metadata = {{'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}}, ) def __init__({RLConfig_arguments}, vllm_sampling_params = None, - unsloth_num_chunks = 1, + unsloth_num_chunks = -1, **kwargs, ): {RLConfig_extra_args} diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 5d6201dd2..23b31172f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - if self.args.unsloth_num_chunks != 1: return None + return None # Unsloth efficient GRPO if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): From 0362bd22faf0d4206b5a2e977a181ed9168c7de7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 04:22:17 -0800 Subject: [PATCH 432/473] Versioning --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- unsloth/models/_utils.py | 2 +- unsloth/models/mapper.py | 5 ----- 4 files changed, 4 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 07085adcc..96aa0696f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.2.5", + "unsloth_zoo>=2025.2.6", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -348,7 +348,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.2.5", + "unsloth_zoo>=2025.2.6", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index f0600f332..a3b3e68b2 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -196,7 +196,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.2.4"): + if Version(unsloth_zoo_version) < Version("2025.2.6"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0c51c174f..52b371091 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.12" +__version__ = "2025.2.13" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 2e85d3014..da7f449bb 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -601,11 +601,6 @@ "Qwen/Qwen2.5-VL-72B-Instruct", "unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit", ), - "unsloth/DeepHermes-3-Llama-3-8B-Preview-unsloth-bnb-4bit" : ( - "unsloth/DeepHermes-3-Llama-3-8B-Preview", - "NousResearch/DeepHermes-3-Llama-3-8B-Preview", - "unsloth/DeepHermes-3-Llama-3-8B-Preview-bnb-4bit", - ), "unsloth/DeepScaleR-1.5B-Preview-unsloth-bnb-4bit" : ( "unsloth/DeepHermes-3-Llama-3-8B-Preview", "agentica-org/DeepScaleR-1.5B-Preview", From b5eda24d81808f36562daae7ae44b5a84f43b0b8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:01:14 -0800 Subject: [PATCH 433/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 909dfc339..579376cdd 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,6 +449,7 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: + print(attention_mask) # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From 7de002246fe0c60769b2874e750ec7964bf0bc1c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:25:31 -0800 Subject: [PATCH 434/473] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 579376cdd..4d8ec1367 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,12 +449,12 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: - print(attention_mask) # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! + Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2)#.contiguous() + A = A.transpose(1, 2).contiguous() else: if n_groups != 1: K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) From d4d7694dd950053f9422d7e38963530a59efa15c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:36:23 -0800 Subject: [PATCH 435/473] Update llama.py --- unsloth/models/llama.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4d8ec1367..f19609fa4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -247,7 +247,7 @@ def LlamaAttention_fast_forward_inference( # Grouped query attention _, _, cached_len, _ = Knn.shape - if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1: + if True: Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) @@ -266,10 +266,7 @@ def LlamaAttention_fast_forward_inference( A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) A = torch_matmul(A, Vnn, out = Qn) else: - if SDPA_HAS_GQA: - A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True) - else: - A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) pass A = A.transpose(1, 2) A = A.reshape(bsz, 1, attention_size) From 0bbfbe802ec32930b5262d8b087ad5cc15dea493 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:40:45 -0800 Subject: [PATCH 436/473] Update llama.py --- unsloth/models/llama.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index f19609fa4..44765fdd9 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -247,7 +247,7 @@ def LlamaAttention_fast_forward_inference( # Grouped query attention _, _, cached_len, _ = Knn.shape - if True: + if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1: Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) @@ -266,7 +266,10 @@ def LlamaAttention_fast_forward_inference( A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) A = torch_matmul(A, Vnn, out = Qn) else: - A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) + if SDPA_HAS_GQA: + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True) + else: + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) pass A = A.transpose(1, 2) A = A.reshape(bsz, 1, attention_size) @@ -448,10 +451,9 @@ def LlamaAttention_fast_forward( if SDPA_HAS_GQA: # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! - Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2).contiguous() + A = A.transpose(1, 2)#.contiguous() else: if n_groups != 1: K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) @@ -723,8 +725,8 @@ def LlamaModel_fast_forward( past_key_values_length, sliding_window = getattr(self.config, "sliding_window", None), ) - if attention_mask is not None: - attention_mask = attention_mask.to(torch.bool) + # if attention_mask is not None: + # attention_mask = attention_mask.to(torch.bool) pass hidden_states = inputs_embeds From ae6e2bd67127f11e602f7ecb832489e58a31de45 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:46:14 -0800 Subject: [PATCH 437/473] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 44765fdd9..3e0717a87 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -725,6 +725,7 @@ def LlamaModel_fast_forward( past_key_values_length, sliding_window = getattr(self.config, "sliding_window", None), ) + # Must NOT convert to bool - weirdly this causes stuff to error out! # if attention_mask is not None: # attention_mask = attention_mask.to(torch.bool) pass From 1792deb7338a8475e70cd8fa6288f18da672ddba Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:51:33 -0800 Subject: [PATCH 438/473] Update _utils.py --- unsloth/models/_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 382024512..e1259af3a 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -143,6 +143,11 @@ def filter(self, x): return not (self.text in x.getMessage()) transformers_training_args_logger.addFilter(HideLoggingMessage("torch.distributed")) del transformers_training_args_logger +# No label_names provided for model class +from transformers.trainer import logger as transformers_trainer_logger +transformers_trainer_logger.addFilter(HideLoggingMessage("No label_names")) +del transformers_trainer_logger + # Using the default loss: `ForCausalLMLoss`. try: from transformers.modeling_utils import logger as transformers_modeling_utils_logger From 5dcd079e61a414a3043bfb3d5b06738f63d11def Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 08:28:21 -0800 Subject: [PATCH 439/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 23b31172f..dd4d5a0e8 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -165,6 +165,7 @@ def grpo_trainer__prepare_inputs(function_name, function): def grpo_trainer__move_model_to_vllm(function_name, function): if function_name != "_move_model_to_vllm": return function + print(function) # .*? matches first match. .+? matches final match. replacement = "def _move_model_to_vllm(self, *args, **kwargs): return None\n" return " "*function.find("def") + replacement From ec6e0b7ac25e71e2e76f7cbcc1cc76df1a0cf5e4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 08:31:37 -0800 Subject: [PATCH 440/473] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index dd4d5a0e8..06ae82140 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -164,11 +164,11 @@ def grpo_trainer__prepare_inputs(function_name, function): # Remove _move_model_to_vllm def grpo_trainer__move_model_to_vllm(function_name, function): if function_name != "_move_model_to_vllm": return function + + def _move_model_to_vllm(self, *args, **kwargs): return None - print(function) - # .*? matches first match. .+? matches final match. - replacement = "def _move_model_to_vllm(self, *args, **kwargs): return None\n" - return " "*function.find("def") + replacement + function = inspect.getsource(_move_model_to_vllm) + return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) From bc1d2cefa9582fec5de3788daff13c9de6b20c07 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 08:43:46 -0800 Subject: [PATCH 441/473] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 96aa0696f..e17fbfb32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ huggingface = [ "wheel>=0.42.0", "numpy", "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", "peft>=0.7.1,!=0.11.0", "protobuf<4.0.0", "huggingface_hub", @@ -366,7 +366,7 @@ colab-new = [ ] colab-no-deps = [ "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", "peft>=0.7.1", "xformers", "bitsandbytes>=0.46.1", From adbe38e6ca9c33826e073e196863d01ada762539 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 09:02:41 -0800 Subject: [PATCH 442/473] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e17fbfb32..14797c8fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.2.6", + "unsloth_zoo>=2025.2.7", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -348,7 +348,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.2.6", + "unsloth_zoo>=2025.2.7", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", From a9b542fa8e9b0c3fbb204262cbe8972d87a303bf Mon Sep 17 00:00:00 2001 From: Jyotin Goel <120490013+gjyotin305@users.noreply.github.com> Date: Sat, 22 Feb 2025 16:07:01 +0530 Subject: [PATCH 443/473] Export Model to ollama.com (#1648) * Ollama Export Model to ollama.com Signed-off-by: Jyotin Goel * Check for model_name Signed-off-by: Jyotin Goel * subprocess use instead of requests | added check for ollama server Signed-off-by: Jyotin Goel * create_ollama_model Signed-off-by: Jyotin Goel * create_ollama_model | fix Signed-off-by: Jyotin Goel * Push to Ollama Signed-off-by: Jyotin Goel --------- Signed-off-by: Jyotin Goel --- unsloth/save.py | 108 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/unsloth/save.py b/unsloth/save.py index eaddfa05c..6770d658c 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -17,6 +17,8 @@ from peft.tuners.lora import Linear4bit as Peft_Linear4bit from peft.tuners.lora import Linear as Peft_Linear from typing import Optional, Callable, Union, List +import sys +import requests import torch import os import shutil @@ -1613,6 +1615,112 @@ def create_ollama_modelfile(tokenizer, gguf_location): return modelfile pass +def create_ollama_model( + username: str, + model_name: str, + tag: str, + modelfile_path: str +): + try: + init_check = subprocess.run( + ['curl', 'http://localhost:11434'], capture_output=True, text=True, timeout=3 + ) + if init_check.returncode == 0: + print(init_check.stdout.strip()) + else: + print("Ollama Server is not Running") + except subprocess.TimeoutExpired: + return "Ollama Request Timeout" + + process = subprocess.Popen( + ['ollama', 'create', f'{username}/{model_name}:{tag}', '-f', f'{modelfile_path}'], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + for line in iter(process.stdout.readline, ''): + print(line, end='') + sys.stdout.flush() + + return_code = process.wait() + + if return_code != 0: + print(f"\nMODEL CREATED FAILED WITH RETURN CODE {return_code}") + else: + print("\nMODEL CREATED SUCCESSFULLY") +pass + + +def push_to_ollama_hub(username: str, model_name: str, tag: str): + try: + init_check = subprocess.run( + ['curl', 'http://localhost:11434'], capture_output=True, text=True, timeout=3 + ) + if init_check.returncode == 0: + print(init_check.stdout.strip()) + else: + print("Ollama Server is not Running") + except subprocess.TimeoutExpired: + return "Ollama Request Timeout" + + process = subprocess.Popen( + ['ollama', 'push', f'{username}/{model_name}:{tag}'], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + for line in iter(process.stdout.readline, ''): + print(line, end='') + sys.stdout.flush() + + return_code = process.wait() + + if return_code != 0: + print(f"\nMODEL PUBLISHED FAILED WITH RETURN CODE {return_code}") + else: + print("\nMODEL PUBLISHED SUCCESSFULLY") + + +def push_to_ollama( + tokenizer, + gguf_location, + username: str, + model_name: str, + tag: str +): + model_file = create_ollama_modelfile( + tokenizer=tokenizer, + gguf_location=gguf_location + ) + + with open(f"Modelfile_{model_name}", "w") as f: + f.write(model_file) + f.close() + + create_ollama_model( + username=username, + model_name=model_name, + tag=tag, + modelfile_path=f"Modelfile_{model_name}" + ) + + push_to_ollama_hub( + username=username, + model_name=model_name, + tag=tag + ) + + print("Succesfully pushed to ollama") + + + + def unsloth_save_pretrained_gguf( self, From 9cab34721ce70481180377b2e12656f2a7128c62 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Mar 2025 23:08:44 -0800 Subject: [PATCH 444/473] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index fcba2eb6d..1c9998e1c 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -279,10 +279,11 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : n_rows : int vocab_size : int n_rows, vocab_size = logits.shape + device = logits.device div, mod = divmod(vocab_size, MAX_FUSED_SIZE) n_chunks : int = div + (mod != 0) - losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + losses = torch.empty(n_rows, dtype = torch.float32, device = device) DO_SOFTCAPPING : bool = bool(logit_softcapping != 0) DO_LOGIT_SCALING : bool = bool(logit_scaling != 0) @@ -292,7 +293,7 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral BLOCK_SIZE, num_warps = calculate_settings(vocab_size) - logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device) _cross_entropy_forward[(n_rows,)]( logits, logits.stride(0), @@ -309,7 +310,7 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : ) else: # For large vocabs > 65336 like Gemma 256K - logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0") + logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = device) _chunked_cross_entropy_forward[(n_rows, n_chunks,)]( logits, logits.stride(0), From 0ae908247ec45f15ee12959af7d5fa33a0731eb7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Mar 2025 23:31:16 -0800 Subject: [PATCH 445/473] torch_cuda_device --- unsloth/kernels/cross_entropy_loss.py | 91 +++++++++++++++------------ unsloth/kernels/geglu.py | 18 ++++-- unsloth/kernels/layernorm.py | 48 +++++++------- unsloth/kernels/rms_layernorm.py | 47 +++++++------- unsloth/kernels/rope_embedding.py | 42 +++++++------ unsloth/kernels/swiglu.py | 8 ++- unsloth/kernels/utils.py | 1 + 7 files changed, 140 insertions(+), 115 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 1c9998e1c..006dfff63 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -15,7 +15,13 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh, triton_cast +from .utils import ( + calculate_settings, + MAX_FUSED_SIZE, + triton_tanh, + triton_cast, + torch_cuda_device, +) from transformers.models.llama.modeling_llama import logger from packaging.version import Version @@ -295,37 +301,39 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : BLOCK_SIZE, num_warps = calculate_settings(vocab_size) logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device) - _cross_entropy_forward[(n_rows,)]( - logits, logits.stride(0), - losses, - logsumexp, - labels, - VOCAB_SIZE = vocab_size, - BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = DO_SOFTCAPPING, - SOFTCAP = logit_softcapping, - DO_LOGIT_SCALING = DO_LOGIT_SCALING, - LOGIT_SCALE = logit_scaling, - num_warps = num_warps, - ) + with torch_cuda_device(device): + _cross_entropy_forward[(n_rows,)]( + logits, logits.stride(0), + losses, + logsumexp, + labels, + VOCAB_SIZE = vocab_size, + BLOCK_SIZE = BLOCK_SIZE, + DO_SOFTCAPPING = DO_SOFTCAPPING, + SOFTCAP = logit_softcapping, + DO_LOGIT_SCALING = DO_LOGIT_SCALING, + LOGIT_SCALE = logit_scaling, + num_warps = num_warps, + ) else: # For large vocabs > 65336 like Gemma 256K logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = device) - _chunked_cross_entropy_forward[(n_rows, n_chunks,)]( - logits, logits.stride(0), - losses, - logsumexp, - labels, - VOCAB_SIZE = vocab_size, - N_CHUNKS = n_chunks, - BLOCK_SIZE = MAX_FUSED_SIZE, - DO_SOFTCAPPING = DO_SOFTCAPPING, - SOFTCAP = logit_softcapping, - DO_LOGIT_SCALING = DO_LOGIT_SCALING, - LOGIT_SCALE = logit_scaling, - num_warps = 32, - ) + with torch_cuda_device(device): + _chunked_cross_entropy_forward[(n_rows, n_chunks,)]( + logits, logits.stride(0), + losses, + logsumexp, + labels, + VOCAB_SIZE = vocab_size, + N_CHUNKS = n_chunks, + BLOCK_SIZE = MAX_FUSED_SIZE, + DO_SOFTCAPPING = DO_SOFTCAPPING, + SOFTCAP = logit_softcapping, + DO_LOGIT_SCALING = DO_LOGIT_SCALING, + LOGIT_SCALE = logit_scaling, + num_warps = 32, + ) # logsumexp(chunked_logsumexp) - x # Do the -x separately logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum @@ -355,19 +363,20 @@ def backward(ctx, dlosses): div, mod = divmod(vocab_size, BLOCK_SIZE) n_blocks : int = div + (mod != 0) - _cross_entropy_backward[(n_rows, n_blocks,)]( - logits, logits.stride(0), - dlosses, dlosses.stride(0), - logsumexp, - labels, - VOCAB_SIZE = vocab_size, - BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, - SOFTCAP = ctx.logit_softcapping, - DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING, - LOGIT_SCALE = ctx.logit_scaling, - num_warps = 8, - ) + with torch_cuda_device(dlosses.device): + _cross_entropy_backward[(n_rows, n_blocks,)]( + logits, logits.stride(0), + dlosses, dlosses.stride(0), + logsumexp, + labels, + VOCAB_SIZE = vocab_size, + BLOCK_SIZE = BLOCK_SIZE, + DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, + SOFTCAP = ctx.logit_softcapping, + DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING, + LOGIT_SCALE = ctx.logit_scaling, + num_warps = 8, + ) return logits, None, None, None, pass pass diff --git a/unsloth/kernels/geglu.py b/unsloth/kernels/geglu.py index 9fedae769..d5a69aa67 100644 --- a/unsloth/kernels/geglu.py +++ b/unsloth/kernels/geglu.py @@ -15,7 +15,11 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings, triton_tanh +from .utils import ( + calculate_settings, + triton_tanh, + torch_cuda_device, +) @triton.jit @@ -43,7 +47,8 @@ def geglu_exact_forward_kernel(gate, up): n_elements = gate.numel() out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0") grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) + with torch_cuda_device(gate.device): + _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) return out pass @@ -99,7 +104,8 @@ def geglu_exact_backward_kernel(DW, e, g): batch_seq_len, hd = e.shape n_elements = e.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) + with torch_cuda_device(e.device): + _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) return DW, e, g pass @@ -135,7 +141,8 @@ def geglu_approx_forward_kernel(gate, up): n_elements = gate.numel() out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0") grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) + with torch_cuda_device(gate.device): + _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) return out pass @@ -198,6 +205,7 @@ def geglu_approx_backward_kernel(DW, e, g): batch_seq_len, hd = e.shape n_elements = e.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) + with torch_cuda_device(e.device): + _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) return DW, e, g pass diff --git a/unsloth/kernels/layernorm.py b/unsloth/kernels/layernorm.py index ffcc5cc13..26a77f03a 100644 --- a/unsloth/kernels/layernorm.py +++ b/unsloth/kernels/layernorm.py @@ -16,7 +16,7 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings +from .utils import calculate_settings, torch_cuda_device from unsloth_zoo.patching_utils import ( patch_layernorm, ) @@ -111,17 +111,18 @@ def forward(ctx, X, W, b, eps): r = torch.empty(n_rows, dtype = torch.float32, device = device) mu = torch.empty(n_rows, dtype = torch.float32, device = device) - layernorm_forward[(n_rows,)]( - Y, Y.stride(0), - X, X.stride(0), - W, - b, - r, - mu, - n_cols, eps, - BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, - ) + with torch_cuda_device(device): + layernorm_forward[(n_rows,)]( + Y, Y.stride(0), + X, X.stride(0), + W, + b, + r, + mu, + n_cols, eps, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps @@ -137,17 +138,18 @@ def backward(ctx, dY): X, W, b, r, mu = ctx.saved_tensors n_rows, n_cols = dY.shape - layernorm_backward[(n_rows,)]( - dY, dY.stride(0), - X, X .stride(0), - W, - b, - r, - mu, - n_cols, ctx.eps, - BLOCK_SIZE = ctx.BLOCK_SIZE, - num_warps = ctx.num_warps, - ) + with torch_cuda_device(dY.device): + layernorm_backward[(n_rows,)]( + dY, dY.stride(0), + X, X .stride(0), + W, + b, + r, + mu, + n_cols, ctx.eps, + BLOCK_SIZE = ctx.BLOCK_SIZE, + num_warps = ctx.num_warps, + ) dX = dY.view(*shape) return dX, None, None, None, None pass diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 7487c10ee..1cde6388e 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -15,8 +15,7 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings - +from .utils import calculate_settings, torch_cuda_device @triton.jit def _rms_layernorm_forward( @@ -154,15 +153,16 @@ def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool = r = torch.empty(n_rows, dtype = torch.float32, device = device) fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward - fx[(n_rows,)]( - Y, Y.stride(0), - X, X.stride(0), - W, W.stride(0), - r, r.stride(0), - n_cols, eps, - BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, - ) + with torch_cuda_device(device): + fx[(n_rows,)]( + Y, Y.stride(0), + X, X.stride(0), + W, W.stride(0), + r, r.stride(0), + n_cols, eps, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps @@ -183,18 +183,19 @@ def backward(ctx, dY : torch.Tensor): # dW = X dX = torch.empty_like(dY) if ctx.GEMMA else dY - _rms_layernorm_backward[(n_rows,)]( - dY, dY.stride(0), - dX, dX.stride(0), - X, X .stride(0), - W, W .stride(0), - r, r .stride(0), - # dW, dW.stride(0), - n_cols, ctx.eps, - GEMMA = ctx.GEMMA, - BLOCK_SIZE = ctx.BLOCK_SIZE, - num_warps = ctx.num_warps, - ) + with torch_cuda_device(dY.device): + _rms_layernorm_backward[(n_rows,)]( + dY, dY.stride(0), + dX, dX.stride(0), + X, X .stride(0), + W, W .stride(0), + r, r .stride(0), + # dW, dW.stride(0), + n_cols, ctx.eps, + GEMMA = ctx.GEMMA, + BLOCK_SIZE = ctx.BLOCK_SIZE, + num_warps = ctx.num_warps, + ) dX = dX.view(*shape) return dX, None, None, None pass diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py index 88b9ccadb..a14a48535 100644 --- a/unsloth/kernels/rope_embedding.py +++ b/unsloth/kernels/rope_embedding.py @@ -15,7 +15,7 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings +from .utils import calculate_settings, torch_cuda_device ROPE_GROUP_SIZE : int = 4 def _rope_embedding( @@ -100,16 +100,17 @@ def forward(ctx, Q, cos, sin): div, mod = divmod(n_heads, ROPE_GROUP_SIZE) n_groups : int = div + (mod != 0) - _rope_embedding[(n_rows, n_groups, )]( - Q, Q.stride(0), - cos, cos.stride(0), - sin, sin.stride(0), - seq_len, - head_dim, n_heads, - BACKWARD_PASS = False, - BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, - ) + with torch_cuda_device(Q.device): + _rope_embedding[(n_rows, n_groups, )]( + Q, Q.stride(0), + cos, cos.stride(0), + sin, sin.stride(0), + seq_len, + head_dim, n_heads, + BACKWARD_PASS = False, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps ctx.n_groups = n_groups @@ -134,15 +135,16 @@ def backward(ctx, dY): cos = ctx.cos sin = ctx.sin - _rope_embedding[(n_rows, ctx.n_groups, )]( - dY, dY .stride(0), - cos, cos.stride(0), - sin, sin.stride(0), - seq_len, head_dim, n_heads, - BACKWARD_PASS = True, - BLOCK_SIZE = ctx.BLOCK_SIZE, - num_warps = ctx.num_warps, - ) + with torch_cuda_device(dY.device): + _rope_embedding[(n_rows, ctx.n_groups, )]( + dY, dY .stride(0), + cos, cos.stride(0), + sin, sin.stride(0), + seq_len, head_dim, n_heads, + BACKWARD_PASS = True, + BLOCK_SIZE = ctx.BLOCK_SIZE, + num_warps = ctx.num_warps, + ) dY = dY.view(batch, seq_len, n_heads, head_dim) return dY, None, None, pass diff --git a/unsloth/kernels/swiglu.py b/unsloth/kernels/swiglu.py index 688e9f9a4..12f1f5e06 100644 --- a/unsloth/kernels/swiglu.py +++ b/unsloth/kernels/swiglu.py @@ -15,7 +15,7 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings +from .utils import calculate_settings, torch_cuda_device @triton.jit @@ -43,7 +43,8 @@ def swiglu_fg_kernel(e, g): n_elements = e.numel() h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device) grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,) + with torch_cuda_device(e.device): + _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,) return h pass @@ -94,6 +95,7 @@ def swiglu_DWf_DW_dfg_kernel(DW, e, g): batch_seq_len, hd = e.shape n_elements = e.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) + with torch_cuda_device(e.device): + _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) return DW, e, g pass diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 985adaaa4..4439a47f2 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -27,6 +27,7 @@ torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda") torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda") pass +torch_cuda_device = torch.cuda.device # tl.math.tanh now is libdevice.tanh From f21314c1c096f742f1b1b38ffefba9b9d299c50c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Mar 2025 23:38:32 -0800 Subject: [PATCH 446/473] Update utils.py --- unsloth/kernels/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 4439a47f2..7cd51e9ff 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -139,6 +139,7 @@ def get_lora_parameters_bias(proj): if HAS_CUDA_STREAM: @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): + use_global_buffer = False if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class From 9215212724896f9073b22e07c7d56dc13706505c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Mar 2025 23:41:35 -0800 Subject: [PATCH 447/473] Update utils.py --- unsloth/kernels/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 7cd51e9ff..1d4b494dd 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -451,7 +451,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype - W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = False) if X.dim() == 3: batch, seq_len, d = X.shape @@ -461,6 +461,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): reshape = False pass + print(X.device, W.device, torch.cuda.current_device()) out = torch_matmul(X, W, out = out) if W_quant is not None: del W From 9d95aeee8d4db1b05bc629188367d3a21362cbdd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Mar 2025 23:43:02 -0800 Subject: [PATCH 448/473] Update utils.py --- unsloth/kernels/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 1d4b494dd..eb3a2e38c 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -460,8 +460,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): else: reshape = False pass - - print(X.device, W.device, torch.cuda.current_device()) + out = torch_matmul(X, W, out = out) if W_quant is not None: del W From 35e9144a015f4cbe8a847a91e43ea277c3c86c21 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Mar 2025 23:58:17 -0800 Subject: [PATCH 449/473] device --- unsloth/kernels/geglu.py | 10 ++++++---- unsloth/kernels/utils.py | 4 +++- unsloth/models/llama.py | 1 + 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/unsloth/kernels/geglu.py b/unsloth/kernels/geglu.py index d5a69aa67..1ece87c08 100644 --- a/unsloth/kernels/geglu.py +++ b/unsloth/kernels/geglu.py @@ -45,9 +45,10 @@ def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): def geglu_exact_forward_kernel(gate, up): batch, seq_len, hd = gate.shape n_elements = gate.numel() - out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0") + device = gate.device + out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device) grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - with torch_cuda_device(gate.device): + with torch_cuda_device(device): _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) return out pass @@ -139,9 +140,10 @@ def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): def geglu_approx_forward_kernel(gate, up): batch, seq_len, hd = gate.shape n_elements = gate.numel() - out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0") + device = gate.device + out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device) grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - with torch_cuda_device(gate.device): + with torch_cuda_device(device): _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) return out pass diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index eb3a2e38c..2c4edf334 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -460,7 +460,9 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): else: reshape = False pass - + + if X.device != W.device: + print(X.device, W.device, torch.cuda.current_device()) out = torch_matmul(X, W, out = out) if W_quant is not None: del W diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fe0627f8d..7f475869c 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -385,6 +385,7 @@ def LlamaAttention_fast_forward( head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) + print(hidden_states.device, torch.cuda.current_device()) Q, K, V = self.apply_qkv(self, hidden_states) Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) From 30b6f9449c0ad38bbd99e00a5bb7f45fd9981b02 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 00:04:08 -0800 Subject: [PATCH 450/473] device --- unsloth/kernels/utils.py | 7 +++---- unsloth/models/llama.py | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 2c4edf334..6bb44fbd1 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -452,7 +452,9 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype W = fast_dequantize(W.t(), W_quant, use_global_buffer = False) - + if X.device != W.device: + print(X.device, W.device, torch.cuda.current_device()) + if X.dim() == 3: batch, seq_len, d = X.shape X = X.view(-1, X.shape[-1]) @@ -460,9 +462,6 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): else: reshape = False pass - - if X.device != W.device: - print(X.device, W.device, torch.cuda.current_device()) out = torch_matmul(X, W, out = out) if W_quant is not None: del W diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7f475869c..fe0627f8d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -385,7 +385,6 @@ def LlamaAttention_fast_forward( head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) - print(hidden_states.device, torch.cuda.current_device()) Q, K, V = self.apply_qkv(self, hidden_states) Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) From 64e2b00975520c9524d1511e31a1d3c58feef417 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 02:30:53 -0800 Subject: [PATCH 451/473] Update loader.py --- unsloth/models/loader.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 186545cf0..30128cd13 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -59,7 +59,15 @@ from .gemma2 import FastGemma2Model pass import torch - +from ._utils import ( + patch_compiling_bitsandbytes, + patch_model_and_tokenizer, + prepare_model_for_kbit_training, + patch_unsloth_smart_gradient_checkpointing, + patch_compiled_autograd, + process_vision_info, + unsloth_compile_transformers, +) class FastLanguageModel(FastLlamaModel): @staticmethod @@ -87,6 +95,10 @@ def from_pretrained( *args, **kwargs, ): if token is None: token = get_token() + assert (dtype is None or dtype == torch.float16 or dtype == torch.bfloat16) + + if use_gradient_checkpointing == "unsloth": + patch_unsloth_smart_gradient_checkpointing(dtype = dtype) if fast_inference: if importlib.util.find_spec("vllm") is None: @@ -367,15 +379,6 @@ def from_pretrained( pass -from ._utils import ( - patch_compiling_bitsandbytes, - patch_model_and_tokenizer, - prepare_model_for_kbit_training, - patch_unsloth_smart_gradient_checkpointing, - patch_compiled_autograd, - process_vision_info, - unsloth_compile_transformers, -) from ..kernels import ( patch_loss_functions, post_patch_loss_function, @@ -404,6 +407,7 @@ def from_pretrained( *args, **kwargs, ): if token is None: token = get_token() + assert (dtype is None or dtype == torch.float16 or dtype == torch.bfloat16) patch_compiled_autograd() patch_compiling_bitsandbytes() From ffa327862b6f87cabcc1d9ebaa02b4f18eeb941e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 02:36:16 -0800 Subject: [PATCH 452/473] Update llama.py --- unsloth/models/llama.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fe0627f8d..707091990 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -18,6 +18,7 @@ from functools import partial from typing import Optional, Tuple, List, Union from ._utils import * +from ._utils import patch_unsloth_smart_gradient_checkpointing from ._utils import __version__ from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version @@ -850,27 +851,14 @@ def LlamaModel_fast_forward( mask = self. GA_mask if use_static_mask else dynamic_GA_mask pass - if offloaded_gradient_checkpointing: - hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply( - decoder_layer, - hidden_states, - mask, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - None, - position_embeddings, - )[0] - - elif gradient_checkpointing: + if gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings) return custom_forward pass + print(torch.utils.checkpoint.checkpoint) layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, @@ -2034,6 +2022,9 @@ def get_peft_model( ): transformers_set_seed(random_state) + if use_gradient_checkpointing == "unsloth": + patch_unsloth_smart_gradient_checkpointing(dtype = model.get_input_embeddings().weight.dtype) + if type(r) is not int: raise TypeError(f"Unsloth: Rank of {str(r)} must be an integer.") if r <= 0: From 748c5b522d37c71bc068f3a56fba4d51205e7fe2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 14:58:30 -0800 Subject: [PATCH 453/473] Update README.md --- README.md | 62 +++++++++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 5b2dd6f12..5e4add0a3 100644 --- a/README.md +++ b/README.md @@ -242,10 +242,8 @@ For **advanced installation instructions** or if you see weird errors during ins ```python from unsloth import FastLanguageModel -from unsloth import is_bfloat16_supported import torch -from trl import SFTTrainer -from transformers import TrainingArguments +from trl import SFTTrainer, SFTConfig from datasets import load_dataset max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any! # Get LAION dataset @@ -254,21 +252,28 @@ dataset = load_dataset("json", data_files = {"train" : url}, split = "train") # 4bit pre quantized models we support for 4x faster downloading + no OOMs. fourbit_models = [ - "unsloth/mistral-7b-v0.3-bnb-4bit", # New Mistral v3 2x faster! + "unsloth/Meta-Llama-3.1-8B-bnb-4bit", # Llama-3.1 2x faster + "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", + "unsloth/Meta-Llama-3.1-70B-bnb-4bit", + "unsloth/Meta-Llama-3.1-405B-bnb-4bit", # 4bit for 405b! + "unsloth/Mistral-Small-Instruct-2409", # Mistral 22b 2x faster! "unsloth/mistral-7b-instruct-v0.3-bnb-4bit", - "unsloth/llama-3-8b-bnb-4bit", # Llama-3 15 trillion tokens model 2x faster! - "unsloth/llama-3-8b-Instruct-bnb-4bit", - "unsloth/llama-3-70b-bnb-4bit", - "unsloth/Phi-3-mini-4k-instruct", # Phi-3 2x faster! + "unsloth/Phi-3.5-mini-instruct", # Phi-3.5 2x faster! "unsloth/Phi-3-medium-4k-instruct", - "unsloth/mistral-7b-bnb-4bit", - "unsloth/gemma-7b-bnb-4bit", # Gemma 2.2x faster! + "unsloth/gemma-2-9b-bnb-4bit", + "unsloth/gemma-2-27b-bnb-4bit", # Gemma 2x faster! + + "unsloth/Llama-3.2-1B-bnb-4bit", # NEW! Llama 3.2 models + "unsloth/Llama-3.2-1B-Instruct-bnb-4bit", + "unsloth/Llama-3.2-3B-bnb-4bit", + "unsloth/Llama-3.2-3B-Instruct-bnb-4bit", + + "unsloth/Llama-3.3-70B-Instruct-bnb-4bit" # NEW! Llama 3.3 70B! ] # More models at https://huggingface.co/unsloth model, tokenizer = FastLanguageModel.from_pretrained( - model_name = "unsloth/llama-3-8b-bnb-4bit", + model_name = "unsloth/Llama-3.2-1B", max_seq_length = max_seq_length, - dtype = None, load_in_4bit = True, ) @@ -292,16 +297,14 @@ model = FastLanguageModel.get_peft_model( trainer = SFTTrainer( model = model, train_dataset = dataset, - dataset_text_field = "text", - max_seq_length = max_seq_length, tokenizer = tokenizer, - args = TrainingArguments( + args = SFTConfig( + dataset_text_field = "text", + max_seq_length = max_seq_length, per_device_train_batch_size = 2, gradient_accumulation_steps = 4, warmup_steps = 10, max_steps = 60, - fp16 = not is_bfloat16_supported(), - bf16 = is_bfloat16_supported(), logging_steps = 1, output_dir = "outputs", optim = "adamw_8bit", @@ -333,17 +336,14 @@ RL including DPO, GRPO, PPO, Reward Modelling, Online DPO all work with Unsloth. import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Optional set GPU device ID -from unsloth import FastLanguageModel, PatchDPOTrainer -from unsloth import is_bfloat16_supported -PatchDPOTrainer() +from unsloth import FastLanguageModel import torch -from transformers import TrainingArguments -from trl import DPOTrainer +from trl import DPOTrainer, DPOConfig +max_seq_length = 2048 model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/zephyr-sft-bnb-4bit", max_seq_length = max_seq_length, - dtype = None, load_in_4bit = True, ) @@ -365,24 +365,22 @@ model = FastLanguageModel.get_peft_model( dpo_trainer = DPOTrainer( model = model, ref_model = None, - args = TrainingArguments( + train_dataset = YOUR_DATASET_HERE, + # eval_dataset = YOUR_DATASET_HERE, + tokenizer = tokenizer, + args = DPOConfig( per_device_train_batch_size = 4, gradient_accumulation_steps = 8, warmup_ratio = 0.1, num_train_epochs = 3, - fp16 = not is_bfloat16_supported(), - bf16 = is_bfloat16_supported(), logging_steps = 1, optim = "adamw_8bit", seed = 42, output_dir = "outputs", + max_length = 1024, + max_prompt_length = 512, + beta = 0.1, ), - beta = 0.1, - train_dataset = YOUR_DATASET_HERE, - # eval_dataset = YOUR_DATASET_HERE, - tokenizer = tokenizer, - max_length = 1024, - max_prompt_length = 512, ) dpo_trainer.train() ``` From 469ed48cf4b38cc14570ae70dc0927b456f4164e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 15:48:55 -0800 Subject: [PATCH 454/473] Update llama.py --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 707091990..233f104ec 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -857,8 +857,7 @@ def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings) return custom_forward pass - - print(torch.utils.checkpoint.checkpoint) + layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, From bc87afde4113b3b183773cb17767eed10c61bf3d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 15:49:04 -0800 Subject: [PATCH 455/473] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 233f104ec..c7e630d42 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -857,7 +857,6 @@ def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings) return custom_forward pass - layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, From ee9d6e5955d7ad919a3710c4939a4e335c37812e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 17:12:56 -0800 Subject: [PATCH 456/473] Update _utils.py --- unsloth/models/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index cca77bb60..0f0d4c159 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -755,7 +755,8 @@ def offload_to_disk(W, model, name, temporary_location : str = "_unsloth_tempora filename = os.path.join(file_location, f"{name}.pt") W = W.weight if hasattr(W, "weight") else W torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,) - offloaded_W = torch.load(filename, map_location = "cpu", mmap = True) + # We must use weights_only = False due to pickling + offloaded_W = torch.load(filename, map_location = "cpu", mmap = True, weights_only = False) offloaded_W._offloaded_file_location = filename return offloaded_W pass From 91458bbcdcd582f38bb71376d71fd6f8e56a6b00 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 17:17:25 -0800 Subject: [PATCH 457/473] Update utils.py --- unsloth/kernels/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 6bb44fbd1..e699e632f 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -452,6 +452,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype W = fast_dequantize(W.t(), W_quant, use_global_buffer = False) + print(W) if X.device != W.device: print(X.device, W.device, torch.cuda.current_device()) From a7a5d75b830355c3b1583c58b5b0da79773ee850 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 17:27:59 -0800 Subject: [PATCH 458/473] Update utils.py --- unsloth/kernels/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index e699e632f..427c2233c 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -140,6 +140,7 @@ def get_lora_parameters_bias(proj): @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): use_global_buffer = False + print(W, quant_state) if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class From d93cca24a8a8e0dcc09712267a5886a35e481ec4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 18:29:35 -0800 Subject: [PATCH 459/473] Update utils.py --- unsloth/kernels/utils.py | 51 +++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 427c2233c..3dd2d8e40 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -93,27 +93,29 @@ def calculate_settings(n : int) -> (int, int,): cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16 - -def QUANT_STATE(W): - return getattr(W, "quant_state", None) -pass - +def QUANT_STATE(W): return getattr(W, "quant_state", None) def get_lora_parameters(proj): # For DPO or disabled adapters - base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj) + base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight - if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: - return W, QUANT_STATE(W), None, None, None + # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + if getattr(proj, "disable_adapters", True) or proj.merged: + return W, getattr(W, "quant_state", None), None, None, None pass - active_adapter = proj.active_adapters[0] if \ - hasattr(proj, "active_adapters") else proj.active_adapter - A = proj.lora_A [active_adapter].weight - B = proj.lora_B [active_adapter].weight - s = proj.scaling[active_adapter] - return W, QUANT_STATE(W), A, B, s + adapter = getattr(proj, "active_adapters", None) + if adapter is None: adapter = getattr(proj, "active_adapter", ("default")) + adapter = adapter[0] + + return ( + W, + getattr(W, "quant_state", None), + proj.lora_A [adapter].weight, + proj.lora_B [adapter].weight, + proj.scaling[adapter], + ) pass @@ -121,19 +123,24 @@ def get_lora_parameters_bias(proj): # For DPO or disabled adapters base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight - bias = base_layer.bias # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if getattr(proj, "disable_adapters", True) or proj.merged: - return W, QUANT_STATE(W), None, None, None, bias + return W, getattr(W, "quant_state", None), None, None, None, bias pass - active_adapter = proj.active_adapters[0] if \ - getattr(proj, "active_adapters", ) else proj.active_adapter - A = proj.lora_A [active_adapter].weight - B = proj.lora_B [active_adapter].weight - s = proj.scaling[active_adapter] - return W, QUANT_STATE(W), A, B, s, bias + adapter = getattr(proj, "active_adapters", None) + if adapter is None: adapter = getattr(proj, "active_adapter", ("default")) + adapter = adapter[0] + + return ( + W, + getattr(W, "quant_state", None), + proj.lora_A [adapter].weight, + proj.lora_B [adapter].weight, + proj.scaling[adapter], + base_layer.bias, + ) pass if HAS_CUDA_STREAM: From 6e2a3a8d772b9b3c26fbd39c441b63d8689a158e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 18:33:18 -0800 Subject: [PATCH 460/473] Update utils.py --- unsloth/kernels/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 3dd2d8e40..5b7be9a5f 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -147,7 +147,6 @@ def get_lora_parameters_bias(proj): @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): use_global_buffer = False - print(W, quant_state) if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class From 8f9ba99b76c519d4b6680b0edc93311b90d7b8ee Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 18:46:16 -0800 Subject: [PATCH 461/473] Update utils.py --- unsloth/kernels/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 5b7be9a5f..5bb0e337d 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -459,9 +459,6 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype W = fast_dequantize(W.t(), W_quant, use_global_buffer = False) - print(W) - if X.device != W.device: - print(X.device, W.device, torch.cuda.current_device()) if X.dim() == 3: batch, seq_len, d = X.shape From ed697da94535beb23f34bce147d77c02059cfd77 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:00:26 -0800 Subject: [PATCH 462/473] Update llama.py --- unsloth/models/llama.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c7e630d42..475f82a5b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -759,14 +759,9 @@ def LlamaModel_fast_forward( # Check checkpointing method gradient_checkpointing = False - offloaded_gradient_checkpointing = False if (self.gradient_checkpointing and self.training and not use_cache): - gradient_checkpointing = True - - if output_attentions is False and hasattr(self, "_offloaded_gradient_checkpointing"): - offloaded_gradient_checkpointing = True pass # Gemma2 has alternating SWA and global attn @@ -1975,9 +1970,14 @@ def from_pretrained( internal_model = model while hasattr(internal_model, "model"): internal_model._saved_temp_tokenizer = tokenizer + # Also set is_loaded_in_8bit to disable incorrect DDP + internal_model.is_loaded_in_8bit = True + internal_model = internal_model.model pass internal_model._saved_temp_tokenizer = tokenizer + # Also set is_loaded_in_8bit to disable incorrect DDP + internal_model.is_loaded_in_8bit = True # For transformers > 4.47.1, we need to add rotary_emb to all attention layers if IS_ATTENTION_REFACTOR or hasattr(model.model, "rotary_emb"): @@ -2387,11 +2387,15 @@ def get_peft_model( if hasattr(internal_model, "_saved_temp_tokenizer"): internal_model._saved_temp_tokenizer.padding_side = "right" pass + # Also set is_loaded_in_8bit to disable incorrect DDP + internal_model.is_loaded_in_8bit = True internal_model = internal_model.model pass if hasattr(internal_model, "_saved_temp_tokenizer"): internal_model._saved_temp_tokenizer.padding_side = "right" pass + # Also set is_loaded_in_8bit to disable incorrect DDP + internal_model.is_loaded_in_8bit = True # Clear deleted GPU items for _ in range(3): From d73c34bf19917945f6c5166cdb309eee8966b290 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:32:02 -0800 Subject: [PATCH 463/473] Update llama.py --- unsloth/models/llama.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 475f82a5b..b5bfa3cbf 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1684,10 +1684,10 @@ def from_pretrained( statistics = \ f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.\n"\ - f" {chr(92)}{chr(92)} /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ + f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ - f' "-____-" Free Apache license: http://github.com/unslothai/unsloth' + f' "-____-" Free license: http://github.com/unslothai/unsloth' print(statistics) # Warn about fast transfers @@ -1879,11 +1879,11 @@ def from_pretrained( # Cannot use \\ since it will cause a SyntaxWarning in Python 3.12 # Instead use chr(92) == \\ debug_info = """debug_info = \\ - f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs = {args.world_size}\\n"\\ - f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,}\\n"\\ - f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient Accumulation steps = {args.gradient_accumulation_steps}\\n"\\ - f"{chr(92)} / Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\n"\\ - f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}' + f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = {len(set(p.device for p in model.parameters()))}\\n"\\ + f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ + f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ + f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size = {total_train_batch_size:,}\\n"\\ + f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model)}' logger.warning(debug_info) import subprocess, re, gc for _ in range(3): From 4485da745ba2728396815f7edbd548832ffd633e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:41:37 -0800 Subject: [PATCH 464/473] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index b5bfa3cbf..7bee733a1 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1882,8 +1882,8 @@ def from_pretrained( f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = {len(set(p.device for p in model.parameters()))}\\n"\\ f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ - f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size = {total_train_batch_size:,}\\n"\\ - f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model)}' + f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size}*{args.gradient_accumulation_steps}*{args.world_size}) = {total_train_batch_size:,}\\n"\\ + f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f})' logger.warning(debug_info) import subprocess, re, gc for _ in range(3): From 45ea48c3ce2e252bf6de790ad05a7db55a4acc9a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:58:58 -0800 Subject: [PATCH 465/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7bee733a1..6bff0f217 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1883,7 +1883,7 @@ def from_pretrained( f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size}*{args.gradient_accumulation_steps}*{args.world_size}) = {total_train_batch_size:,}\\n"\\ - f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f})' + f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)' logger.warning(debug_info) import subprocess, re, gc for _ in range(3): From 8c4b79c32df8a706bed707f12426220b366a6541 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:59:11 -0800 Subject: [PATCH 466/473] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 6bff0f217..bcabbd512 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1882,7 +1882,7 @@ def from_pretrained( f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = {len(set(p.device for p in model.parameters()))}\\n"\\ f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ - f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size}*{args.gradient_accumulation_steps}*{args.world_size}) = {total_train_batch_size:,}\\n"\\ + f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\ f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)' logger.warning(debug_info) import subprocess, re, gc From c2ae5101e8fa8daa4e4de2ac5755740196f8c05d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:28:35 -0800 Subject: [PATCH 467/473] Update utils.py --- unsloth/kernels/utils.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 5bb0e337d..f42ceeca2 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -19,6 +19,7 @@ # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch +torch_Tensor = torch.Tensor from packaging.version import Version if Version(torch.__version__) < Version("2.4.0"): torch_amp_custom_fwd = torch.cuda.amp.custom_fwd @@ -68,6 +69,18 @@ def calculate_settings(n : int) -> (int, int,): HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3") get_ptr = bnb.functional.get_ptr +if torch.cuda.device_count() > 1: + def _cuda_device_of(a: torch_Tensor): return torch.cuda.device_of(a) +else: + from contextlib import nullcontext + def _cuda_device_of(a: torch_Tensor): return nullcontext() +pass +_cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream +c_void_p = ctypes.c_void_p +def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p: + return c_void_p(_cuda_getCurrentRawStream(tensor.device.index)) +pass + # Get array of CUDA streams and other buffers global CUDA_STREAMS global WEIGHT_BUFFERS @@ -202,18 +215,19 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # NF4 dequantization of statistics ptr_out_absmax = get_ptr(out_absmax) - cdequantize_blockwise_fp32( - get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM, - ) - out_absmax += offset - - # Dequantize W - fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ - cdequantize_blockwise_bf16_nf4 - fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,) - + with _cuda_device_of(absmax): + cdequantize_blockwise_fp32( + get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, + ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), _get_tensor_stream(absmax), + ) + out_absmax += offset + + # Dequantize W + fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ + cdequantize_blockwise_bf16_nf4 + fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), + ctypes_c_int(blocksize), ctypes_c_int(out.numel()), _get_tensor_stream(absmax),) + pass # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) return out.t() if is_transposed else out From 432ea2447f532691ec11148d9aabf63b2bb65d21 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:35:19 -0800 Subject: [PATCH 468/473] Update utils.py --- unsloth/kernels/utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index f42ceeca2..7a6927471 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -28,7 +28,6 @@ torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda") torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda") pass -torch_cuda_device = torch.cuda.device # tl.math.tanh now is libdevice.tanh @@ -70,10 +69,10 @@ def calculate_settings(n : int) -> (int, int,): get_ptr = bnb.functional.get_ptr if torch.cuda.device_count() > 1: - def _cuda_device_of(a: torch_Tensor): return torch.cuda.device_of(a) + torch_cuda_device = torch.cuda.device else: from contextlib import nullcontext - def _cuda_device_of(a: torch_Tensor): return nullcontext() + def torch_cuda_device(device): return nullcontext() pass _cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream c_void_p = ctypes.c_void_p @@ -215,10 +214,10 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # NF4 dequantization of statistics ptr_out_absmax = get_ptr(out_absmax) - with _cuda_device_of(absmax): + with torch_cuda_device(device): cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), _get_tensor_stream(absmax), + ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM ) out_absmax += offset @@ -226,7 +225,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes_c_int(blocksize), ctypes_c_int(out.numel()), _get_tensor_stream(absmax),) + ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,) pass # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) From dcff03c59a6cb5781409bb5fcdbb72a08847e51b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:37:09 -0800 Subject: [PATCH 469/473] Update utils.py --- unsloth/kernels/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 7a6927471..fc45a2b4b 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -158,7 +158,6 @@ def get_lora_parameters_bias(proj): if HAS_CUDA_STREAM: @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): - use_global_buffer = False if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class From 6ef086694a14681f1ab40d7ff158c5d7d6f034a2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:38:44 -0800 Subject: [PATCH 470/473] Update utils.py --- unsloth/kernels/utils.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index fc45a2b4b..273eddcc2 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -337,19 +337,21 @@ def fast_gemv(X, W, quant_state, out = None): ldc = ctypes_c_int32(ldc) df = torch.empty(absmax.shape, dtype = torch.float32, device = device) - cdequantize_blockwise_fp32( - get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), - ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM, - ) - df += offset - absmax = df + with torch_cuda_device(device): + cdequantize_blockwise_fp32( + get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), + ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM, + ) + df += offset + absmax = df - fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \ - cgemm_4bit_inference_naive_bf16 + fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \ + cgemm_4bit_inference_naive_bf16 - blocksize = ctypes_c_int32(blocksize) - fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), - lda, ldb, ldc, blocksize, CUDA_STREAM,) + blocksize = ctypes_c_int32(blocksize) + fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), + lda, ldb, ldc, blocksize, CUDA_STREAM,) + pass return out pass @@ -470,7 +472,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype - W = fast_dequantize(W.t(), W_quant, use_global_buffer = False) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) if X.dim() == 3: batch, seq_len, d = X.shape From 8c8ce96af782b50ea485e90f0845c2447edc4a5c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:57:54 -0800 Subject: [PATCH 471/473] __version__ --- unsloth/__init__.py | 1 + unsloth/models/__init__.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index e33d16577..caa06b012 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -212,6 +212,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 pass from .models import * +from .models import __version__ from .save import * from .chat_templates import * from .tokenizer_utils import * diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index 29ad78dae..e11cd5441 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -19,5 +19,5 @@ from .mistral import FastMistralModel from .qwen2 import FastQwen2Model from .dpo import PatchDPOTrainer, PatchKTOTrainer -from ._utils import is_bfloat16_supported +from ._utils import is_bfloat16_supported, __version__ from .rl import PatchFastRL, vLLMSamplingParams From 208971bc3347723402db70e31cbfc904dee9ee67 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 03:31:38 -0800 Subject: [PATCH 472/473] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 8f346073b..3a9d651d1 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -495,7 +495,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = True, + overwrite = False, ) # Patch Trainer From adc697770f3c9f2878b0e7fc5e863ba9e3a8cfcc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 03:47:51 -0800 Subject: [PATCH 473/473] Bug fixes --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- unsloth/kernels/utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index de1583e9e..73e69dcd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ triton = [ ] windows=[ - "unsloth_zoo>=2025.2.7", + "unsloth_zoo>=2025.3.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -61,7 +61,7 @@ windows=[ "xformers>=0.0.22.post7 ; platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.2.7", + "unsloth_zoo>=2025.3.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index caa06b012..c8f292698 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.2.6"): + if Version(unsloth_zoo_version) < Version("2025.3.1"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 273eddcc2..5eb9b8f5c 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -473,7 +473,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) - + if X.dim() == 3: batch, seq_len, d = X.shape X = X.view(-1, X.shape[-1])