Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions nemo/collections/llm/gpt/data/pre_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils.import_utils import safe_import
from nemo.utils.msc_utils import import_multistorageclient, is_multistorageclient_url

_, HAVE_TE = safe_import("transformer_engine")

Expand Down Expand Up @@ -85,7 +86,12 @@ def validate_dataset_asset_accessibility(paths):
if not isinstance(paths, str) and not isinstance(paths, Path):
raise ValueError("Expected path to be of string or Path type.")

path = Path(paths)
if is_multistorageclient_url(paths):
msc = import_multistorageclient()
path = msc.Path(paths)
else:
path = Path(paths)

suffices = (".bin", ".idx")
if path.is_dir():
if not os.access(path, os.R_OK):
Expand All @@ -97,7 +103,7 @@ def validate_dataset_asset_accessibility(paths):
raise PermissionError(f"Expected {str(path)} to be readable.")
return
for suffix in suffices:
file_path = Path(str(path) + suffix)
file_path = path.with_suffix(suffix)
if not file_path.exists():
raise FileNotFoundError(f"Expected {str(file_path)} to exist.")
if not os.access(file_path, os.R_OK):
Expand Down Expand Up @@ -157,6 +163,7 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin):
init_consumed_samples: (Optional[int]): Number of samples already consumed at initialization.
init_global_step: (Optional[int]): Starting global training step count, used for resuming training.
output_log: (Optional[bool]): Whether to print logging/debug output during sampling.
object_storage_cache_path: (Optional[str]): Path for caching indices for s3 or msc dataloading.
"""

def __init__(
Expand Down Expand Up @@ -186,6 +193,7 @@ def __init__(
init_global_step: Optional[int] = 0,
output_log: Optional[bool] = True,
dataset_cls: Type[MegatronDataset] = GPTDataset,
object_storage_cache_path: Optional[str] = None,
) -> None:
super().__init__()
if not isinstance(paths, (list, tuple, dict)):
Expand Down Expand Up @@ -215,6 +223,10 @@ def __init__(
build_kwargs["blend"] = [paths, weights]
build_kwargs["split"] = split

if object_storage_cache_path:
build_kwargs["object_storage_cache_path"] = object_storage_cache_path
build_kwargs["mmap_bin_files"] = False

self.build_kwargs = build_kwargs
self.seq_length = seq_length
self.micro_batch_size = micro_batch_size
Expand Down
14 changes: 14 additions & 0 deletions tests/collections/llm/gpt/data/test_pre_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,17 @@ def test_validate_dataset_asset_accessibility_file_is_none(tokenizer, trainer):
raised_exception = True

assert raised_exception == True, "Expected to raise a ValueError"


def test_object_storage_cache_path(tokenizer):
data = PreTrainingDataModule(
paths=[f"msc://default{DATA_PATH}"],
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
tokenizer=tokenizer,
object_storage_cache_path="/tmp/object_storage_cache_path",
)

assert data.build_kwargs["object_storage_cache_path"] == "/tmp/object_storage_cache_path"
assert data.build_kwargs["mmap_bin_files"] == False
Loading