diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index af182ad7f422f..850a1f966edbc 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -124,10 +124,9 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self._determine_ddp_device_ids() - # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() - with ctx: - return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) + # Use default stream for DDP init to match subsequent forwards/backwards and avoid + # AccumulateGrad stream mismatch warning (see pytorch/pytorch#input_buffer.cpp) + return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) @override def module_to_device(self, module: Module) -> None: diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 4eca6159ddced..87e3c518a5cd4 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from contextlib import nullcontext from datetime import timedelta from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union @@ -190,10 +189,9 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") - # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() - with ctx: - return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) + # Use default stream for DDP init to match subsequent forwards/backwards and avoid + # AccumulateGrad stream mismatch warning (see pytorch/pytorch#input_buffer.cpp) + return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) def setup_distributed(self) -> None: log.debug(f"{self.__class__.__name__}: setting up distributed...") diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index f302da5d1bc4f..0e0658af0ba1a 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -145,9 +145,9 @@ def test_module_init_context(precision, expected_dtype): @mock.patch.dict(os.environ, {"LOCAL_RANK": "0"}) @mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel") -@mock.patch("torch.cuda.Stream") @mock.patch("torch.cuda.stream") -def test_setup_with_cuda_stream(cuda_stream_mock, *_): +def test_setup_uses_default_stream(cuda_stream_mock, *_): + """DDP setup uses default stream to avoid AccumulateGrad stream mismatch warning.""" model = torch.nn.Linear(2, 2) strategy = DDPStrategy(parallel_devices=[torch.device("cpu")], cluster_environment=LightningEnvironment()) strategy.setup_module(model) @@ -155,7 +155,7 @@ def test_setup_with_cuda_stream(cuda_stream_mock, *_): strategy = DDPStrategy(parallel_devices=[torch.device("cuda", 0)], cluster_environment=LightningEnvironment()) strategy.setup_module(model) - cuda_stream_mock.assert_called_once() + cuda_stream_mock.assert_not_called() @mock.patch("torch.distributed.init_process_group")