Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,7 @@ def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")

def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
value = (
value.clone().detach().to(self.device)
if isinstance(value, Tensor)
else torch.tensor(value, device=self.device)
)
value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value, device=self.device)
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
Expand Down
3 changes: 0 additions & 3 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,6 @@ def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks."""
trainer = self.trainer

assert self._results is not None
self._results.to(device=trainer.lightning_module.device)

hook_name = "on_test_start" if trainer.testing else "on_validation_start"
call._call_callback_hooks(trainer, hook_name, *args, **kwargs)
call._call_lightning_module_hook(trainer, hook_name, *args, **kwargs)
Expand Down
2 changes: 0 additions & 2 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,6 @@ def on_run_start(self) -> None:

self._data_fetcher = _select_data_fetcher(trainer)

self._results.to(device=trainer.lightning_module.device)

call._call_callback_hooks(trainer, "on_train_start")
call._call_lightning_module_hook(trainer, "on_train_start")
call._call_strategy_hook(trainer, "on_train_start")
Expand Down
9 changes: 8 additions & 1 deletion src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import nullcontext
from datetime import timedelta
from typing import Any, Callable, Dict, List, Literal, Optional, Union

Expand Down Expand Up @@ -182,7 +183,13 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
if torch.cuda.is_available():
ctx = torch.cuda.stream(torch.cuda.Stream())
else:
ctx = nullcontext()
with ctx:
ddp_model = DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
return ddp_model

def setup_distributed(self) -> None:
log.debug(f"{self.__class__.__name__}: setting up distributed...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,9 @@ class _ResultCollection(dict):

DATALOADER_SUFFIX = "/dataloader_idx_{}"

def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None:
def __init__(self, training: bool) -> None:
super().__init__()
self.training = training
self.device: Optional[Union[str, torch.device]] = device
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.dataloader_idx: Optional[int] = None
Expand Down Expand Up @@ -410,13 +409,13 @@ def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None:

Value can be provided as a nested collection
"""
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(self.device)
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device)
self[key] = metric

def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
result_metric = self[key]
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
result_metric.forward(value.to(self.device), batch_size)
result_metric.forward(value, batch_size)
result_metric.has_reset = False

@staticmethod
Expand Down Expand Up @@ -509,9 +508,6 @@ def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> Non
def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection":
"""Move all data to the given device."""
self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs))

if "device" in kwargs:
self.device = kwargs["device"]
return self

def cpu(self) -> "_ResultCollection":
Expand All @@ -524,4 +520,4 @@ def __str__(self) -> str:
return f"{self.__class__.__name__}({self_str})"

def __repr__(self) -> str:
return f"{{{self.training}, {repr(self.device)}, {super().__repr__()}}}"
return f"{{{self.training}, {super().__repr__()}}}"
9 changes: 4 additions & 5 deletions tests/tests_pytorch/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def result_reduce_ddp_fn(strategy):
metric_b = metric_b.to(f"cuda:{rank}")
metric_c = metric_c.to(f"cuda:{rank}")

result = _ResultCollection(True, torch.device(f"cuda:{rank}"))
result = _ResultCollection(True)

for _ in range(3):
cumulative_sum = 0
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_result_metric_integration():
metric_b = DummyMetric()
metric_c = DummyMetric()

result = _ResultCollection(True, torch.device("cpu"))
result = _ResultCollection(True)

for _ in range(3):
cumulative_sum = 0
Expand Down Expand Up @@ -148,7 +148,6 @@ def test_result_metric_integration():
assert repr(result) == (
"{"
"True, "
"device(type='cpu'), "
"{'h.a': _ResultMetric('a', value=DummyMetric()), "
"'h.b': _ResultMetric('b', value=DummyMetric()), "
"'h.c': _ResultMetric('c', value=DummyMetric())"
Expand All @@ -157,7 +156,7 @@ def test_result_metric_integration():


def test_result_collection_simple_loop():
result = _ResultCollection(True, torch.device("cpu"))
result = _ResultCollection(True)
current_fx_name = None
batch_idx = None

Expand Down Expand Up @@ -205,7 +204,7 @@ def my_sync_dist(x, *_, **__):
def test_result_collection_restoration(tmpdir):
"""This test make sure metrics are properly reloaded on failure."""

result = _ResultCollection(True, torch.device("cpu"))
result = _ResultCollection(True)
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ def on_validation_epoch_end(self) -> None:
assert set(trainer.callback_metrics) == {"val_loss", "val_loss_epoch"}

# make sure values are correct
assert trainer.logged_metrics["val_loss_epoch"] == model.manual_epoch_end_mean
assert trainer.callback_metrics["val_loss_epoch"] == model.manual_epoch_end_mean
assert trainer.callback_metrics["val_loss"] == model.manual_epoch_end_mean
assert torch.allclose(trainer.logged_metrics["val_loss_epoch"], model.manual_epoch_end_mean)
assert torch.allclose(trainer.callback_metrics["val_loss_epoch"], model.manual_epoch_end_mean)
assert torch.allclose(trainer.callback_metrics["val_loss"], model.manual_epoch_end_mean)
assert trainer.logged_metrics["val_loss_step"] == model.val_losses[-1]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def test_result_collection_batch_size_extraction():
fx_name = "training_step"
log_val = torch.tensor(7.0)

results = _ResultCollection(training=True, device="cpu")
results = _ResultCollection(training=True)
results.batch = torch.randn(1, 4)
train_mse = MeanSquaredError()
train_mse(torch.randn(4, 5), torch.randn(4, 5))
Expand All @@ -615,7 +615,7 @@ def test_result_collection_batch_size_extraction():
assert isinstance(results["training_step.mse"].value, MeanSquaredError)
assert results["training_step.log_val"].value == log_val

results = _ResultCollection(training=True, device="cpu")
results = _ResultCollection(training=True)
results.batch = torch.randn(1, 4)
results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True)
assert results.batch_size == 1
Expand All @@ -624,7 +624,7 @@ def test_result_collection_batch_size_extraction():


def test_result_collection_no_batch_size_extraction():
results = _ResultCollection(training=True, device="cpu")
results = _ResultCollection(training=True)
results.batch = torch.randn(1, 4)
fx_name = "training_step"
batch_size = 10
Expand Down