2929from nemo .lightning .io .mixin import IOMixin
3030from nemo .lightning .pytorch .plugins import MegatronDataSampler
3131from nemo .utils .import_utils import safe_import
32+ from nemo .utils .msc_utils import import_multistorageclient , is_multistorageclient_url
3233
3334_ , HAVE_TE = safe_import ("transformer_engine" )
3435
@@ -85,7 +86,12 @@ def validate_dataset_asset_accessibility(paths):
8586 if not isinstance (paths , str ) and not isinstance (paths , Path ):
8687 raise ValueError ("Expected path to be of string or Path type." )
8788
88- path = Path (paths )
89+ if is_multistorageclient_url (paths ):
90+ msc = import_multistorageclient ()
91+ path = msc .Path (paths )
92+ else :
93+ path = Path (paths )
94+
8995 suffices = (".bin" , ".idx" )
9096 if path .is_dir ():
9197 if not os .access (path , os .R_OK ):
@@ -97,7 +103,7 @@ def validate_dataset_asset_accessibility(paths):
97103 raise PermissionError (f"Expected { str (path )} to be readable." )
98104 return
99105 for suffix in suffices :
100- file_path = Path ( str ( path ) + suffix )
106+ file_path = path . with_suffix ( suffix )
101107 if not file_path .exists ():
102108 raise FileNotFoundError (f"Expected { str (file_path )} to exist." )
103109 if not os .access (file_path , os .R_OK ):
@@ -157,6 +163,7 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin):
157163 init_consumed_samples: (Optional[int]): Number of samples already consumed at initialization.
158164 init_global_step: (Optional[int]): Starting global training step count, used for resuming training.
159165 output_log: (Optional[bool]): Whether to print logging/debug output during sampling.
166+ object_storage_cache_path: (Optional[str]): Path for caching indices for s3 or msc dataloading.
160167 """
161168
162169 def __init__ (
@@ -186,6 +193,7 @@ def __init__(
186193 init_global_step : Optional [int ] = 0 ,
187194 output_log : Optional [bool ] = True ,
188195 dataset_cls : Type [MegatronDataset ] = GPTDataset ,
196+ object_storage_cache_path : Optional [str ] = None ,
189197 ) -> None :
190198 super ().__init__ ()
191199 if not isinstance (paths , (list , tuple , dict )):
@@ -215,6 +223,10 @@ def __init__(
215223 build_kwargs ["blend" ] = [paths , weights ]
216224 build_kwargs ["split" ] = split
217225
226+ if object_storage_cache_path :
227+ build_kwargs ["object_storage_cache_path" ] = object_storage_cache_path
228+ build_kwargs ["mmap_bin_files" ] = False
229+
218230 self .build_kwargs = build_kwargs
219231 self .seq_length = seq_length
220232 self .micro_batch_size = micro_batch_size
0 commit comments