|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +# DeepSpeed Team |
| 5 | + |
| 6 | +import pytest |
| 7 | +from deepspeed.runtime.engine import DeepSpeedEngine |
| 8 | +from deepspeed.runtime.pipe.engine import PipelineEngine |
| 9 | + |
| 10 | + |
| 11 | +# Silence destructors because instances are created via __new__ (no init) |
| 12 | +@pytest.fixture(autouse=True) |
| 13 | +def _silence_engine_destructors(monkeypatch): |
| 14 | + monkeypatch.setattr(DeepSpeedEngine, "__del__", lambda self: None, raising=False) |
| 15 | + monkeypatch.setattr(PipelineEngine, "__del__", lambda self: None, raising=False) |
| 16 | + monkeypatch.setattr(DeepSpeedEngine, "destroy", lambda self: None, raising=False) |
| 17 | + monkeypatch.setattr(PipelineEngine, "destroy", lambda self: None, raising=False) |
| 18 | + |
| 19 | + |
| 20 | +# Skip if methods are absent (e.g., running against an older DS build) |
| 21 | +if (not hasattr(DeepSpeedEngine, "get_parallel_world_sizes") |
| 22 | + or not hasattr(PipelineEngine, "get_parallel_world_sizes")): |
| 23 | + pytest.skip("Required methods missing on this DeepSpeed build.", allow_module_level=True) |
| 24 | + |
| 25 | + |
| 26 | +def _patch_groups(monkeypatch, dp=8, tp=4): |
| 27 | + """Patch deepspeed.utils.groups to avoid initializing any distributed backend.""" |
| 28 | + import deepspeed.utils.groups as groups |
| 29 | + monkeypatch.setattr(groups, "get_data_parallel_world_size", lambda: dp, raising=True) |
| 30 | + monkeypatch.setattr(groups, "get_tensor_model_parallel_world_size", lambda: tp, raising=True) |
| 31 | + |
| 32 | + |
| 33 | +def _make_engine(): |
| 34 | + """Create engine without running __init__ to avoid side effects.""" |
| 35 | + return DeepSpeedEngine.__new__(DeepSpeedEngine) |
| 36 | + |
| 37 | + |
| 38 | +def _make_pipeline_engine(num_stages=6): |
| 39 | + """Create pipeline engine without init; set the minimal required attribute.""" |
| 40 | + pe = PipelineEngine.__new__(PipelineEngine) |
| 41 | + pe.num_stages = num_stages |
| 42 | + return pe |
| 43 | + |
| 44 | + |
| 45 | +def test_deepspeedengine_get_parallel_world_sizes(monkeypatch): |
| 46 | + _patch_groups(monkeypatch, dp=8, tp=4) |
| 47 | + eng = _make_engine() |
| 48 | + assert eng.get_parallel_world_sizes() == {"dp": 8, "tp": 4} |
| 49 | + |
| 50 | + |
| 51 | +def test_pipelineengine_get_parallel_world_sizes(monkeypatch): |
| 52 | + _patch_groups(monkeypatch, dp=8, tp=4) |
| 53 | + peng = _make_pipeline_engine(num_stages=6) |
| 54 | + assert peng.get_parallel_world_sizes() == {"dp": 8, "tp": 4, "pp": 6} |
0 commit comments