diff --git a/pyproject.toml b/pyproject.toml index 40a92e6..256e5c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "axolotl-contribs-lgpl" -version = "0.0.4.dev0" +version = "0.0.4" description = "LGPL contributions to the axolotl framework" authors = [{ name="Wing Lian" }] readme = "README.md" diff --git a/src/axolotl/contribs/lgpl/__init__.py b/src/axolotl/contribs/lgpl/__init__.py index e69de29..244d1c1 100644 --- a/src/axolotl/contribs/lgpl/__init__.py +++ b/src/axolotl/contribs/lgpl/__init__.py @@ -0,0 +1 @@ +from .unsloth import fix_untrained_tokens # noqa: F401 \ No newline at end of file diff --git a/src/axolotl/contribs/lgpl/unsloth.py b/src/axolotl/contribs/lgpl/unsloth.py index 9204516..e57a5f2 100644 --- a/src/axolotl/contribs/lgpl/unsloth.py +++ b/src/axolotl/contribs/lgpl/unsloth.py @@ -14,56 +14,68 @@ import itertools import logging from collections import Counter +from contextlib import nullcontext +from functools import partial import datasets +import deepspeed import numpy as np import torch LOG = logging.getLogger("axolotl.contribs.lgpl.unsloth") -@torch.inference_mode() -def fix_untrained_tokens( # pylint: disable=too-many-return-statements - model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16, token_ids_to_fix=None, -): +def get_embedding_mean(input_embeddings, output_embeddings, tokenizer, train_dataset, eps=1e-16, token_ids_to_fix=None): """ - Llama-3 for eg has untrained vectors in the base model. - These include <|eot_id|>, <|start_header_id|>, <|end_header_id|> - We reset them to the mean of the rest of the tokens + Detects tokens with untrained embeddings (values close to zero or duplicated) and + prepares scaled mean embedding values for these tokens based on their frequency in the training data. + + Args: + input_embeddings: The model's input embedding layer (embed_tokens) + output_embeddings: The model's output embedding layer (lm_head) + tokenizer: The tokenizer used with the model + train_dataset: The training dataset to analyze token frequencies + eps: Small epsilon value to detect near-zero embeddings (default: 1e-16) + token_ids_to_fix: Additional token IDs to include in the fixing process + + Returns: + tuple: If untrained tokens are found in the training data, returns: + - mean_embedding_repeated: Scaled mean embeddings for input layer + - mean_lm_head_repeated: Scaled mean embeddings for output layer + - tokens_to_update: List of token IDs that need updating + None: If no untrained tokens are found in the training data + + Raises: + ValueError: If embedding matrices have incorrect shapes or if untrained tokens are found but embeddings are not trainable """ - # Code licensed under LGPL - if not token_ids_to_fix: - token_ids_to_fix = [] - embedding_matrix = model.get_input_embeddings().weight - lm_head_matrix = model.get_output_embeddings().weight + + # Check if we still have an issue with the shapes + if input_embeddings.weight.shape[0] == 0 or output_embeddings.weight.shape[0] == 0: + raise ValueError( + f"Could not gather embedding matrices properly. " + f"Shapes: embedding={input_embeddings.weight.shape}, lm_head={output_embeddings.weight.shape}. " + f"This might indicate a DeepSpeed configuration issue." + ) + + # Get the chat template if available chat_template = getattr(tokenizer, "chat_template", None) tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer - # Ignore some model checks for now - if not ignored_tokenizer_names: - ignored_tokenizer_names = [] - if ( - model.config._name_or_path # pylint: disable=protected-access - in ignored_tokenizer_names - ): - return - # Sometimes the sizes can be different like in vision models # Ie is in input, but not in output - min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1]) - embedding_matrix = embedding_matrix[:, :min_size] - lm_head_matrix = lm_head_matrix[:, :min_size] + min_size = min(input_embeddings.weight.shape[1], output_embeddings.weight.shape[1]) + input_embeddings.weight = torch.nn.Parameter(input_embeddings.weight[:, :min_size]) + output_embeddings.weight = torch.nn.Parameter(output_embeddings.weight[:, :min_size]) # Get untrained tokens - indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps - # Check lm_head as well + indicator_untrained1 = torch.amax(input_embeddings.weight, axis=1) <= eps # Does NOT work for Llama 3.1!! - indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps + indicator_untrained2 = torch.amax(output_embeddings.weight, axis=1) <= eps # We instead check for repeated vectors lm_head_where = torch.where(indicator_untrained1)[0] - lm_head_bad = lm_head_matrix[lm_head_where] + lm_head_bad = output_embeddings.weight[lm_head_where] lm_head_bad = lm_head_bad.cpu().float().numpy().round(3) counter = Counter() for row in lm_head_bad: @@ -75,7 +87,8 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements for j, row in enumerate(lm_head_bad): if hash(row.data.tobytes()) in counter: final_bad_lm_head.append(lm_head_where[j]) - indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2) + + indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained1, dtype=torch.bool) indicator_untrained2[final_bad_lm_head] = True # Combine both checks @@ -89,7 +102,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements where_untrained = torch.where(indicator_untrained)[0] n_untrained = where_untrained.shape[0] - n_trained = embedding_matrix.shape[0] - n_untrained + n_trained = input_embeddings.weight.shape[0] - n_untrained # Get set and actual tokens where_untrained = where_untrained.tolist() @@ -145,9 +158,9 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements # Check if lm_head / embed_token are trainable! bad_not_trainable = False - if not embedding_matrix.requires_grad: + if not input_embeddings.weight.requires_grad: bad_not_trainable = True - if not lm_head_matrix.requires_grad: + if not output_embeddings.weight.requires_grad: bad_not_trainable = True if bad_not_trainable: # pylint: disable=too-many-nested-blocks @@ -176,8 +189,6 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements # If no bad tokens, possibly chat template itself has issues? if len(final_bad_items) == 0: - # Recheck 2000 and last 2000 items - size_dataset = len(train_dataset) size = min(size_dataset, 2000) for j in range(size): input_ids = train_dataset[j] @@ -207,7 +218,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements # Count all the possible bad tokens final_counts = np.zeros( - max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64 + max(len(tokenizer), input_embeddings.weight.shape[0]), dtype=np.int64 ) def mapping(examples): @@ -235,15 +246,15 @@ def mapping(examples): ) # Get sum of all items - sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0) - sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0) + sum_embedding = torch.sum(input_embeddings.weight, dtype=torch.float32, axis=0) + sum_lm_head = torch.sum(output_embeddings.weight, dtype=torch.float32, axis=0) # Remove bad tokens sum_embedding -= torch.sum( - embedding_matrix[where_untrained], dtype=torch.float32, axis=0 + input_embeddings.weight[where_untrained], dtype=torch.float32, axis=0 ) sum_lm_head -= torch.sum( - lm_head_matrix[where_untrained], dtype=torch.float32, axis=0 + output_embeddings.weight[where_untrained], dtype=torch.float32, axis=0 ) # Find correct average by dividing by sum of trained tokens @@ -262,14 +273,60 @@ def mapping(examples): mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling ) - # Update embeddings only for tokens seen in train_dataset - embedding_matrix[tokens_to_update] = mean_embedding_repeated.to( - embedding_matrix.dtype - ) - lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype) + return mean_embedding_repeated, mean_lm_head_repeated, tokens_to_update + +def fix_untrained_tokens( + model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16, token_ids_to_fix=None, is_ds_zero3=False +): + """ + Some base models have untrained vectors for embeddings/tokens. Update these embeddings to + the mean of the rest of the tokens. This additionally handles distributed settings like DeepSpeed ZeRO-3. + + Args: + model: The model to fix embeddings for + tokenizer: The tokenizer used with the model + train_dataset: The training dataset to analyze token frequencies + ignored_tokenizer_names: List of model names to skip processing (default: None) + eps: Small epsilon value to detect near-zero embeddings (default: 1e-16) + token_ids_to_fix: Additional token IDs to include in the fixing process (default: None) + is_ds_zero3: Whether DeepSpeed ZeRO-3 is being used (default: False) + """ + + if not token_ids_to_fix: + token_ids_to_fix = [] + if not ignored_tokenizer_names: + ignored_tokenizer_names = [] + + # Check if we should ignore this model + if ( + hasattr(model, "config") and + hasattr(model.config, "_name_or_path") and + model.config._name_or_path in ignored_tokenizer_names + ): + return + + with torch.no_grad(): + # Get the embedding layer and lm_head + embedding_layer = model.get_input_embeddings() + lm_head_layer = model.get_output_embeddings() + + context = nullcontext + if is_ds_zero3: + # Get the full parameters if using DeepSpeed + context = partial(deepspeed.zero.GatheredParameters, [embedding_layer.weight, lm_head_layer.weight], modifier_rank=0) + + with context(): + input_embeddings = model.get_input_embeddings() + output_embeddings = model.get_output_embeddings() + mean_embedding_repeated, mean_lm_head_repeated, tokens_to_update = get_embedding_mean(input_embeddings, output_embeddings, tokenizer, train_dataset, eps=eps, token_ids_to_fix=token_ids_to_fix) + input_embeddings.weight[tokens_to_update] = mean_embedding_repeated.to( + input_embeddings.weight.dtype + ) + model.set_input_embeddings(input_embeddings) + output_embeddings.weight[tokens_to_update] = mean_lm_head_repeated.to(output_embeddings.weight.dtype) + model.set_output_embeddings(output_embeddings) # Clean up for _ in range(3): gc.collect() torch.cuda.empty_cache() - return