Skip to content

Commit a04fe97

Browse files
shunjiadnasretdinovr
authored andcommitted
Add object_storage_cache_path to PreTrainingDataModule (NVIDIA-NeMo#14103)
Signed-off-by: Shunjia Ding <[email protected]>
1 parent 0dd9cf7 commit a04fe97

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

nemo/collections/llm/gpt/data/pre_training.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from nemo.lightning.io.mixin import IOMixin
3030
from nemo.lightning.pytorch.plugins import MegatronDataSampler
3131
from nemo.utils.import_utils import safe_import
32+
from nemo.utils.msc_utils import import_multistorageclient, is_multistorageclient_url
3233

3334
_, HAVE_TE = safe_import("transformer_engine")
3435

@@ -85,7 +86,12 @@ def validate_dataset_asset_accessibility(paths):
8586
if not isinstance(paths, str) and not isinstance(paths, Path):
8687
raise ValueError("Expected path to be of string or Path type.")
8788

88-
path = Path(paths)
89+
if is_multistorageclient_url(paths):
90+
msc = import_multistorageclient()
91+
path = msc.Path(paths)
92+
else:
93+
path = Path(paths)
94+
8995
suffices = (".bin", ".idx")
9096
if path.is_dir():
9197
if not os.access(path, os.R_OK):
@@ -97,7 +103,7 @@ def validate_dataset_asset_accessibility(paths):
97103
raise PermissionError(f"Expected {str(path)} to be readable.")
98104
return
99105
for suffix in suffices:
100-
file_path = Path(str(path) + suffix)
106+
file_path = path.with_suffix(suffix)
101107
if not file_path.exists():
102108
raise FileNotFoundError(f"Expected {str(file_path)} to exist.")
103109
if not os.access(file_path, os.R_OK):
@@ -157,6 +163,7 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin):
157163
init_consumed_samples: (Optional[int]): Number of samples already consumed at initialization.
158164
init_global_step: (Optional[int]): Starting global training step count, used for resuming training.
159165
output_log: (Optional[bool]): Whether to print logging/debug output during sampling.
166+
object_storage_cache_path: (Optional[str]): Path for caching indices for s3 or msc dataloading.
160167
"""
161168

162169
def __init__(
@@ -186,6 +193,7 @@ def __init__(
186193
init_global_step: Optional[int] = 0,
187194
output_log: Optional[bool] = True,
188195
dataset_cls: Type[MegatronDataset] = GPTDataset,
196+
object_storage_cache_path: Optional[str] = None,
189197
) -> None:
190198
super().__init__()
191199
if not isinstance(paths, (list, tuple, dict)):
@@ -215,6 +223,10 @@ def __init__(
215223
build_kwargs["blend"] = [paths, weights]
216224
build_kwargs["split"] = split
217225

226+
if object_storage_cache_path:
227+
build_kwargs["object_storage_cache_path"] = object_storage_cache_path
228+
build_kwargs["mmap_bin_files"] = False
229+
218230
self.build_kwargs = build_kwargs
219231
self.seq_length = seq_length
220232
self.micro_batch_size = micro_batch_size

tests/collections/llm/gpt/data/test_pre_training_data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,17 @@ def test_validate_dataset_asset_accessibility_file_is_none(tokenizer, trainer):
112112
raised_exception = True
113113

114114
assert raised_exception == True, "Expected to raise a ValueError"
115+
116+
117+
def test_object_storage_cache_path(tokenizer):
118+
data = PreTrainingDataModule(
119+
paths=[f"msc://default{DATA_PATH}"],
120+
seq_length=512,
121+
micro_batch_size=2,
122+
global_batch_size=2,
123+
tokenizer=tokenizer,
124+
object_storage_cache_path="/tmp/object_storage_cache_path",
125+
)
126+
127+
assert data.build_kwargs["object_storage_cache_path"] == "/tmp/object_storage_cache_path"
128+
assert data.build_kwargs["mmap_bin_files"] == False

0 commit comments

Comments
 (0)