Skip to content

Commit edaec42

Browse files
authored
Merge pull request #4 from axolotl-ai-cloud/fix-untrained-distributed
handle distributed embeddings
2 parents cedbf68 + e76b897 commit edaec42

File tree

3 files changed

+104
-46
lines changed

3 files changed

+104
-46
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "axolotl-contribs-lgpl"
7-
version = "0.0.4.dev0"
7+
version = "0.0.4"
88
description = "LGPL contributions to the axolotl framework"
99
authors = [{ name="Wing Lian" }]
1010
readme = "README.md"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .unsloth import fix_untrained_tokens # noqa: F401

src/axolotl/contribs/lgpl/unsloth.py

Lines changed: 102 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,56 +14,68 @@
1414
import itertools
1515
import logging
1616
from collections import Counter
17+
from contextlib import nullcontext
18+
from functools import partial
1719

1820
import datasets
21+
import deepspeed
1922
import numpy as np
2023
import torch
2124

2225
LOG = logging.getLogger("axolotl.contribs.lgpl.unsloth")
2326

2427

25-
@torch.inference_mode()
26-
def fix_untrained_tokens( # pylint: disable=too-many-return-statements
27-
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16, token_ids_to_fix=None,
28-
):
28+
def get_embedding_mean(input_embeddings, output_embeddings, tokenizer, train_dataset, eps=1e-16, token_ids_to_fix=None):
2929
"""
30-
Llama-3 for eg has untrained vectors in the base model.
31-
These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
32-
We reset them to the mean of the rest of the tokens
30+
Detects tokens with untrained embeddings (values close to zero or duplicated) and
31+
prepares scaled mean embedding values for these tokens based on their frequency in the training data.
32+
33+
Args:
34+
input_embeddings: The model's input embedding layer (embed_tokens)
35+
output_embeddings: The model's output embedding layer (lm_head)
36+
tokenizer: The tokenizer used with the model
37+
train_dataset: The training dataset to analyze token frequencies
38+
eps: Small epsilon value to detect near-zero embeddings (default: 1e-16)
39+
token_ids_to_fix: Additional token IDs to include in the fixing process
40+
41+
Returns:
42+
tuple: If untrained tokens are found in the training data, returns:
43+
- mean_embedding_repeated: Scaled mean embeddings for input layer
44+
- mean_lm_head_repeated: Scaled mean embeddings for output layer
45+
- tokens_to_update: List of token IDs that need updating
46+
None: If no untrained tokens are found in the training data
47+
48+
Raises:
49+
ValueError: If embedding matrices have incorrect shapes or if untrained tokens are found but embeddings are not trainable
3350
"""
34-
# Code licensed under LGPL
35-
if not token_ids_to_fix:
36-
token_ids_to_fix = []
37-
embedding_matrix = model.get_input_embeddings().weight
38-
lm_head_matrix = model.get_output_embeddings().weight
51+
52+
# Check if we still have an issue with the shapes
53+
if input_embeddings.weight.shape[0] == 0 or output_embeddings.weight.shape[0] == 0:
54+
raise ValueError(
55+
f"Could not gather embedding matrices properly. "
56+
f"Shapes: embedding={input_embeddings.weight.shape}, lm_head={output_embeddings.weight.shape}. "
57+
f"This might indicate a DeepSpeed configuration issue."
58+
)
59+
60+
# Get the chat template if available
3961
chat_template = getattr(tokenizer, "chat_template", None)
4062
tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer
4163

42-
# Ignore some model checks for now
43-
if not ignored_tokenizer_names:
44-
ignored_tokenizer_names = []
45-
if (
46-
model.config._name_or_path # pylint: disable=protected-access
47-
in ignored_tokenizer_names
48-
):
49-
return
50-
5164
# Sometimes the sizes can be different like in vision models
5265
# Ie <image> is in input, but not in output
53-
min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1])
54-
embedding_matrix = embedding_matrix[:, :min_size]
55-
lm_head_matrix = lm_head_matrix[:, :min_size]
66+
min_size = min(input_embeddings.weight.shape[1], output_embeddings.weight.shape[1])
67+
input_embeddings.weight = torch.nn.Parameter(input_embeddings.weight[:, :min_size])
68+
output_embeddings.weight = torch.nn.Parameter(output_embeddings.weight[:, :min_size])
5669

5770
# Get untrained tokens
58-
indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps
59-
# Check lm_head as well
71+
indicator_untrained1 = torch.amax(input_embeddings.weight, axis=1) <= eps
6072

6173
# Does NOT work for Llama 3.1!!
62-
indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps
74+
indicator_untrained2 = torch.amax(output_embeddings.weight, axis=1) <= eps
6375

6476
# We instead check for repeated vectors
6577
lm_head_where = torch.where(indicator_untrained1)[0]
66-
lm_head_bad = lm_head_matrix[lm_head_where]
78+
lm_head_bad = output_embeddings.weight[lm_head_where]
6779
lm_head_bad = lm_head_bad.cpu().float().numpy().round(3)
6880
counter = Counter()
6981
for row in lm_head_bad:
@@ -75,7 +87,8 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
7587
for j, row in enumerate(lm_head_bad):
7688
if hash(row.data.tobytes()) in counter:
7789
final_bad_lm_head.append(lm_head_where[j])
78-
indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2)
90+
91+
indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained1, dtype=torch.bool)
7992
indicator_untrained2[final_bad_lm_head] = True
8093

