Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "axolotl-contribs-lgpl"
version = "0.0.4.dev0"
version = "0.0.4"
description = "LGPL contributions to the axolotl framework"
authors = [{ name="Wing Lian" }]
readme = "README.md"
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/contribs/lgpl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .unsloth import fix_untrained_tokens # noqa: F401
147 changes: 102 additions & 45 deletions src/axolotl/contribs/lgpl/unsloth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,68 @@
import itertools
import logging
from collections import Counter
from contextlib import nullcontext
from functools import partial

import datasets
import deepspeed
import numpy as np
import torch

LOG = logging.getLogger("axolotl.contribs.lgpl.unsloth")


@torch.inference_mode()
def fix_untrained_tokens( # pylint: disable=too-many-return-statements
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16, token_ids_to_fix=None,
):
def get_embedding_mean(input_embeddings, output_embeddings, tokenizer, train_dataset, eps=1e-16, token_ids_to_fix=None):
"""
Llama-3 for eg has untrained vectors in the base model.
These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
We reset them to the mean of the rest of the tokens
Detects tokens with untrained embeddings (values close to zero or duplicated) and
prepares scaled mean embedding values for these tokens based on their frequency in the training data.

Args:
input_embeddings: The model's input embedding layer (embed_tokens)
output_embeddings: The model's output embedding layer (lm_head)
tokenizer: The tokenizer used with the model
train_dataset: The training dataset to analyze token frequencies
eps: Small epsilon value to detect near-zero embeddings (default: 1e-16)
token_ids_to_fix: Additional token IDs to include in the fixing process

Returns:
tuple: If untrained tokens are found in the training data, returns:
- mean_embedding_repeated: Scaled mean embeddings for input layer
- mean_lm_head_repeated: Scaled mean embeddings for output layer
- tokens_to_update: List of token IDs that need updating
None: If no untrained tokens are found in the training data

Raises:
ValueError: If embedding matrices have incorrect shapes or if untrained tokens are found but embeddings are not trainable
"""
# Code licensed under LGPL
if not token_ids_to_fix:
token_ids_to_fix = []
embedding_matrix = model.get_input_embeddings().weight
lm_head_matrix = model.get_output_embeddings().weight

# Check if we still have an issue with the shapes
if input_embeddings.weight.shape[0] == 0 or output_embeddings.weight.shape[0] == 0:
raise ValueError(
f"Could not gather embedding matrices properly. "
f"Shapes: embedding={input_embeddings.weight.shape}, lm_head={output_embeddings.weight.shape}. "
f"This might indicate a DeepSpeed configuration issue."
)

# Get the chat template if available
chat_template = getattr(tokenizer, "chat_template", None)
tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer

# Ignore some model checks for now
if not ignored_tokenizer_names:
ignored_tokenizer_names = []
if (
model.config._name_or_path # pylint: disable=protected-access
in ignored_tokenizer_names
):
return

# Sometimes the sizes can be different like in vision models
# Ie <image> is in input, but not in output
min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1])
embedding_matrix = embedding_matrix[:, :min_size]
lm_head_matrix = lm_head_matrix[:, :min_size]
min_size = min(input_embeddings.weight.shape[1], output_embeddings.weight.shape[1])
input_embeddings.weight = torch.nn.Parameter(input_embeddings.weight[:, :min_size])
output_embeddings.weight = torch.nn.Parameter(output_embeddings.weight[:, :min_size])

# Get untrained tokens
indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps
# Check lm_head as well
indicator_untrained1 = torch.amax(input_embeddings.weight, axis=1) <= eps

# Does NOT work for Llama 3.1!!
indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps
indicator_untrained2 = torch.amax(output_embeddings.weight, axis=1) <= eps

# We instead check for repeated vectors
lm_head_where = torch.where(indicator_untrained1)[0]
lm_head_bad = lm_head_matrix[lm_head_where]
lm_head_bad = output_embeddings.weight[lm_head_where]
lm_head_bad = lm_head_bad.cpu().float().numpy().round(3)
counter = Counter()
for row in lm_head_bad:
Expand All @@ -75,7 +87,8 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
for j, row in enumerate(lm_head_bad):
if hash(row.data.tobytes()) in counter:
final_bad_lm_head.append(lm_head_where[j])
indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2)

indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained1, dtype=torch.bool)
indicator_untrained2[final_bad_lm_head] = True

# Combine both checks
Expand All @@ -89,7 +102,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements

