Skip to content

Commit 59eb3cb

Browse files
committed
Add world-size getter in Engine
Signed-off-by: WoosungMyung <[email protected]>
1 parent 1d7b90a commit 59eb3cb

File tree

3 files changed

+64
-0
lines changed

3 files changed

+64
-0
lines changed

deepspeed/runtime/engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,10 @@ def get_tensor_parallel_rank(self):
748748
def get_model_parallel_rank(self):
749749
return groups.get_model_parallel_rank()
750750

751+
def get_parallel_world_sizes(self):
752+
"""Return a dict of parallel world sizes for data/tensor parallelism."""
753+
return {"dp": groups.get_data_parallel_world_size(), "tp": groups.get_tensor_model_parallel_world_size()}
754+
751755
def get_sequence_parallel_group(self):
752756
return self.seq_parallel_group
753757

deepspeed/runtime/pipe/engine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,12 @@ def is_last_stage(self):
537537
def get_pipeline_parallel_rank(self):
538538
return self.stage_id
539539

540+
def get_parallel_world_sizes(self):
541+
"""Return a dict of parallel world sizes for data/tensor/pipeline parallelism."""
542+
sizes = super().get_parallel_world_sizes()
543+
sizes["pp"] = self.num_stages
544+
return sizes
545+
540546
def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True, micro_batches=None):
541547
if reduce is None:
542548
return outputs
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import pytest
7+
from deepspeed.runtime.engine import DeepSpeedEngine
8+
from deepspeed.runtime.pipe.engine import PipelineEngine
9+
10+
11+
# Silence destructors because instances are created via __new__ (no init)
12+
@pytest.fixture(autouse=True)
13+
def _silence_engine_destructors(monkeypatch):
14+
monkeypatch.setattr(DeepSpeedEngine, "__del__", lambda self: None, raising=False)
15+
monkeypatch.setattr(PipelineEngine, "__del__", lambda self: None, raising=False)
16+
monkeypatch.setattr(DeepSpeedEngine, "destroy", lambda self: None, raising=False)
17+
monkeypatch.setattr(PipelineEngine, "destroy", lambda self: None, raising=False)
18+
19+
20+
# Skip if methods are absent (e.g., running against an older DS build)
21+
if (not hasattr(DeepSpeedEngine, "get_parallel_world_sizes")
22+
or not hasattr(PipelineEngine, "get_parallel_world_sizes")):
23+
pytest.skip("Required methods missing on this DeepSpeed build.", allow_module_level=True)
24+
25+
26+
def _patch_groups(monkeypatch, dp=8, tp=4):
27+
"""Patch deepspeed.utils.groups to avoid initializing any distributed backend."""
28+
import deepspeed.utils.groups as groups
29+
monkeypatch.setattr(groups, "get_data_parallel_world_size", lambda: dp, raising=True)
30+
monkeypatch.setattr(groups, "get_tensor_model_parallel_world_size", lambda: tp, raising=True)
31+
32+
33+
def _make_engine():
34+
"""Create engine without running __init__ to avoid side effects."""
35+
return DeepSpeedEngine.__new__(DeepSpeedEngine)
36+
37+
38+
def _make_pipeline_engine(num_stages=6):
39+
"""Create pipeline engine without init; set the minimal required attribute."""
40+
pe = PipelineEngine.__new__(PipelineEngine)
41+
pe.num_stages = num_stages
42+
return pe
43+
44+
45+
def test_deepspeedengine_get_parallel_world_sizes(monkeypatch):
46+
_patch_groups(monkeypatch, dp=8, tp=4)
47+
eng = _make_engine()
48+
assert eng.get_parallel_world_sizes() == {"dp": 8, "tp": 4}
49+
50+
51+
def test_pipelineengine_get_parallel_world_sizes(monkeypatch):
52+
_patch_groups(monkeypatch, dp=8, tp=4)
53+
peng = _make_pipeline_engine(num_stages=6)
54+
assert peng.get_parallel_world_sizes() == {"dp": 8, "tp": 4, "pp": 6}

0 commit comments

Comments
 (0)