1616import os
1717import warnings
1818from pathlib import Path
19- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Type , Union
19+ from typing import TYPE_CHECKING , Any , Dict , List , Literal , Optional , Type , Union
2020
2121import lightning .pytorch as pl
2222from lightning .pytorch .utilities .rank_zero import rank_zero_info
@@ -153,6 +153,10 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin):
153153 num_test_samples (Optional[int]): The number of samples to use for testing, defaults to total
154154 test steps times global batch size.
155155 dataset_cls (Optional[Type[MegatronDataset]]): The dataset class to use for the data module.
156+ dataloader_type (Optional[Literal["single", "cyclic", "batch"]]): Data loading strategy.
157+ init_consumed_samples: (Optional[int]): Number of samples already consumed at initialization.
158+ init_global_step: (Optional[int]): Starting global training step count, used for resuming training.
159+ output_log: (Optional[bool]): Whether to print logging/debug output during sampling.
156160 """
157161
158162 def __init__ (
@@ -177,6 +181,10 @@ def __init__(
177181 num_train_samples : Optional [int ] = None ,
178182 num_val_samples : Optional [int ] = None ,
179183 num_test_samples : Optional [int ] = None ,
184+ dataloader_type : Optional [Literal ["single" , "cyclic" , "batch" ]] = "single" ,
185+ init_consumed_samples : Optional [int ] = 0 ,
186+ init_global_step : Optional [int ] = 0 ,
187+ output_log : Optional [bool ] = True ,
180188 dataset_cls : Type [MegatronDataset ] = GPTDataset ,
181189 ) -> None :
182190 super ().__init__ ()
@@ -227,6 +235,10 @@ def __init__(
227235 self .num_train_samples = num_train_samples
228236 self .num_val_samples = num_val_samples
229237 self .num_test_samples = num_test_samples
238+ self .dataloader_type = dataloader_type
239+ self .init_consumed_samples = init_consumed_samples
240+ self .init_global_step = init_global_step
241+ self .output_log = output_log
230242
231243 from nemo .collections .nlp .modules .common .tokenizer_utils import get_nmt_tokenizer
232244
@@ -236,6 +248,10 @@ def __init__(
236248 micro_batch_size = self .micro_batch_size ,
237249 global_batch_size = self .global_batch_size ,
238250 rampup_batch_size = rampup_batch_size ,
251+ dataloader_type = self .dataloader_type ,
252+ init_consumed_samples = self .init_consumed_samples ,
253+ init_global_step = self .init_global_step ,
254+ output_log = self .output_log ,
239255 )
240256
241257 def build (
0 commit comments