where_untrained = torch.where(indicator_untrained)[0]
n_untrained = where_untrained.shape[0]
n_trained = embedding_matrix.shape[0] - n_untrained
n_trained = input_embeddings.weight.shape[0] - n_untrained

# Get set and actual tokens
where_untrained = where_untrained.tolist()
Expand Down Expand Up @@ -145,9 +158,9 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements

# Check if lm_head / embed_token are trainable!
bad_not_trainable = False
if not embedding_matrix.requires_grad:
if not input_embeddings.weight.requires_grad:
bad_not_trainable = True
if not lm_head_matrix.requires_grad:
if not output_embeddings.weight.requires_grad:
bad_not_trainable = True

if bad_not_trainable: # pylint: disable=too-many-nested-blocks
Expand Down Expand Up @@ -176,8 +189,6 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements

# If no bad tokens, possibly chat template itself has issues?
if len(final_bad_items) == 0:
# Recheck 2000 and last 2000 items
size_dataset = len(train_dataset)
size = min(size_dataset, 2000)
for j in range(size):
input_ids = train_dataset[j]
Expand Down Expand Up @@ -207,7 +218,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements

# Count all the possible bad tokens
final_counts = np.zeros(
max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
max(len(tokenizer), input_embeddings.weight.shape[0]), dtype=np.int64
)

def mapping(examples):
Expand Down Expand Up @@ -235,15 +246,15 @@ def mapping(examples):
)

# Get sum of all items
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
sum_embedding = torch.sum(input_embeddings.weight, dtype=torch.float32, axis=0)
sum_lm_head = torch.sum(output_embeddings.weight, dtype=torch.float32, axis=0)

# Remove bad tokens
sum_embedding -= torch.sum(
embedding_matrix[where_untrained], dtype=torch.float32, axis=0
input_embeddings.weight[where_untrained], dtype=torch.float32, axis=0
)
sum_lm_head -= torch.sum(
lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
output_embeddings.weight[where_untrained], dtype=torch.float32, axis=0
)

# Find correct average by dividing by sum of trained tokens
Expand All @@ -262,14 +273,60 @@ def mapping(examples):
mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
)

# Update embeddings only for tokens seen in train_dataset
embedding_matrix[tokens_to_update] = mean_embedding_repeated.to(
embedding_matrix.dtype
)
lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype)
return mean_embedding_repeated, mean_lm_head_repeated, tokens_to_update

def fix_untrained_tokens(
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16, token_ids_to_fix=None, is_ds_zero3=False
):
"""
Some base models have untrained vectors for embeddings/tokens. Update these embeddings to
the mean of the rest of the tokens. This additionally handles distributed settings like DeepSpeed ZeRO-3.

Args:
model: The model to fix embeddings for
tokenizer: The tokenizer used with the model
train_dataset: The training dataset to analyze token frequencies
ignored_tokenizer_names: List of model names to skip processing (default: None)
eps: Small epsilon value to detect near-zero embeddings (default: 1e-16)
token_ids_to_fix: Additional token IDs to include in the fixing process (default: None)
is_ds_zero3: Whether DeepSpeed ZeRO-3 is being used (default: False)
"""

if not token_ids_to_fix:
token_ids_to_fix = []
if not ignored_tokenizer_names:
ignored_tokenizer_names = []

# Check if we should ignore this model
if (
hasattr(model, "config") and
hasattr(model.config, "_name_or_path") and
model.config._name_or_path in ignored_tokenizer_names
):
return

with torch.no_grad():
# Get the embedding layer and lm_head
embedding_layer = model.get_input_embeddings()
lm_head_layer = model.get_output_embeddings()

context = nullcontext
if is_ds_zero3:
# Get the full parameters if using DeepSpeed
context = partial(deepspeed.zero.GatheredParameters, [embedding_layer.weight, lm_head_layer.weight], modifier_rank=0)

with context():
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
mean_embedding_repeated, mean_lm_head_repeated, tokens_to_update = get_embedding_mean(input_embeddings, output_embeddings, tokenizer, train_dataset, eps=eps, token_ids_to_fix=token_ids_to_fix)
input_embeddings.weight[tokens_to_update] = mean_embedding_repeated.to(
input_embeddings.weight.dtype
)
model.set_input_embeddings(input_embeddings)
output_embeddings.weight[tokens_to_update] = mean_lm_head_repeated.to(output_embeddings.weight.dtype)
model.set_output_embeddings(output_embeddings)

# Clean up
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return