Skip to content

Commit b7f72ea

Browse files
dimapihtarnasretdinovr
authored andcommitted
add extra params for MegatronDataSampler (NVIDIA-NeMo#13956)
* add extra params for MegatronDataSampler Signed-off-by: dimapihtar <[email protected]> * fix style Signed-off-by: dimapihtar <[email protected]> * Apply isort and black reformatting Signed-off-by: dimapihtar <[email protected]> --------- Signed-off-by: dimapihtar <[email protected]> Signed-off-by: dimapihtar <[email protected]> Co-authored-by: dimapihtar <[email protected]>
1 parent 4835627 commit b7f72ea

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

nemo/collections/llm/gpt/data/pre_training.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
import warnings
1818
from pathlib import Path
19-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
19+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type, Union
2020

2121
import lightning.pytorch as pl
2222
from lightning.pytorch.utilities.rank_zero import rank_zero_info
@@ -153,6 +153,10 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin):
153153
num_test_samples (Optional[int]): The number of samples to use for testing, defaults to total
154154
test steps times global batch size.
155155
dataset_cls (Optional[Type[MegatronDataset]]): The dataset class to use for the data module.
156+
dataloader_type (Optional[Literal["single", "cyclic", "batch"]]): Data loading strategy.
157+
init_consumed_samples: (Optional[int]): Number of samples already consumed at initialization.
158+
init_global_step: (Optional[int]): Starting global training step count, used for resuming training.
159+
output_log: (Optional[bool]): Whether to print logging/debug output during sampling.
156160
"""
157161

158162
def __init__(
@@ -177,6 +181,10 @@ def __init__(
177181
num_train_samples: Optional[int] = None,
178182
num_val_samples: Optional[int] = None,
179183
num_test_samples: Optional[int] = None,
184+
dataloader_type: Optional[Literal["single", "cyclic", "batch"]] = "single",
185+
init_consumed_samples: Optional[int] = 0,
186+
init_global_step: Optional[int] = 0,
187+
output_log: Optional[bool] = True,
180188
dataset_cls: Type[MegatronDataset] = GPTDataset,
181189
) -> None:
182190
super().__init__()
@@ -227,6 +235,10 @@ def __init__(
227235
self.num_train_samples = num_train_samples
228236
self.num_val_samples = num_val_samples
229237
self.num_test_samples = num_test_samples
238+
self.dataloader_type = dataloader_type
239+
self.init_consumed_samples = init_consumed_samples
240+
self.init_global_step = init_global_step
241+
self.output_log = output_log
230242

231243
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
232244

@@ -236,6 +248,10 @@ def __init__(
236248
micro_batch_size=self.micro_batch_size,
237249
global_batch_size=self.global_batch_size,
238250
rampup_batch_size=rampup_batch_size,
251+
dataloader_type=self.dataloader_type,
252+
init_consumed_samples=self.init_consumed_samples,
253+
init_global_step=self.init_global_step,
254+
output_log=self.output_log,
239255
)
240256

241257
def build(

0 commit comments

Comments
 (0)