Skip to content

Commit 23811a4

Browse files
committed
Allow to configure multiprocessing start and set safe defaults
We would still need to document this for users
1 parent 33b6bca commit 23811a4

1 file changed

Lines changed: 79 additions & 5 deletions

File tree

openfold3/core/data/framework/data_module.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
import dataclasses
4343
import enum
4444
import logging
45+
import multiprocessing
46+
import platform
47+
import sys
4548
import warnings
4649
from functools import partial
4750
from typing import Any
@@ -148,14 +151,78 @@ def get_config_for_mode(self, mode: DatasetMode) -> "MultiDatasetConfig":
148151
return self.get_subset(datasets_stage_mask)
149152

150153

154+
151155
class DataModuleConfig(BaseModel):
152156
datasets: list[SerializeAsAny[BaseModel]]
153157
batch_size: int = 1
154158
num_workers: int = 0
155159
num_workers_validation: int = 0
160+
multiprocessing_context: str = "openfold-default"
156161
data_seed: int = 42
157162
epoch_len: int = 1
158163

164+
@staticmethod
165+
def safe_multiprocessing_context(
166+
multiprocessing_context: str | None, num_workers: int
167+
) -> str | None:
168+
"""
169+
Returns multiprocessing start methods with safer/sensible defaults:
170+
- fork when using MPS
171+
- forkserver for linux, matching the new 3.14 default
172+
- default otherwise
173+
174+
For general info on risks and defaults across platformas and python versions see:
175+
https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
176+
https://docs.pytorch.org/docs/stable/notes/multiprocessing.html#multiprocessing-poison-fork-note
177+
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
178+
"""
179+
180+
# Do not bother if not using multiprocessing
181+
if num_workers == 0:
182+
return None
183+
184+
# Set safe defaults
185+
if multiprocessing_context == "openfold-default":
186+
187+
# Use fork to create processes when using MPS. See:
188+
# - https://github.com/pytorch/pytorch/issues/70344
189+
# - https://github.com/pytorch/pytorch/issues/87688
190+
if platform.system() == "Darwin" and torch.backends.mps.is_available():
191+
return "fork"
192+
193+
# Use forkserver in linux
194+
# Backports the new python 3.14 default in previous python versions.
195+
# An alternative for further safety would be "spawn". Avoid "fork".
196+
# See: https://github.com/python/cpython/issues/84559
197+
if platform.system() == "linux":
198+
return "forkserver"
199+
200+
# Use the platform default otherwise - "spawn" at the time of writing
201+
return multiprocessing.get_start_method()
202+
203+
# Warn about unsafe defaults
204+
else:
205+
if platform.system() == "Darwin" and torch.backends.mps.is_available():
206+
if multiprocessing_context != "fork":
207+
logger.warning(
208+
f"Using multiprocessing context {multiprocessing_context} on MPS may cause "
209+
"issues. Consider using 'fork' or 'openfold-default' (which resolves to 'fork' on MPS).",
210+
stacklevel=2,
211+
)
212+
if platform.system() == "linux":
213+
dangerous_start_method = (
214+
multiprocessing_context == "fork" or
215+
multiprocessing_context is None and sys.version_info < (3, 14)
216+
)
217+
if dangerous_start_method:
218+
logger.warning(
219+
"Using 'fork' multiprocessing context in linux may cause issues. Consider using "
220+
"'spawn', 'forkserver' or 'openfold-default' (which resolves to 'forkserver' on linux).",
221+
stacklevel=2,
222+
)
223+
224+
return multiprocessing_context
225+
159226

160227
class DataModule(pl.LightningDataModule):
161228
"""A LightningDataModule class for organizing Datasets and DataLoaders."""
@@ -167,6 +234,7 @@ def __init__(self, data_module_config: DataModuleConfig) -> None:
167234
self.batch_size = data_module_config.batch_size
168235
self.num_workers = data_module_config.num_workers
169236
self.num_workers_validation = data_module_config.num_workers_validation
237+
self.multiprocessing_context = data_module_config.safe_multiprocessing_context
170238
self.data_seed = data_module_config.data_seed
171239
self.next_data_seed = data_module_config.data_seed
172240
self.epoch_len = data_module_config.epoch_len
@@ -433,8 +501,17 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
433501
# instead of pl.seed_everything(workers=True), so this function is
434502
# passed explicitly here.
435503
worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)
504+
505+
# Set a sensible default for multiprocesssing start method
506+
# depending on platform and python version.
507+
multiprocessing_context = DataModuleConfig.safe_multiprocessing_context(
508+
self.multiprocessing_context, num_workers
509+
)
510+
436511
logger.debug(
437-
f"Creating {mode} dataloader: num_workers={num_workers}, "
512+
f"Creating {mode} dataloader: "
513+
f"num_workers={num_workers}, "
514+
f"multiprocessing_context={multiprocessing_context}, "
438515
f"rank={self.global_rank}."
439516
)
440517
return DataLoader(
@@ -445,10 +522,7 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
445522
collate_fn=openfold_batch_collator,
446523
generator=self.generators[mode],
447524
worker_init_fn=worker_init_fn,
448-
# https://github.com/pytorch/pytorch/issues/87688
449-
multiprocessing_context="fork"
450-
if torch.backends.mps.is_available() and num_workers
451-
else None,
525+
multiprocessing_context=multiprocessing_context,
452526
)
453527

454528
def train_dataloader(self) -> DataLoader:

0 commit comments

Comments
 (0)