4242import dataclasses
4343import enum
4444import logging
45+ import multiprocessing
46+ import platform
47+ import sys
4548import warnings
4649from functools import partial
4750from typing import Any
@@ -148,14 +151,78 @@ def get_config_for_mode(self, mode: DatasetMode) -> "MultiDatasetConfig":
148151 return self .get_subset (datasets_stage_mask )
149152
150153
154+
151155class DataModuleConfig (BaseModel ):
152156 datasets : list [SerializeAsAny [BaseModel ]]
153157 batch_size : int = 1
154158 num_workers : int = 0
155159 num_workers_validation : int = 0
160+ multiprocessing_context : str = "openfold-default"
156161 data_seed : int = 42
157162 epoch_len : int = 1
158163
164+ @staticmethod
165+ def safe_multiprocessing_context (
166+ multiprocessing_context : str | None , num_workers : int
167+ ) -> str | None :
168+ """
169+ Returns multiprocessing start methods with safer/sensible defaults:
170+ - fork when using MPS
171+ - forkserver for linux, matching the new 3.14 default
172+ - default otherwise
173+
174+ For general info on risks and defaults across platformas and python versions see:
175+ https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
176+ https://docs.pytorch.org/docs/stable/notes/multiprocessing.html#multiprocessing-poison-fork-note
177+ https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
178+ """
179+
180+ # Do not bother if not using multiprocessing
181+ if num_workers == 0 :
182+ return None
183+
184+ # Set safe defaults
185+ if multiprocessing_context == "openfold-default" :
186+
187+ # Use fork to create processes when using MPS. See:
188+ # - https://github.com/pytorch/pytorch/issues/70344
189+ # - https://github.com/pytorch/pytorch/issues/87688
190+ if platform .system () == "Darwin" and torch .backends .mps .is_available ():
191+ return "fork"
192+
193+ # Use forkserver in linux
194+ # Backports the new python 3.14 default in previous python versions.
195+ # An alternative for further safety would be "spawn". Avoid "fork".
196+ # See: https://github.com/python/cpython/issues/84559
197+ if platform .system () == "linux" :
198+ return "forkserver"
199+
200+ # Use the platform default otherwise - "spawn" at the time of writing
201+ return multiprocessing .get_start_method ()
202+
203+ # Warn about unsafe defaults
204+ else :
205+ if platform .system () == "Darwin" and torch .backends .mps .is_available ():
206+ if multiprocessing_context != "fork" :
207+ logger .warning (
208+ f"Using multiprocessing context { multiprocessing_context } on MPS may cause "
209+ "issues. Consider using 'fork' or 'openfold-default' (which resolves to 'fork' on MPS)." ,
210+ stacklevel = 2 ,
211+ )
212+ if platform .system () == "linux" :
213+ dangerous_start_method = (
214+ multiprocessing_context == "fork" or
215+ multiprocessing_context is None and sys .version_info < (3 , 14 )
216+ )
217+ if dangerous_start_method :
218+ logger .warning (
219+ "Using 'fork' multiprocessing context in linux may cause issues. Consider using "
220+ "'spawn', 'forkserver' or 'openfold-default' (which resolves to 'forkserver' on linux)." ,
221+ stacklevel = 2 ,
222+ )
223+
224+ return multiprocessing_context
225+
159226
160227class DataModule (pl .LightningDataModule ):
161228 """A LightningDataModule class for organizing Datasets and DataLoaders."""
@@ -167,6 +234,7 @@ def __init__(self, data_module_config: DataModuleConfig) -> None:
167234 self .batch_size = data_module_config .batch_size
168235 self .num_workers = data_module_config .num_workers
169236 self .num_workers_validation = data_module_config .num_workers_validation
237+ self .multiprocessing_context = data_module_config .safe_multiprocessing_context
170238 self .data_seed = data_module_config .data_seed
171239 self .next_data_seed = data_module_config .data_seed
172240 self .epoch_len = data_module_config .epoch_len
@@ -433,8 +501,17 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
433501 # instead of pl.seed_everything(workers=True), so this function is
434502 # passed explicitly here.
435503 worker_init_fn = partial (pl_worker_init_function , rank = self .global_rank )
504+
505+ # Set a sensible default for multiprocesssing start method
506+ # depending on platform and python version.
507+ multiprocessing_context = DataModuleConfig .safe_multiprocessing_context (
508+ self .multiprocessing_context , num_workers
509+ )
510+
436511 logger .debug (
437- f"Creating { mode } dataloader: num_workers={ num_workers } , "
512+ f"Creating { mode } dataloader: "
513+ f"num_workers={ num_workers } , "
514+ f"multiprocessing_context={ multiprocessing_context } , "
438515 f"rank={ self .global_rank } ."
439516 )
440517 return DataLoader (
@@ -445,10 +522,7 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
445522 collate_fn = openfold_batch_collator ,
446523 generator = self .generators [mode ],
447524 worker_init_fn = worker_init_fn ,
448- # https://github.com/pytorch/pytorch/issues/87688
449- multiprocessing_context = "fork"
450- if torch .backends .mps .is_available () and num_workers
451- else None ,
525+ multiprocessing_context = multiprocessing_context ,
452526 )
453527
454528 def train_dataloader (self ) -> DataLoader :
0 commit comments