1414import itertools
1515import logging
1616from collections import Counter
17+ from contextlib import nullcontext
18+ from functools import partial
1719
1820import datasets
21+ import deepspeed
1922import numpy as np
2023import torch
2124
2225LOG = logging .getLogger ("axolotl.contribs.lgpl.unsloth" )
2326
2427
25- @torch .inference_mode ()
26- def fix_untrained_tokens ( # pylint: disable=too-many-return-statements
27- model , tokenizer , train_dataset , ignored_tokenizer_names = None , eps = 1e-16 , token_ids_to_fix = None ,
28- ):
28+ def get_embedding_mean (input_embeddings , output_embeddings , tokenizer , train_dataset , eps = 1e-16 , token_ids_to_fix = None ):
2929 """
30- Llama-3 for eg has untrained vectors in the base model.
31- These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
32- We reset them to the mean of the rest of the tokens
30+ Detects tokens with untrained embeddings (values close to zero or duplicated) and
31+ prepares scaled mean embedding values for these tokens based on their frequency in the training data.
32+
33+ Args:
34+ input_embeddings: The model's input embedding layer (embed_tokens)
35+ output_embeddings: The model's output embedding layer (lm_head)
36+ tokenizer: The tokenizer used with the model
37+ train_dataset: The training dataset to analyze token frequencies
38+ eps: Small epsilon value to detect near-zero embeddings (default: 1e-16)
39+ token_ids_to_fix: Additional token IDs to include in the fixing process
40+
41+ Returns:
42+ tuple: If untrained tokens are found in the training data, returns:
43+ - mean_embedding_repeated: Scaled mean embeddings for input layer
44+ - mean_lm_head_repeated: Scaled mean embeddings for output layer
45+ - tokens_to_update: List of token IDs that need updating
46+ None: If no untrained tokens are found in the training data
47+
48+ Raises:
49+ ValueError: If embedding matrices have incorrect shapes or if untrained tokens are found but embeddings are not trainable
3350 """
34- # Code licensed under LGPL
35- if not token_ids_to_fix :
36- token_ids_to_fix = []
37- embedding_matrix = model .get_input_embeddings ().weight
38- lm_head_matrix = model .get_output_embeddings ().weight
51+
52+ # Check if we still have an issue with the shapes
53+ if input_embeddings .weight .shape [0 ] == 0 or output_embeddings .weight .shape [0 ] == 0 :
54+ raise ValueError (
55+ f"Could not gather embedding matrices properly. "
56+ f"Shapes: embedding={ input_embeddings .weight .shape } , lm_head={ output_embeddings .weight .shape } . "
57+ f"This might indicate a DeepSpeed configuration issue."
58+ )
59+
60+ # Get the chat template if available
3961 chat_template = getattr (tokenizer , "chat_template" , None )
4062 tokenizer = tokenizer .tokenizer if hasattr (tokenizer , "tokenizer" ) else tokenizer
4163
42- # Ignore some model checks for now
43- if not ignored_tokenizer_names :
44- ignored_tokenizer_names = []
45- if (
46- model .config ._name_or_path # pylint: disable=protected-access
47- in ignored_tokenizer_names
48- ):
49- return
50-
5164 # Sometimes the sizes can be different like in vision models
5265 # Ie <image> is in input, but not in output
53- min_size = min (embedding_matrix . shape [1 ], lm_head_matrix .shape [1 ])
54- embedding_matrix = embedding_matrix [:, :min_size ]
55- lm_head_matrix = lm_head_matrix [:, :min_size ]
66+ min_size = min (input_embeddings . weight . shape [1 ], output_embeddings . weight .shape [1 ])
67+ input_embeddings . weight = torch . nn . Parameter ( input_embeddings . weight [:, :min_size ])
68+ output_embeddings . weight = torch . nn . Parameter ( output_embeddings . weight [:, :min_size ])
5669
5770 # Get untrained tokens
58- indicator_untrained1 = torch .amax (embedding_matrix , axis = 1 ) <= eps
59- # Check lm_head as well
71+ indicator_untrained1 = torch .amax (input_embeddings .weight , axis = 1 ) <= eps
6072
6173 # Does NOT work for Llama 3.1!!
62- indicator_untrained2 = torch .amax (lm_head_matrix , axis = 1 ) <= eps
74+ indicator_untrained2 = torch .amax (output_embeddings . weight , axis = 1 ) <= eps
6375
6476 # We instead check for repeated vectors
6577 lm_head_where = torch .where (indicator_untrained1 )[0 ]
66- lm_head_bad = lm_head_matrix [lm_head_where ]
78+ lm_head_bad = output_embeddings . weight [lm_head_where ]
6779 lm_head_bad = lm_head_bad .cpu ().float ().numpy ().round (3 )
6880 counter = Counter ()
6981 for row in lm_head_bad :
@@ -75,7 +87,8 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
7587 for j , row in enumerate (lm_head_bad ):
7688 if hash (row .data .tobytes ()) in counter :
7789 final_bad_lm_head .append (lm_head_where [j ])
78- indicator_untrained2 = indicator_untrained2 | torch .zeros_like (indicator_untrained2 )
90+
91+ indicator_untrained2 = indicator_untrained2 | torch .zeros_like (indicator_untrained1 , dtype = torch .bool )
7992 indicator_untrained2 [final_bad_lm_head ] = True
8093
8194 # Combine both checks
@@ -89,7 +102,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
89102
90103 where_untrained = torch .where (indicator_untrained )[0 ]
91104 n_untrained = where_untrained .shape [0 ]
92- n_trained = embedding_matrix .shape [0 ] - n_untrained
105+ n_trained = input_embeddings . weight .shape [0 ] - n_untrained
93106
94107 # Get set and actual tokens
95108 where_untrained = where_untrained .tolist ()
@@ -145,9 +158,9 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
145158
146159 # Check if lm_head / embed_token are trainable!
147160 bad_not_trainable = False
148- if not embedding_matrix .requires_grad :
161+ if not input_embeddings . weight .requires_grad :
149162 bad_not_trainable = True
150- if not lm_head_matrix .requires_grad :
163+ if not output_embeddings . weight .requires_grad :
151164 bad_not_trainable = True
152165
153166 if bad_not_trainable : # pylint: disable=too-many-nested-blocks
@@ -176,8 +189,6 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
176189
177190 # If no bad tokens, possibly chat template itself has issues?
178191 if len (final_bad_items ) == 0 :
179- # Recheck 2000 and last 2000 items
180- size_dataset = len (train_dataset )
181192 size = min (size_dataset , 2000 )
182193 for j in range (size ):
183194 input_ids = train_dataset [j ]
@@ -207,7 +218,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
207218
208219 # Count all the possible bad tokens
209220 final_counts = np .zeros (
210- max (len (tokenizer ), embedding_matrix .shape [0 ]), dtype = np .int64
221+ max (len (tokenizer ), input_embeddings . weight .shape [0 ]), dtype = np .int64
211222 )
212223
213224 def mapping (examples ):
@@ -235,15 +246,15 @@ def mapping(examples):
235246 )
236247
237248 # Get sum of all items
238- sum_embedding = torch .sum (embedding_matrix , dtype = torch .float32 , axis = 0 )
239- sum_lm_head = torch .sum (lm_head_matrix , dtype = torch .float32 , axis = 0 )
249+ sum_embedding = torch .sum (input_embeddings . weight , dtype = torch .float32 , axis = 0 )
250+ sum_lm_head = torch .sum (output_embeddings . weight , dtype = torch .float32 , axis = 0 )
240251
241252 # Remove bad tokens
242253 sum_embedding -= torch .sum (
243- embedding_matrix [where_untrained ], dtype = torch .float32 , axis = 0
254+ input_embeddings . weight [where_untrained ], dtype = torch .float32 , axis = 0
244255 )
245256 sum_lm_head -= torch .sum (
246- lm_head_matrix [where_untrained ], dtype = torch .float32 , axis = 0
257+ output_embeddings . weight [where_untrained ], dtype = torch .float32 , axis = 0
247258 )
248259
249260 # Find correct average by dividing by sum of trained tokens
@@ -262,14 +273,60 @@ def mapping(examples):
262273 mean_lm_head .unsqueeze (0 ).repeat (len (tokens_to_update ), 1 ) * scaling
263274 )
264275
265- # Update embeddings only for tokens seen in train_dataset
266- embedding_matrix [tokens_to_update ] = mean_embedding_repeated .to (
267- embedding_matrix .dtype
268- )
269- lm_head_matrix [tokens_to_update ] = mean_lm_head_repeated .to (lm_head_matrix .dtype )
276+ return mean_embedding_repeated , mean_lm_head_repeated , tokens_to_update
277+
278+ def fix_untrained_tokens (
279+ model , tokenizer , train_dataset , ignored_tokenizer_names = None , eps = 1e-16 , token_ids_to_fix = None , is_ds_zero3 = False
280+ ):
281+ """
282+ Some base models have untrained vectors for embeddings/tokens. Update these embeddings to
283+ the mean of the rest of the tokens. This additionally handles distributed settings like DeepSpeed ZeRO-3.
284+
285+ Args:
286+ model: The model to fix embeddings for
287+ tokenizer: The tokenizer used with the model
288+ train_dataset: The training dataset to analyze token frequencies
289+ ignored_tokenizer_names: List of model names to skip processing (default: None)
290+ eps: Small epsilon value to detect near-zero embeddings (default: 1e-16)
291+ token_ids_to_fix: Additional token IDs to include in the fixing process (default: None)
292+ is_ds_zero3: Whether DeepSpeed ZeRO-3 is being used (default: False)
293+ """
294+
295+ if not token_ids_to_fix :
296+ token_ids_to_fix = []
297+ if not ignored_tokenizer_names :
298+ ignored_tokenizer_names = []
299+
300+ # Check if we should ignore this model
301+ if (
302+ hasattr (model , "config" ) and
303+ hasattr (model .config , "_name_or_path" ) and
304+ model .config ._name_or_path in ignored_tokenizer_names
305+ ):
306+ return
307+
308+ with torch .no_grad ():
309+ # Get the embedding layer and lm_head
310+ embedding_layer = model .get_input_embeddings ()
311+ lm_head_layer = model .get_output_embeddings ()
312+
313+ context = nullcontext
314+ if is_ds_zero3 :
315+ # Get the full parameters if using DeepSpeed
316+ context = partial (deepspeed .zero .GatheredParameters , [embedding_layer .weight , lm_head_layer .weight ], modifier_rank = 0 )
317+
318+ with context ():
319+ input_embeddings = model .get_input_embeddings ()
320+ output_embeddings = model .get_output_embeddings ()
321+ mean_embedding_repeated , mean_lm_head_repeated , tokens_to_update = get_embedding_mean (input_embeddings , output_embeddings , tokenizer , train_dataset , eps = eps , token_ids_to_fix = token_ids_to_fix )
322+ input_embeddings .weight [tokens_to_update ] = mean_embedding_repeated .to (
323+ input_embeddings .weight .dtype
324+ )
325+ model .set_input_embeddings (input_embeddings )
326+ output_embeddings .weight [tokens_to_update ] = mean_lm_head_repeated .to (output_embeddings .weight .dtype )
327+ model .set_output_embeddings (output_embeddings )
270328
271329 # Clean up
272330 for _ in range (3 ):
273331 gc .collect ()
274332 torch .cuda .empty_cache ()
275- return
0 commit comments