-
Notifications
You must be signed in to change notification settings - Fork 681
Validate tokenizer and model alignment before training #2074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
59dc807
faa2122
50b10b8
883c281
cdc1f1b
31853e8
24557cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -468,3 +468,45 @@ def get_moe_model_nparams_and_flops( | |
| nparams = nparams - nparams_embedding | ||
|
|
||
| return nparams, num_flops_per_token | ||
|
|
||
|
|
||
| def validate_tokenizer_model_alignment( | ||
| tokenizer: "BaseTokenizer | None", | ||
| model_args: "BaseModelArgs", | ||
|
||
| ) -> None: | ||
| """ | ||
| Validate that tokenizer configuration matches model configuration. | ||
|
|
||
| Args: | ||
| tokenizer: Tokenizer instance to validate. Can be None. | ||
| model_args: Model arguments object containing configuration to validate against. | ||
|
|
||
| Raises: | ||
| ValueError: If tokenizer and model configurations don't match. | ||
| """ | ||
| if tokenizer is None: | ||
| return | ||
|
|
||
| # Validate vocab_size | ||
| if hasattr(model_args, "vocab_size"): | ||
| tokenizer_vocab_size = tokenizer.get_vocab_size() | ||
| model_vocab_size = model_args.vocab_size | ||
| if tokenizer_vocab_size != model_vocab_size: | ||
|
||
| raise ValueError( | ||
| f"Tokenizer vocab_size ({tokenizer_vocab_size}) does not match " | ||
| f"model vocab_size ({model_vocab_size}). " | ||
| f"This mismatch will cause training errors. " | ||
| f"Please ensure the tokenizer and model configuration are aligned." | ||
| ) | ||
|
|
||
| # Validate eos_id | ||
| if hasattr(model_args, "eos_id"): | ||
|
||
| tokenizer_eos_id = getattr(tokenizer, "eos_id", None) | ||
| model_eos_id = model_args.eos_id | ||
| if tokenizer_eos_id is not None and tokenizer_eos_id != model_eos_id: | ||
| raise ValueError( | ||
| f"Tokenizer eos_id ({tokenizer_eos_id}) does not match " | ||
| f"model eos_id ({model_eos_id}). " | ||
| f"This mismatch may cause training errors. " | ||
| f"Please ensure the tokenizer and model configuration are aligned." | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe call it
since we no longer require them to be identical
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l That makes sense. I’ve updated the function name accordingly, and also reverted the removal of
eos_id. Thanks for the suggestion!