Skip to content

Commit b26bf7f

Browse files
cuichenxchtruong814
authored andcommitted
Add cp per token loss check (#14282)
* add cp per token loss check Signed-off-by: Chen Cui <[email protected]> * Apply isort and black reformatting Signed-off-by: cuichenx <[email protected]> --------- Signed-off-by: Chen Cui <[email protected]> Signed-off-by: cuichenx <[email protected]> Co-authored-by: cuichenx <[email protected]> Signed-off-by: Charlie Truong <[email protected]>
1 parent b807e7a commit b26bf7f

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

nemo/collections/llm/api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
EvaluationTarget,
3636
MisconfigurationError,
3737
)
38+
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
3839
from nemo.collections.llm.modelopt import (
3940
DistillationGPTModel,
4041
ExportConfig,
@@ -1359,6 +1360,10 @@ def _validate_config(
13591360
assert (
13601361
model.config.seq_length % (trainer.strategy.context_parallel_size * 2) == 0
13611362
), 'Sequence length must be divisible by 2 * context parallel size if context parallel is used.'
1363+
if isinstance(data, FineTuningDataModule):
1364+
assert model.config.calculate_per_token_loss, (
1365+
"When finetuning with CP>1, " "model.config.calculate_per_token_loss must be True"
1366+
)
13621367

13631368
# EP validation
13641369
if trainer.strategy.expert_model_parallel_size > 1:

0 commit comments

Comments
 (0)