8194
# Combine both checks
@@ -89,7 +102,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
89102

90103
where_untrained = torch.where(indicator_untrained)[0]
91104
n_untrained = where_untrained.shape[0]
92-
n_trained = embedding_matrix.shape[0] - n_untrained
105+
n_trained = input_embeddings.weight.shape[0] - n_untrained
93106

94107
# Get set and actual tokens
95108
where_untrained = where_untrained.tolist()
@@ -145,9 +158,9 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
145158

146159
# Check if lm_head / embed_token are trainable!
147160
bad_not_trainable = False
148-
if not embedding_matrix.requires_grad:
161+
if not input_embeddings.weight.requires_grad:
149162
bad_not_trainable = True
150-
if not lm_head_matrix.requires_grad:
163+
if not output_embeddings.weight.requires_grad:
151164
bad_not_trainable = True
152165

153166
if bad_not_trainable: # pylint: disable=too-many-nested-blocks
@@ -176,8 +189,6 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
176189

177190
# If no bad tokens, possibly chat template itself has issues?
178191
if len(final_bad_items) == 0:
179-
# Recheck 2000 and last 2000 items
180-
size_dataset = len(train_dataset)
181192
size = min(size_dataset, 2000)
182193
for j in range(size):
183194
input_ids = train_dataset[j]
@@ -207,7 +218,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
207218

208219
# Count all the possible bad tokens
209220
final_counts = np.zeros(
210-
max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
221+
max(len(tokenizer), input_embeddings.weight.shape[0]), dtype=np.int64
211222
)
212223

213224
def mapping(examples):
@@ -235,15 +246,15 @@ def mapping(examples):
235246
)
236247

237248
# Get sum of all items
238-
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
239-
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
249+
sum_embedding = torch.sum(input_embeddings.weight, dtype=torch.float32, axis=0)
250+
sum_lm_head = torch.sum(output_embeddings.weight, dtype=torch.float32, axis=0)
240251

241252
# Remove bad tokens
242253
sum_embedding -= torch.sum(
243-
embedding_matrix[where_untrained], dtype=torch.float32, axis=0
254+
input_embeddings.weight[where_untrained], dtype=torch.float32, axis=0
244255
)
245256
sum_lm_head -= torch.sum(
246-
lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
257+
output_embeddings.weight[where_untrained], dtype=torch.float32, axis=0
247258
)
248259

249260
# Find correct average by dividing by sum of trained tokens
@@ -262,14 +273,60 @@ def mapping(examples):
262273
mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
263274
)
264275

265-
# Update embeddings only for tokens seen in train_dataset
266-
embedding_matrix[tokens_to_update] = mean_embedding_repeated.to(
267-
embedding_matrix.dtype
268-
)
269-
lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype)
276+
return mean_embedding_repeated, mean_lm_head_repeated, tokens_to_update
277+
278+
def fix_untrained_tokens(
279+
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16, token_ids_to_fix=None, is_ds_zero3=False
280+
):
281+
"""
282+
Some base models have untrained vectors for embeddings/tokens. Update these embeddings to
283+
the mean of the rest of the tokens. This additionally handles distributed settings like DeepSpeed ZeRO-3.
284+
285+
Args:
286+
model: The model to fix embeddings for
287+
tokenizer: The tokenizer used with the model
288+
train_dataset: The training dataset to analyze token frequencies
289+
ignored_tokenizer_names: List of model names to skip processing (default: None)
290+
eps: Small epsilon value to detect near-zero embeddings (default: 1e-16)
291+
token_ids_to_fix: Additional token IDs to include in the fixing process (default: None)
292+
is_ds_zero3: Whether DeepSpeed ZeRO-3 is being used (default: False)
293+
"""
294+
295+
if not token_ids_to_fix:
296+
token_ids_to_fix = []
297+
if not ignored_tokenizer_names:
298+
ignored_tokenizer_names = []
299+
300+
# Check if we should ignore this model
301+
if (
302+
hasattr(model, "config") and
303+
hasattr(model.config, "_name_or_path") and
304+
model.config._name_or_path in ignored_tokenizer_names
305+
):
306+
return
307+
308+
with torch.no_grad():
309+
# Get the embedding layer and lm_head
310+
embedding_layer = model.get_input_embeddings()
311+
lm_head_layer = model.get_output_embeddings()
312+
313+
context = nullcontext
314+
if is_ds_zero3:
315+
# Get the full parameters if using DeepSpeed
316+
context = partial(deepspeed.zero.GatheredParameters, [embedding_layer.weight, lm_head_layer.weight], modifier_rank=0)
317+
318+
with context():
319+
input_embeddings = model.get_input_embeddings()
320+
output_embeddings = model.get_output_embeddings()
321+
mean_embedding_repeated, mean_lm_head_repeated, tokens_to_update = get_embedding_mean(input_embeddings, output_embeddings, tokenizer, train_dataset, eps=eps, token_ids_to_fix=token_ids_to_fix)
322+
input_embeddings.weight[tokens_to_update] = mean_embedding_repeated.to(
323+
input_embeddings.weight.dtype
324+
)
325+
model.set_input_embeddings(input_embeddings)
326+
output_embeddings.weight[tokens_to_update] = mean_lm_head_repeated.to(output_embeddings.weight.dtype)
327+
model.set_output_embeddings(output_embeddings)
270328

271329
# Clean up
272330
for _ in range(3):
273331
gc.collect()
274332
torch.cuda.empty_cache()
275-
return

0 commit comments

Comments
 (0)