From f7516bedf5a7f8fddca54e8823300d998f0792b2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 22 Dec 2024 13:29:54 -0500 Subject: [PATCH 1/2] allow manual token_ids to be passed to fix untrained lrs --- src/axolotl/contribs/lgpl/unsloth.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/contribs/lgpl/unsloth.py b/src/axolotl/contribs/lgpl/unsloth.py index c61e8d0..332d5fa 100644 --- a/src/axolotl/contribs/lgpl/unsloth.py +++ b/src/axolotl/contribs/lgpl/unsloth.py @@ -24,7 +24,7 @@ @torch.inference_mode() def fix_untrained_tokens( # pylint: disable=too-many-return-statements - model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16 + model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16, token_ids_to_fix=None, ): """ Llama-3 for eg has untrained vectors in the base model. @@ -32,6 +32,8 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements We reset them to the mean of the rest of the tokens """ # Code licensed under LGPL + if not token_ids_to_fix: + token_ids_to_fix = [] embedding_matrix = model.get_input_embeddings().weight lm_head_matrix = model.get_output_embeddings().weight chat_template = getattr(tokenizer, "chat_template", None) @@ -91,6 +93,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements # Get set and actual tokens where_untrained = where_untrained.tolist() + where_untrained = list(set(where_untrained) + set(token_ids_to_fix)) if len(where_untrained) == 0: return From 2722306b40756b9cfafb21f62a28d41f5a65dd71 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 22 Dec 2024 18:02:12 -0500 Subject: [PATCH 2/2] fix set operation --- src/axolotl/contribs/lgpl/unsloth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/contribs/lgpl/unsloth.py b/src/axolotl/contribs/lgpl/unsloth.py index 332d5fa..9204516 100644 --- a/src/axolotl/contribs/lgpl/unsloth.py +++ b/src/axolotl/contribs/lgpl/unsloth.py @@ -93,7 +93,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements # Get set and actual tokens where_untrained = where_untrained.tolist() - where_untrained = list(set(where_untrained) + set(token_ids_to_fix)) + where_untrained = list(set(token_ids_to_fix + where_untrained)) if len(where_untrained) == 0: return