diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index 2af9538494b5..9a4871bf5437 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -16,7 +16,7 @@ import os import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type, Union import lightning.pytorch as pl from lightning.pytorch.utilities.rank_zero import rank_zero_info @@ -153,6 +153,10 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin): num_test_samples (Optional[int]): The number of samples to use for testing, defaults to total test steps times global batch size. dataset_cls (Optional[Type[MegatronDataset]]): The dataset class to use for the data module. + dataloader_type (Optional[Literal["single", "cyclic", "batch"]]): Data loading strategy. + init_consumed_samples: (Optional[int]): Number of samples already consumed at initialization. + init_global_step: (Optional[int]): Starting global training step count, used for resuming training. + output_log: (Optional[bool]): Whether to print logging/debug output during sampling. """ def __init__( @@ -177,6 +181,10 @@ def __init__( num_train_samples: Optional[int] = None, num_val_samples: Optional[int] = None, num_test_samples: Optional[int] = None, + dataloader_type: Optional[Literal["single", "cyclic", "batch"]] = "single", + init_consumed_samples: Optional[int] = 0, + init_global_step: Optional[int] = 0, + output_log: Optional[bool] = True, dataset_cls: Type[MegatronDataset] = GPTDataset, ) -> None: super().__init__() @@ -227,6 +235,10 @@ def __init__( self.num_train_samples = num_train_samples self.num_val_samples = num_val_samples self.num_test_samples = num_test_samples + self.dataloader_type = dataloader_type + self.init_consumed_samples = init_consumed_samples + self.init_global_step = init_global_step + self.output_log = output_log from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer @@ -236,6 +248,10 @@ def __init__( micro_batch_size=self.micro_batch_size, global_batch_size=self.global_batch_size, rampup_batch_size=rampup_batch_size, + dataloader_type=self.dataloader_type, + init_consumed_samples=self.init_consumed_samples, + init_global_step=self.init_global_step, + output_log=self.output_log, ) def build(