From 51c8d2818516ec94bf7497d0204b70f80e8da64d Mon Sep 17 00:00:00 2001 From: Shunjia Ding Date: Wed, 2 Jul 2025 14:09:00 +0800 Subject: [PATCH] Add object_storage_cache_path to PreTrainingDataModule Signed-off-by: Shunjia Ding --- nemo/collections/llm/gpt/data/pre_training.py | 16 ++++++++++++++-- .../llm/gpt/data/test_pre_training_data.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index 9a4871bf5437..9ec94d4f405f 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -29,6 +29,7 @@ from nemo.lightning.io.mixin import IOMixin from nemo.lightning.pytorch.plugins import MegatronDataSampler from nemo.utils.import_utils import safe_import +from nemo.utils.msc_utils import import_multistorageclient, is_multistorageclient_url _, HAVE_TE = safe_import("transformer_engine") @@ -85,7 +86,12 @@ def validate_dataset_asset_accessibility(paths): if not isinstance(paths, str) and not isinstance(paths, Path): raise ValueError("Expected path to be of string or Path type.") - path = Path(paths) + if is_multistorageclient_url(paths): + msc = import_multistorageclient() + path = msc.Path(paths) + else: + path = Path(paths) + suffices = (".bin", ".idx") if path.is_dir(): if not os.access(path, os.R_OK): @@ -97,7 +103,7 @@ def validate_dataset_asset_accessibility(paths): raise PermissionError(f"Expected {str(path)} to be readable.") return for suffix in suffices: - file_path = Path(str(path) + suffix) + file_path = path.with_suffix(suffix) if not file_path.exists(): raise FileNotFoundError(f"Expected {str(file_path)} to exist.") if not os.access(file_path, os.R_OK): @@ -157,6 +163,7 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin): init_consumed_samples: (Optional[int]): Number of samples already consumed at initialization. init_global_step: (Optional[int]): Starting global training step count, used for resuming training. output_log: (Optional[bool]): Whether to print logging/debug output during sampling. + object_storage_cache_path: (Optional[str]): Path for caching indices for s3 or msc dataloading. """ def __init__( @@ -186,6 +193,7 @@ def __init__( init_global_step: Optional[int] = 0, output_log: Optional[bool] = True, dataset_cls: Type[MegatronDataset] = GPTDataset, + object_storage_cache_path: Optional[str] = None, ) -> None: super().__init__() if not isinstance(paths, (list, tuple, dict)): @@ -215,6 +223,10 @@ def __init__( build_kwargs["blend"] = [paths, weights] build_kwargs["split"] = split + if object_storage_cache_path: + build_kwargs["object_storage_cache_path"] = object_storage_cache_path + build_kwargs["mmap_bin_files"] = False + self.build_kwargs = build_kwargs self.seq_length = seq_length self.micro_batch_size = micro_batch_size diff --git a/tests/collections/llm/gpt/data/test_pre_training_data.py b/tests/collections/llm/gpt/data/test_pre_training_data.py index f6167593c953..c2cd04348e38 100644 --- a/tests/collections/llm/gpt/data/test_pre_training_data.py +++ b/tests/collections/llm/gpt/data/test_pre_training_data.py @@ -112,3 +112,17 @@ def test_validate_dataset_asset_accessibility_file_is_none(tokenizer, trainer): raised_exception = True assert raised_exception == True, "Expected to raise a ValueError" + + +def test_object_storage_cache_path(tokenizer): + data = PreTrainingDataModule( + paths=[f"msc://default{DATA_PATH}"], + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + tokenizer=tokenizer, + object_storage_cache_path="/tmp/object_storage_cache_path", + ) + + assert data.build_kwargs["object_storage_cache_path"] == "/tmp/object_storage_cache_path" + assert data.build_kwargs["mmap_bin_files"] == False