Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/axolotl/contribs/lgpl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from unsloth import fix_untrained_tokens # noqa: F401
122 changes: 76 additions & 46 deletions src/axolotl/contribs/lgpl/unsloth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,54 +16,41 @@
from collections import Counter

import datasets
import deepspeed
import numpy as np
import torch

LOG = logging.getLogger("axolotl.contribs.lgpl.unsloth")


@torch.inference_mode()
def fix_untrained_tokens( # pylint: disable=too-many-return-statements
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16, token_ids_to_fix=None,
):
"""
Llama-3 for eg has untrained vectors in the base model.
These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
We reset them to the mean of the rest of the tokens
"""
# Code licensed under LGPL
if not token_ids_to_fix:
token_ids_to_fix = []
embedding_matrix = model.get_input_embeddings().weight
lm_head_matrix = model.get_output_embeddings().weight
def get_embedding_mean(input_embeddings, output_embeddings, tokenizer, train_dataset, eps=1e-16, token_ids_to_fix=None):
# Check if we still have an issue with the shapes
if input_embeddings.weight.shape[0] == 0 or output_embeddings.weight.shape[0] == 0:
raise ValueError(
f"Could not gather embedding matrices properly. "
f"Shapes: embedding={input_embeddings.weight.shape}, lm_head={output_embeddings.weight.shape}. "
f"This might indicate a DeepSpeed configuration issue."
)

# Get the chat template if available
chat_template = getattr(tokenizer, "chat_template", None)
tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer

# Ignore some model checks for now
if not ignored_tokenizer_names:
ignored_tokenizer_names = []
if (
model.config._name_or_path # pylint: disable=protected-access
in ignored_tokenizer_names
):
return

# Sometimes the sizes can be different like in vision models
# Ie <image> is in input, but not in output
min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1])
embedding_matrix = embedding_matrix[:, :min_size]
lm_head_matrix = lm_head_matrix[:, :min_size]
min_size = min(input_embeddings.weight.shape[1], output_embeddings.weight.shape[1])
input_embeddings.weight = torch.nn.Parameter(input_embeddings.weight[:, :min_size])
output_embeddings.weight = torch.nn.Parameter(output_embeddings.weight[:, :min_size])

# Get untrained tokens
indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps
# Check lm_head as well
indicator_untrained1 = torch.amax(input_embeddings.weight, axis=1) <= eps

# Does NOT work for Llama 3.1!!
indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps
indicator_untrained2 = torch.amax(output_embeddings.weight, axis=1) <= eps

# We instead check for repeated vectors
lm_head_where = torch.where(indicator_untrained1)[0]
lm_head_bad = lm_head_matrix[lm_head_where]
lm_head_bad = output_embeddings.weight[lm_head_where]
lm_head_bad = lm_head_bad.cpu().float().numpy().round(3)
counter = Counter()
for row in lm_head_bad:
Expand All @@ -75,7 +62,8 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
for j, row in enumerate(lm_head_bad):
if hash(row.data.tobytes()) in counter:
final_bad_lm_head.append(lm_head_where[j])
indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2)

indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained1, dtype=torch.bool)
indicator_untrained2[final_bad_lm_head] = True

# Combine both checks
Expand All @@ -89,7 +77,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements

where_untrained = torch.where(indicator_untrained)[0]
n_untrained = where_untrained.shape[0]
n_trained = embedding_matrix.shape[0] - n_untrained
n_trained = input_embeddings.weight.shape[0] - n_untrained

# Get set and actual tokens
where_untrained = where_untrained.tolist()
Expand Down Expand Up @@ -145,9 +133,9 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements

# Check if lm_head / embed_token are trainable!
bad_not_trainable = False
if not embedding_matrix.requires_grad:
if not input_embeddings.weight.requires_grad:
bad_not_trainable = True
if not lm_head_matrix.requires_grad:
if not output_embeddings.weight.requires_grad:
bad_not_trainable = True

