Skip to content

Commit 55d26b1

Browse files
authored
fix(Multiprocessing): Fully disable multiprocessing when not used (#281)
This PR fully disables multiprocessing when `max_processes = 1`. This speeds up execution on windows based systems as the overhead of spawning processing using `spawn` context is large. For this reason, this PR also disabled the multiprocessing checkpointing tests for all systems not running Linux since Linux by default uses `fork` which comes with less overhead. Finally, a better error message is given when users runs unguarded benchmarks (main script is not wrapped in `if __name__ == '__main__'`) while using multiprocessing (`max_processes > 1`). closes #279
1 parent 45f3100 commit 55d26b1

6 files changed

Lines changed: 102 additions & 60 deletions

File tree

decent_bench/benchmark/_benchmark.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,7 @@ def resume_benchmark( # noqa: PLR0912
135135
raise ValueError(f"Invalid checkpoint directory: metadata is not valid JSON - {e}") from e
136136

137137
if create_backup:
138-
backup_path = checkpoint_manager.create_backup()
139-
LOGGER.info(f"Created backup of checkpoint directory at '{backup_path}'")
138+
checkpoint_manager.create_backup()
140139

141140
LOGGER.info(
142141
f"Resuming benchmark from checkpoint '{checkpoint_manager.checkpoint_dir}' with {metadata['n_trials']} trials "
@@ -189,7 +188,8 @@ def resume_benchmark( # noqa: PLR0912
189188
checkpoint_manager=checkpoint_manager,
190189
runtime_metrics=runtime_metrics,
191190
)
192-
log_listener.stop()
191+
if log_listener is not None:
192+
log_listener.stop()
193193
return results
194194

195195

@@ -282,15 +282,16 @@ def benchmark(
282282
checkpoint_manager=checkpoint_manager,
283283
runtime_metrics=runtime_metrics,
284284
)
285-
log_listener.stop()
285+
if log_listener is not None:
286+
log_listener.stop()
286287
return results
287288

288289

289290
def _benchmark(
290291
algorithms: list[Algorithm[Network]],
291292
benchmark_problem: BenchmarkProblem,
292-
log_listener: QueueListener,
293-
manager: "SyncManager",
293+
log_listener: QueueListener | None,
294+
manager: "SyncManager | None",
294295
*,
295296
mp_context: "SpawnContext | None" = None,
296297
n_trials: int = 30,
@@ -374,16 +375,26 @@ def _init_logging_and_multiprocessing(
374375
log_level: int,
375376
max_processes: int | None,
376377
benchmark_problem: BenchmarkProblem,
377-
) -> tuple[QueueListener, "SyncManager", "SpawnContext | None"]:
378+
) -> tuple[QueueListener | None, "SyncManager | None", "SpawnContext | None"]:
378379
# Detect if PyTorch costs are being used to determine multiprocessing context
379-
if max_processes != 1:
380-
use_spawn = _should_use_spawn_context(benchmark_problem)
381-
mp_context = get_context("spawn") if use_spawn else None
382-
else:
383-
use_spawn = False
384-
mp_context = None
385-
386-
manager = Manager() if not use_spawn else get_context("spawn").Manager()
380+
if max_processes == 1:
381+
logger.start_logger(log_level)
382+
return None, None, None
383+
384+
use_spawn = _should_use_spawn_context(benchmark_problem)
385+
mp_context = get_context("spawn") if use_spawn else None
386+
try:
387+
manager = Manager() if mp_context is None else mp_context.Manager()
388+
except RuntimeError as e:
389+
if _is_multiprocessing_main_guard_error(e):
390+
raise RuntimeError(
391+
"Failed to start multiprocessing workers. Benchmark execution "
392+
"must be launched inside a guarded main entrypoint. Wrap your benchmark call in:\n\n"
393+
"if __name__ == '__main__':\n"
394+
" ... call decent_bench.benchmark(...)\n\n"
395+
"This prevents child processes from re-running top-level script code during import."
396+
) from e
397+
raise
387398
log_listener = logger.start_log_listener(manager, log_level)
388399

389400
if use_spawn:
@@ -392,12 +403,18 @@ def _init_logging_and_multiprocessing(
392403
return log_listener, manager, mp_context
393404

394405

406+
def _is_multiprocessing_main_guard_error(exc: RuntimeError) -> bool:
407+
"""Return True for the common spawn bootstrap error caused by missing main guard."""
408+
msg = str(exc)
409+
return "start a new process before the" in msg and "bootstrapping phase" in msg
410+
411+
395412
def _run_trials( # noqa: PLR0917
396413
algorithms: list[Algorithm[Network]],
397414
n_trials: int,
398415
problem: BenchmarkProblem,
399416
progress_bar_ctrl: ProgressBarController,
400-
log_listener: QueueListener,
417+
log_listener: QueueListener | None,
401418
max_processes: int | None,
402419
mp_context: "SpawnContext | None" = None,
403420
checkpoint_manager: "CheckpointManager | None" = None,
@@ -467,6 +484,12 @@ def _run_trials( # noqa: PLR0917
467484
if max_processes == 1:
468485
partial_result = {alg: [_run_trial(*args) for args in trial_args[alg]] for alg in trial_args}
469486
else:
487+
if log_listener is None:
488+
# This shouldn't happen: internal invariant violation
489+
raise RuntimeError(
490+
"Log listener must be initialized for multiprocessing to handle logs from worker processes"
491+
)
492+
470493
with ProcessPoolExecutor(
471494
initializer=logger.start_queue_logger,
472495
initargs=(log_listener.queue,),

decent_bench/utils/progress_bar.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class ProgressBarController:
167167
Args:
168168
manager: A multiprocessing :class:`~multiprocessing.managers.SyncManager` instance used to create a shared queue
169169
for coordinating progress updates across multiple processes. This enables thread-safe communication between
170-
worker processes and the progress bar listener thread.
170+
worker processes and the progress bar listener thread. If ``None``, a local in-process queue is used.
171171
algorithms: algorithms that will be run, each gets its own bar
172172
n_trials: number of trials the algorithms will run
173173
progress_step: if provided, the progress bar will step every `progress_step`.
@@ -182,14 +182,17 @@ class ProgressBarController:
182182

183183
def __init__( # noqa: PLR0917
184184
self,
185-
manager: SyncManager,
185+
manager: SyncManager | None,
186186
algorithms: Sequence[Algorithm[Any]],
187187
n_trials: int,
188188
progress_step: int | None,
189189
show_speed: bool = False,
190190
show_trial: bool = False,
191191
):
192-
self._progress_increment_queue: Queue[_ProgressRecord | None] = manager.Queue()
192+
# Use a local queue for single-process runs to avoid multiprocessing manager overhead.
193+
self._progress_increment_queue: Queue[_ProgressRecord | None] = (
194+
manager.Queue() if manager is not None else Queue()
195+
)
193196
self.progress_step = progress_step
194197
p_cols = [
195198
(TextColumn("{task.description}"), Text("Algorithm", style="bold")),

docs/source/api/decent_bench.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ decent\_bench
1717
decent_bench.costs
1818
decent_bench.datasets
1919
decent_bench.distributed_algorithms
20-
decent_bench.utils.network_utils
2120
decent_bench.networks
2221
decent_bench.schemes
2322

docs/source/user.rst

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ Generally benchmark execution involves three steps:
2222
2. Compute metrics from the benchmark results, which returns a :class:`~decent_bench.benchmark.MetricResult` object.
2323
3. Display the computed metrics in tables and plots.
2424

25+
Note:
26+
When running benchmarks, be sure to guard the execution code with ``if __name__ == "__main__":`` to avoid issues with multiprocessing on some platforms (e.g., Windows).
27+
This is a common Python practice to ensure that the benchmark code only runs when the script is executed directly, and not when it is imported as a module or when worker
28+
processes are spawned for multiprocessing. If you forget to include this guard and you are using multiprocessing, i.e. with ``max_processes > 1`` in :func:`~decent_bench.benchmark.benchmark`,
29+
you may encounter errors or unexpected behavior due to the way multiprocessing works on different platforms.
30+
2531
**The following is a working example. The remainder of the user guide will be updated soon.**
2632

2733
.. code-block:: python
@@ -36,39 +42,40 @@ Generally benchmark execution involves three steps:
3642
3743
import networkx as nx
3844
39-
## problem definition
40-
n_agents = 10
45+
if __name__ == "__main__":
46+
## problem definition
47+
n_agents = 10
4148
42-
costs, x_optimal = create_quadratic_problem(10, n_agents)
49+
costs, x_optimal = create_quadratic_problem(10, n_agents)
4350
44-
agents = [Agent(i, cost) for i, cost in enumerate(costs)]
45-
graph = nx.complete_graph(n_agents)
46-
47-
net = P2PNetwork(
48-
graph=graph,
49-
agents=agents,
50-
)
51+
agents = [Agent(i, cost) for i, cost in enumerate(costs)]
52+
graph = nx.complete_graph(n_agents)
53+
54+
net = P2PNetwork(
55+
graph=graph,
56+
agents=agents,
57+
)
5158
52-
bp = benchmark.BenchmarkProblem(net, x_optimal)
59+
bp = benchmark.BenchmarkProblem(net, x_optimal)
5360
54-
## benchmarking
55-
cm = CheckpointManager(checkpoint_dir="results/benchmark_1", checkpoint_step=100, keep_n_checkpoints=2)
61+
## benchmarking
62+
cm = CheckpointManager(checkpoint_dir="results/benchmark_1", checkpoint_step=100, keep_n_checkpoints=2)
5663
57-
num_iter = 1000
58-
step = 0.001
64+
num_iter = 1000
65+
step = 0.001
5966
60-
res = benchmark.benchmark(algorithms=[
61-
DGD(iterations=num_iter, step_size=step),
62-
ATC(iterations=num_iter, step_size=step),
63-
],
64-
benchmark_problem=bp,
65-
checkpoint_manager=cm,
66-
n_trials=1,
67-
)
67+
res = benchmark.benchmark(algorithms=[
68+
DGD(iterations=num_iter, step_size=step),
69+
ATC(iterations=num_iter, step_size=step),
70+
],
71+
benchmark_problem=bp,
72+
checkpoint_manager=cm,
73+
n_trials=1,
74+
)
6875
69-
metr = benchmark.compute_metrics(res, checkpoint_manager=cm)
76+
metr = benchmark.compute_metrics(res, checkpoint_manager=cm)
7077
71-
benchmark.display_metrics(metr, checkpoint_manager=cm)
78+
benchmark.display_metrics(metr, checkpoint_manager=cm)
7279
7380
7481
Benchmark executions will have outputs like these:

readthedocs.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ build:
99
os: ubuntu-24.04
1010
tools:
1111
python: "3.13"
12+
jobs:
13+
build:
14+
html:
15+
- mkdir -p $READTHEDOCS_OUTPUT/html/
16+
- python -m sphinx -T -W --keep-going -j 1 -b html -d _build/doctrees -D language=en docs/source $READTHEDOCS_OUTPUT/html
1217

1318
# Build documentation in the "docs/" directory with Sphinx
1419
sphinx:

test/utils/test_checkpoints.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import os
44
import random
5+
import sys
56
from copy import deepcopy
67
from dataclasses import dataclass
78
from pathlib import Path
@@ -38,6 +39,15 @@
3839
# Suppress JAX debug logs that cause issues during cleanup
3940
logging.getLogger("jax").setLevel(logging.WARNING)
4041

42+
IS_LINUX = sys.platform.startswith("linux")
43+
LINUX_ONLY_MP_GT1 = pytest.mark.skipif(not IS_LINUX, reason="max_processes > 1 is Linux-only")
44+
45+
46+
def _skip_if_max_processes_exceeds_cpu_count(max_processes: int) -> None:
47+
cpu_count = os.cpu_count()
48+
if cpu_count is not None and max_processes > cpu_count:
49+
pytest.skip(f"max_processes={max_processes} exceeds available CPU cores")
50+
4151

4252
@dataclass(eq=False)
4353
class DummyAlg(DGD):
@@ -298,7 +308,7 @@ def test_create_backup_and_clear(tmp_path: Path) -> None: # noqa: D103
298308
("cost_cls", "max_processes"),
299309
[
300310
(LogisticRegressionCost, 1),
301-
(LogisticRegressionCost, 2),
311+
pytest.param(LogisticRegressionCost, 2, marks=LINUX_ONLY_MP_GT1),
302312
pytest.param(
303313
PyTorchCost,
304314
1,
@@ -315,8 +325,7 @@ def test_resume_from_checkpoint_with_additional_trials(
315325
max_processes: int,
316326
seed: int | None,
317327
) -> None:
318-
if os.cpu_count() is not None and max_processes > os.cpu_count():
319-
pytest.skip(f"max_processes={max_processes} exceeds available CPU cores")
328+
_skip_if_max_processes_exceeds_cpu_count(max_processes)
320329

321330
if seed is not None:
322331
iop.set_seed(seed)
@@ -405,7 +414,7 @@ def test_resume_from_checkpoint_with_additional_trials(
405414
("cost_cls", "max_processes"),
406415
[
407416
(LogisticRegressionCost, 1),
408-
(LogisticRegressionCost, 2),
417+
pytest.param(LogisticRegressionCost, 2, marks=LINUX_ONLY_MP_GT1),
409418
pytest.param(
410419
PyTorchCost,
411420
1,
@@ -422,8 +431,7 @@ def test_resume_from_checkpoint_with_additional_iterations(
422431
max_processes: int,
423432
seed: int | None,
424433
) -> None:
425-
if os.cpu_count() is not None and max_processes > os.cpu_count():
426-
pytest.skip(f"max_processes={max_processes} exceeds available CPU cores")
434+
_skip_if_max_processes_exceeds_cpu_count(max_processes)
427435

428436
if seed is not None:
429437
iop.set_seed(seed)
@@ -514,7 +522,7 @@ def test_resume_from_checkpoint_with_additional_iterations(
514522
("cost_cls", "max_processes"),
515523
[
516524
(LogisticRegressionCost, 1),
517-
(LogisticRegressionCost, 2),
525+
pytest.param(LogisticRegressionCost, 2, marks=LINUX_ONLY_MP_GT1),
518526
pytest.param(
519527
PyTorchCost,
520528
1,
@@ -531,8 +539,7 @@ def test_resume_from_checkpoint_with_additional_iterations_and_trials(
531539
max_processes: int,
532540
seed: int | None,
533541
) -> None:
534-
if os.cpu_count() is not None and max_processes > os.cpu_count():
535-
pytest.skip(f"max_processes={max_processes} exceeds available CPU cores")
542+
_skip_if_max_processes_exceeds_cpu_count(max_processes)
536543

537544
if seed is not None:
538545
iop.set_seed(seed)
@@ -624,7 +631,7 @@ def test_resume_from_checkpoint_with_additional_iterations_and_trials(
624631
("cost_cls", "max_processes"),
625632
[
626633
(LogisticRegressionCost, 1),
627-
(LogisticRegressionCost, 2),
634+
pytest.param(LogisticRegressionCost, 2, marks=LINUX_ONLY_MP_GT1),
628635
pytest.param(
629636
PyTorchCost,
630637
1,
@@ -641,8 +648,7 @@ def test_resume_from_non_completed_checkpoint(
641648
max_processes: int,
642649
seed: int | None,
643650
) -> None:
644-
if os.cpu_count() is not None and max_processes > os.cpu_count():
645-
pytest.skip(f"max_processes={max_processes} exceeds available CPU cores")
651+
_skip_if_max_processes_exceeds_cpu_count(max_processes)
646652

647653
if seed is not None:
648654
iop.set_seed(seed)
@@ -774,7 +780,7 @@ def test_resume_from_non_completed_checkpoint(
774780
("cost_cls", "max_processes"),
775781
[
776782
(LogisticRegressionCost, 1),
777-
(LogisticRegressionCost, 2),
783+
pytest.param(LogisticRegressionCost, 2, marks=LINUX_ONLY_MP_GT1),
778784
pytest.param(
779785
PyTorchCost,
780786
1,
@@ -789,8 +795,7 @@ def test_back_to_back_benchmarks(
789795
cost_cls: type[LogisticRegressionCost | PyTorchCost],
790796
max_processes: int,
791797
) -> None:
792-
if os.cpu_count() is not None and max_processes > os.cpu_count():
793-
pytest.skip(f"max_processes={max_processes} exceeds available CPU cores")
798+
_skip_if_max_processes_exceeds_cpu_count(max_processes)
794799

795800
iop.set_seed(123)
796801
problem_5, algorithms_5 = _build_problem_and_algorithms(5, cost_cls=cost_cls)

0 commit comments

Comments
 (0)