Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion nemo/collections/llm/gpt/data/pre_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type, Union

import lightning.pytorch as pl
from lightning.pytorch.utilities.rank_zero import rank_zero_info
Expand Down Expand Up @@ -153,6 +153,10 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin):
num_test_samples (Optional[int]): The number of samples to use for testing, defaults to total
test steps times global batch size.
dataset_cls (Optional[Type[MegatronDataset]]): The dataset class to use for the data module.
dataloader_type (Optional[Literal["single", "cyclic", "batch"]]): Data loading strategy.
init_consumed_samples: (Optional[int]): Number of samples already consumed at initialization.
init_global_step: (Optional[int]): Starting global training step count, used for resuming training.
output_log: (Optional[bool]): Whether to print logging/debug output during sampling.
"""

def __init__(
Expand All @@ -177,6 +181,10 @@ def __init__(
num_train_samples: Optional[int] = None,
num_val_samples: Optional[int] = None,
num_test_samples: Optional[int] = None,
dataloader_type: Optional[Literal["single", "cyclic", "batch"]] = "single",
init_consumed_samples: Optional[int] = 0,
init_global_step: Optional[int] = 0,
output_log: Optional[bool] = True,
dataset_cls: Type[MegatronDataset] = GPTDataset,
) -> None:
super().__init__()
Expand Down Expand Up @@ -227,6 +235,10 @@ def __init__(
self.num_train_samples = num_train_samples
self.num_val_samples = num_val_samples
self.num_test_samples = num_test_samples
self.dataloader_type = dataloader_type
self.init_consumed_samples = init_consumed_samples
self.init_global_step = init_global_step
self.output_log = output_log

from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

Expand All @@ -236,6 +248,10 @@ def __init__(
micro_batch_size=self.micro_batch_size,
global_batch_size=self.global_batch_size,
rampup_batch_size=rampup_batch_size,
dataloader_type=self.dataloader_type,
init_consumed_samples=self.init_consumed_samples,
init_global_step=self.init_global_step,
output_log=self.output_log,
)

def build(
Expand Down
Loading