if bad_not_trainable: # pylint: disable=too-many-nested-blocks
Expand Down Expand Up @@ -176,8 +164,6 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements

# If no bad tokens, possibly chat template itself has issues?
if len(final_bad_items) == 0:
# Recheck 2000 and last 2000 items
size_dataset = len(train_dataset)
size = min(size_dataset, 2000)
for j in range(size):
input_ids = train_dataset[j]
Expand Down Expand Up @@ -207,7 +193,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements

# Count all the possible bad tokens
final_counts = np.zeros(
max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
max(len(tokenizer), input_embeddings.weight.shape[0]), dtype=np.int64
)

def mapping(examples):
Expand Down Expand Up @@ -235,15 +221,15 @@ def mapping(examples):
)

# Get sum of all items
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
sum_embedding = torch.sum(input_embeddings.weight, dtype=torch.float32, axis=0)
sum_lm_head = torch.sum(output_embeddings.weight, dtype=torch.float32, axis=0)

# Remove bad tokens
sum_embedding -= torch.sum(
embedding_matrix[where_untrained], dtype=torch.float32, axis=0
input_embeddings.weight[where_untrained], dtype=torch.float32, axis=0
)
sum_lm_head -= torch.sum(
lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
output_embeddings.weight[where_untrained], dtype=torch.float32, axis=0
)

# Find correct average by dividing by sum of trained tokens
Expand All @@ -262,11 +248,55 @@ def mapping(examples):
mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
)

# Update embeddings only for tokens seen in train_dataset
embedding_matrix[tokens_to_update] = mean_embedding_repeated.to(
embedding_matrix.dtype
)
lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype)
return mean_embedding_repeated, mean_lm_head_repeated, tokens_to_update

def fix_untrained_tokens(
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16, token_ids_to_fix=None, is_ds_zero3=False
):
"""
Modified version of fix_untrained_tokens that works with distributed settings like DeepSpeed ZeRO-3.
Handles cases where embedding_matrix.shape might be [0] due to parameter sharding/offloading.
"""

if not token_ids_to_fix:
token_ids_to_fix = []
if not ignored_tokenizer_names:
ignored_tokenizer_names = []

# Check if we should ignore this model
if (
hasattr(model, "config") and
hasattr(model.config, "_name_or_path") and
model.config._name_or_path in ignored_tokenizer_names
):
return

with torch.no_grad():
# Get the embedding layer and lm_head
embedding_layer = model.get_input_embeddings()
lm_head_layer = model.get_output_embeddings()

# Get the full parameters if using DeepSpeed
if is_ds_zero3:
with deepspeed.zero.GatheredParameters([embedding_layer.weight, lm_head_layer.weight], modifier_rank=0):
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
mean_embedding_repeated, mean_lm_head_repeated, tokens_to_update = get_embedding_mean(input_embeddings, output_embeddings, tokenizer, train_dataset, eps=eps, token_ids_to_fix=token_ids_to_fix)
input_embeddings.weight[tokens_to_update] = mean_embedding_repeated.to(
input_embeddings.weight.dtype
)
model.set_input_embeddings(input_embeddings)
output_embeddings.weight[tokens_to_update] = mean_lm_head_repeated.to(output_embeddings.weight.dtype)
model.set_output_embeddings(output_embeddings)
else:
# Fallback to original approach if not using DeepSpeed
input_embeddings = embedding_layer
output_embeddings = lm_head_layer
mean_embedding_repeated, mean_lm_head_repeated, tokens_to_update = get_embedding_mean(input_embeddings, output_embeddings, tokenizer, train_dataset, eps=eps, token_ids_to_fix=token_ids_to_fix)
input_embeddings.weight[tokens_to_update] = mean_embedding_repeated.to(
input_embeddings.weight.dtype
)
output_embeddings.weight[tokens_to_update] = mean_lm_head_repeated.to(output_embeddings.weight.dtype)

# Clean up
for _ in range(3):
Expand Down