diff --git a/CHANGELOG.md b/CHANGELOG.md index c7770a90667..733a49ace07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for multioutput evaluation in `MeanSquaredError` ([#1937](https://github.com/Lightning-AI/torchmetrics/pull/1937)) +- Added `EvaluationDistributedSampler` to utility for proper distributed evaluation ([#1886](https://github.com/Lightning-AI/torchmetrics/pull/1886)) + + - Added argument `extended_summary` to `MeanAveragePrecision` such that precision, recall, iou can be easily returned ([#1983](https://github.com/Lightning-AI/torchmetrics/pull/1983)) diff --git a/docs/source/references/utilities.rst b/docs/source/references/utilities.rst index 054f0cab6fe..d1963b0ce51 100644 --- a/docs/source/references/utilities.rst +++ b/docs/source/references/utilities.rst @@ -1,9 +1,37 @@ .. role:: hidden :class: hidden-section -########################### +###################### +torchmetrics.utilities +###################### + +In the following is listed public utility functions that may be beneficial to use in your own code. These functions are +not part of the public API and may change at any time. + +********************************** +torchmetrics.utilities.distributed +********************************** + +The `distributed` utilities are used to help with syncronization of metrics across multiple processes. + +EvaluationDistributedSampler +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.utilities.distributed.EvaluationDistributedSampler + :noindex: + +gather_all_tensors +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.utilities.distributed.gather_all_tensors + :noindex: + +*************************** torchmetrics.utilities.data -########################### +*************************** + +The `data` utilities are used to help with data manipulation, such as converting labels in classification from one format +to another. select_topk ~~~~~~~~~~~ @@ -20,9 +48,9 @@ to_onehot .. autofunction:: torchmetrics.utilities.data.to_onehot -################################# +********************************* torchmetrics.utilities.exceptions -################################# +********************************* TorchMetricsUserError ~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/distributed_evaluation.py b/examples/distributed_evaluation.py new file mode 100644 index 00000000000..f10ebabc6b3 --- /dev/null +++ b/examples/distributed_evaluation.py @@ -0,0 +1,195 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An example of how to use the distributed evaluation utilities in both Lightning and PyTorch. + +To run using only Pytorch: + python distributed_evaluation.py +To run using Lightning: + python distributed_evaluation.py --use_lightning + +By default, this example uses the EvaluationDistributedSampler, which is a custom sampler that ensures that no extra +samples are added to the dataset. This is important for evaluation, as we don't want to evaluate on the same samples +multiple times. + +If you want to see the difference between the EvaluationDistributedSampler and the standard DistributedSampler, you +add the flag --use_standard. This will use the standard DistributedSampler, which will add extra samples to the dataset +and thus give incorrect results. + +""" +import argparse +import os +from typing import Tuple + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torchmetrics +from lightning_utilities import module_available +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader, Dataset, DistributedSampler, TensorDataset +from torchmetrics.utilities.distributed import EvaluationDistributedSampler + +_ = torch.manual_seed(42) + + +class DummyModel(Module): + """Dummy model consisting of a single linear layer.""" + + def __init__(self, n_feature: int) -> None: + super().__init__() + self.linear = torch.nn.Linear(n_feature, 10) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.linear(x) + + +def calculate_accuracy_manually(dataset: Dataset, model: Module) -> Tensor: + """Basic function to calculate accuracy manually, without any distributed stuff.""" + x, y = dataset.tensors + preds = model(x) + return (preds.argmax(dim=1) == y).float().mean() + + +def use_lightning( + model: Module, dataset: Dataset, batch_size: int, use_standard: bool, num_processes: int, gpu: bool +) -> None: + """Use lightning to evaluate a model on a dataset.""" + if module_available("lightning"): + from lightning.pytorch import LightningModule, Trainer + else: + from pytorch_lightning import LightningModule, Trainer + + sampler_class = DistributedSampler if use_standard else EvaluationDistributedSampler + + class DummyLightningModule(LightningModule): + def __init__(self, model: Module) -> None: + super().__init__() + self.model = model + self.metric = torchmetrics.classification.MulticlassAccuracy(num_classes=10, average="micro") + + def forward(self, x: Tensor) -> Tensor: + return self.model(x) + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + preds = model(batch[0]) + target = batch[1] + self.metric.update(preds, target) + + def on_test_epoch_end(self) -> None: + self.log("test_acc", self.metric.compute()) + + def test_dataloader(self) -> DataLoader: + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler_class(dataset), + ) + + model = DummyLightningModule(model) + + trainer = Trainer( + devices=num_processes, + accelerator="cpu" if not gpu else "gpu", + ) + + res = trainer.test(model) + manual_res = calculate_accuracy_manually(dataset, model) + print(manual_res) + if torch.allclose(torch.tensor(res[0]["test_acc"]), manual_res): + print("success! result matched manual calculation") + else: + print("failure! result did not match manual calculation") + + +def _use_torch_worker_fn( + rank: int, model: Module, dataset: Dataset, batch_size: int, use_standard: bool, num_processes: int, gpu: bool +) -> None: + """Worker function for torch.distributed evaluation.""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + dist.init_process_group("nvcc" if gpu else "gloo", rank=rank, world_size=num_processes) + + device = torch.device(f"cuda:rank{rank}") if gpu else torch.device("cpu") + + sampler_class = DistributedSampler if use_standard else EvaluationDistributedSampler + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler_class(dataset, num_processes, rank), + ) + + metric = torchmetrics.classification.MulticlassAccuracy(num_classes=10, average="micro") + metric = metric.to(device) + + batches, num_samples = 0, 0 + for _, batch in enumerate(dataloader): + if gpu: + batch = batch.cuda() + preds = model(batch[0]) + target = batch[1] + + metric.update(preds.to(device), target.to(device)) + num_samples += len(target) + batches += 1 + + res = metric.compute() + + print(f"Rank {rank} processed {num_samples} samples and {batches} batches and calculated accuracy: {res}") + + manual_res = calculate_accuracy_manually(dataset, model) + if torch.allclose(res, manual_res): + print("success! result matched manual calculation") + else: + print("failure! result did not match manual calculation") + + +def use_torch( + model: Module, dataset: Dataset, batch_size: int, use_standard: bool, num_processes: int, gpu: bool +) -> None: + """Use torch.distributed to evaluate a model on a dataset.""" + mp.spawn(_use_torch_worker_fn, nprocs=2, args=(model, dataset, batch_size, use_standard, num_processes, gpu)) + + +def main() -> None: + """Main function.""" + parser = argparse.ArgumentParser() + parser.add_argument("--use_lightning", action="store_true") + parser.add_argument("--use_standard", action="store_true") + parser.add_argument("--num_processes", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=3) + parser.add_argument("--gpu", action="store_true") + args = parser.parse_args() + print(args) + + dataset = TensorDataset(torch.randn(199, 100), torch.randint(0, 10, (199,))) + n_feature = 100 + dummy_model = DummyModel(n_feature) + + batch_size = 3 + if len(dataset) % (args.num_processes * batch_size) == 0: + raise ValueError( + "For this example the dataset size should NOT be divisible by the number of processes times the batch size." + ) + + if args.use_lightning: + use_lightning(dummy_model, dataset, batch_size, args.use_standard, args.num_processes, args.gpu) + else: + use_torch(dummy_model, dataset, batch_size, args.use_standard, args.num_processes, args.gpu) + + +if __name__ == "__main__": + main() diff --git a/src/torchmetrics/utilities/__init__.py b/src/torchmetrics/utilities/__init__.py index 234e3474873..32d6d4b6374 100644 --- a/src/torchmetrics/utilities/__init__.py +++ b/src/torchmetrics/utilities/__init__.py @@ -13,15 +13,16 @@ # limitations under the License. from torchmetrics.utilities.checks import check_forward_full_state_property from torchmetrics.utilities.data import apply_to_collection -from torchmetrics.utilities.distributed import class_reduce, reduce +from torchmetrics.utilities.distributed import EvaluationDistributedSampler, class_reduce, reduce from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn __all__ = [ - "check_forward_full_state_property", "apply_to_collection", + "check_forward_full_state_property", "class_reduce", - "reduce", + "EvaluationDistributedSampler", "rank_zero_debug", "rank_zero_info", "rank_zero_warn", + "reduce", ] diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 6b8b027bb93..47b7c12db65 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -16,6 +16,7 @@ import torch from torch import Tensor from torch.nn import functional as F # noqa: N812 +from torch.utils.data import Dataset from typing_extensions import Literal @@ -146,3 +147,65 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens slice_param = [slice(dim_size) for dim_size in item_size] gathered_result[idx] = gathered_result[idx][slice_param] return gathered_result + + +class EvaluationDistributedSampler(torch.utils.data.DistributedSampler): + """A specialized distributed sampler for evaluation (test and validation). + + It is derived from the PyTorch DistributedSampler, with one core difference: it doesn't add extra samples to make + the data evenly divisible across devices. This is important while evaluating, as adding extra samples will screw + the results towards those duplicated samples. + + Normally not adding the extra samples would lead to processes becoming out of sync, but this is handled by the + custom syncronization in Torchmetrics. Thus this sampler does not in general secure that distributed operations + are working outside of Torchmetrics. + + Arguments are the same as DistributedSampler, and this implementation only overrides the __init__ method. + + Args: + dataset: Dataset used for sampling. + num_replicas (int, optional): Number of processes participating in distributed training. By default, + :attr:`world_size` is retrieved from the current distributed group. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. By default, :attr:`rank` is + retrieved from the current distributed group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the indices. + seed (int, optional): random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be + identical across all processes in the distributed group. + drop_last (bool, optional): if ``True``, then the sampler will drop the tail of the data to make it evenly + divisible across the number of replicas. + + For a full example on how to use this sampler, using both bare PyTorch but also PyTorch Lightning, + check out the `distributed_evaluation.py` file in the examples folder. + + Example:: + The distributed sampler is always intended to be used in conjunction with a DataLoader: + + >>> import torch + >>> from torch.utils.data import DataLoader, TensorDataset + >>> from torchmetrics.utilities.distributed import EvaluationDistributedSampler + >>> dataset = TensorDataset(torch.arange(10)) + >>> dataloader = DataLoader( + ... dataset, sampler=EvaluationDistributedSampler(dataset, num_replicas=2) + ... ) # doctest: +SKIP + + """ + + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + # From: + # https://github.com/pytorch/pytorch/issues/25162#issuecomment-1227647626 + super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed) + + len_dataset = len(self.dataset) # type: ignore[arg-type] + if not self.drop_last and len_dataset % self.num_replicas != 0: + # some ranks may have less samples, that's fine + if self.rank >= len_dataset % self.num_replicas: + self.num_samples -= 1 + self.total_size = len_dataset diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index ee1dc505477..b7bc5969aca 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -18,6 +18,7 @@ from distutils.version import LooseVersion from typing import Optional +from lightning_utilities import module_available from lightning_utilities.core.imports import compare_version, package_available _PYTHON_VERSION = ".".join(map(str, [sys.version_info.major, sys.version_info.minor, sys.version_info.micro])) @@ -29,6 +30,12 @@ _TORCH_GREATER_EQUAL_1_12: Optional[bool] = compare_version("torch", operator.ge, "1.12.0") _TORCH_GREATER_EQUAL_1_13: Optional[bool] = compare_version("torch", operator.ge, "1.13.0") +_LIGHTNING_GREATER_EQUAL_2_0: Optional[bool] = ( + compare_version("lightning", operator.ge, "2.0.0") + if module_available("lightning") + else compare_version("pytorch_lightning", operator.ge, "2.0.0") +) + _JIWER_AVAILABLE: bool = package_available("jiwer") _NLTK_AVAILABLE: bool = package_available("nltk") _ROUGE_SCORE_AVAILABLE: bool = package_available("rouge_score") diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index 182b4e2243c..2d3640dbb48 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -13,10 +13,12 @@ # limitations under the License. from unittest import mock +import pytest import torch from lightning_utilities import module_available from torch import tensor from torch.nn import Linear +from torch.utils.data import DataLoader, DistributedSampler if module_available("lightning"): from lightning.pytorch import LightningModule, Trainer @@ -27,7 +29,9 @@ from torchmetrics import MetricCollection from torchmetrics.aggregation import SumMetric -from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision +from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision, MulticlassAccuracy +from torchmetrics.utilities.distributed import EvaluationDistributedSampler +from torchmetrics.utilities.imports import _LIGHTNING_GREATER_EQUAL_2_0 from integrations.helpers import no_warning_call from integrations.lightning.boring_model import BoringModel @@ -439,3 +443,67 @@ def configure_optimizers(self): model = model.type(torch.half) assert model.metric.sum_value.dtype == torch.float32 + + +class _DistributedEvaluationTestModel(BoringModel): + def __init__(self, dataset, sampler_class, devices) -> None: + super().__init__() + self.linear = torch.nn.Linear(32, 10) + self.metric = MulticlassAccuracy(num_classes=10, average="micro") + self.dataset = dataset + self.sampler_class = sampler_class + self.devices = devices + + def forward(self, x): + return self.linear(x) + + def test_step(self, batch, batch_idx): + preds = self(batch[0]).argmax(dim=1) + target = batch[1] + self.metric.update(preds, target) + + def on_test_epoch_end(self): + self.metric._should_unsync = False + result = self.metric.compute() + self.log("test_acc", result) + self.log("samples_seen", self.metric.tp + self.metric.fn) + + def test_dataloader(self): + if self.devices > 1: + return DataLoader(self.dataset, batch_size=3, sampler=self.sampler_class(self.dataset, shuffle=False)) + return DataLoader(self.dataset, batch_size=3, shuffle=False) + + +@pytest.mark.skipif(not _LIGHTNING_GREATER_EQUAL_2_0, reason="Test requires newer Lightning 2.0 version") +@pytest.mark.parametrize("sampler_class", [DistributedSampler, EvaluationDistributedSampler]) +@pytest.mark.parametrize("accelerator", ["cpu", "gpu"]) +@pytest.mark.parametrize("devices", [1, 2]) +def test_distributed_sampler_integration(sampler_class, accelerator, devices): + """Test the integration of the custom distributed sampler with Lightning.""" + if not torch.cuda.is_available() and accelerator == "gpu": + pytest.skip("test requires GPU machine") + if torch.cuda.is_available() and accelerator == "gpu" and torch.cuda.device_count() < 2 and devices > 1: + pytest.skip("test requires GPU machine with at least 2 GPUs") + + n_data = 199 + dataset = torch.utils.data.TensorDataset( + torch.arange(n_data).unsqueeze(1).repeat(1, 32).float(), + torch.arange(10).repeat(20)[:n_data], + ) + + model = _DistributedEvaluationTestModel(dataset, sampler_class, devices) + + trainer = Trainer( + devices=devices, + accelerator=accelerator, + ) + res = trainer.test(model) + manual_res = (model(dataset.tensors[0]).argmax(dim=1) == dataset.tensors[1]).float().mean() + + if sampler_class == DistributedSampler and devices > 1: + # normal sampler adds extra samples which skrews up the results + assert torch.allclose(torch.tensor(res[0]["samples_seen"], dtype=torch.long), torch.tensor(len(dataset) + 1)) + assert not torch.allclose(torch.tensor(res[0]["test_acc"]), manual_res) + else: + assert torch.allclose(torch.tensor(res[0]["samples_seen"], dtype=torch.long), torch.tensor(len(dataset))) + assert torch.allclose(torch.tensor(res[0]["test_acc"]), manual_res) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index fb16e87edae..6b90c53d649 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -19,8 +19,10 @@ import pytest import torch from torch import tensor -from torchmetrics import Metric -from torchmetrics.utilities.distributed import gather_all_tensors +from torch.utils.data import DistributedSampler +from torchmetrics.aggregation import CatMetric, SumMetric +from torchmetrics.metric import Metric +from torchmetrics.utilities.distributed import EvaluationDistributedSampler, gather_all_tensors from torchmetrics.utilities.exceptions import TorchMetricsUserError from unittests import NUM_PROCESSES @@ -272,3 +274,123 @@ def _test_sync_with_empty_lists(rank): def test_sync_with_empty_lists(): """Test that syncronization of states can be enabled and disabled for compute.""" pytest.pool.map(_test_sync_with_empty_lists, range(NUM_PROCESSES)) + + +def _test_evaluation_distributed_dataloader( + rank, + dataset_size, + batch_size, + distributed_sampler_class, + rank_0_batches, + rank_0_samples, + rank_1_batches, + rank_1_samples, + metric_class, +): + """Worker function for testing the EvaluationDistributedSampler.""" + metric = metric_class() + + dataset = torch.utils.data.TensorDataset(torch.arange(1, dataset_size + 1)) + dataloader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=distributed_sampler_class(dataset, num_replicas=NUM_PROCESSES, rank=rank, shuffle=False), + ) + + batch_count, sample_count = 0, 0 + for batch in dataloader: + metric.update(batch[0]) + batch_count += 1 + sample_count += len(batch[0]) + + res = metric.compute() + if rank == 0: + assert batch_count == rank_0_batches, f"The number of batches did not match, got {batch_count}" + assert sample_count == rank_0_samples, f"The number of samples did not match, got {sample_count}" + if rank == 1: + assert batch_count == rank_1_batches, f"The number of batches did not match, got {batch_count}" + assert sample_count == rank_1_samples, f"The number of samples did not match, got {sample_count}" + + if metric_class == SumMetric: + if distributed_sampler_class == EvaluationDistributedSampler: + assert ( + res == torch.arange(1, dataset_size + 1).sum() + ), "The result of the metric did not match the expected result" + if distributed_sampler_class == torch.utils.data.DistributedSampler: + if dataset_size % NUM_PROCESSES == 0: + assert ( + res == torch.arange(1, dataset_size + 1).sum() + ), "The result of the metric did not match the expected result" + else: + assert ( + res == torch.arange(1, dataset_size + 1).sum() + 1 + ), "The result of the metric did not match the expected result" + if metric_class == CatMetric: + x = set(res.tolist()) + y = set(torch.arange(1, dataset_size + 1).tolist()) + assert x - y == set(), "The result of the metric did not match the expected result" + assert y - x == set(), "The result of the metric did not match the expected result" + + if distributed_sampler_class == torch.utils.data.DistributedSampler: + if dataset_size % NUM_PROCESSES != 0: + assert len(res) == dataset_size + 1, "The result of the metric did not match the expected result" + else: + assert len(res) == dataset_size, "The result of the metric did not match the expected result" + else: + assert len(res) == dataset_size, "The result of the metric did not match the expected result" + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.parametrize( + ( + "dataset_size", + "batch_size", + "distributed_sampler_class", + "rank_0_batches", + "rank_0_samples", + "rank_1_batches", + "rank_1_samples", + ), + [ + (10, 1, EvaluationDistributedSampler, 5, 5, 5, 5), + (11, 1, EvaluationDistributedSampler, 6, 6, 5, 5), + (10, 3, EvaluationDistributedSampler, 2, 5, 2, 5), + (11, 3, EvaluationDistributedSampler, 2, 6, 2, 5), + # standard sampler adds samples if the dataset size is not divisible by the number of processes + (10, 1, DistributedSampler, 5, 5, 5, 5), + (11, 1, DistributedSampler, 6, 6, 6, 6), + (10, 3, DistributedSampler, 2, 5, 2, 5), + (11, 3, DistributedSampler, 2, 6, 2, 6), + ], +) +@pytest.mark.parametrize("metric_class", [SumMetric, CatMetric]) +def test_evaluation_distributed_dataloader( + dataset_size: int, + batch_size: int, + distributed_sampler_class: DistributedSampler, + rank_0_batches: int, + rank_0_samples: int, + rank_1_batches: int, + rank_1_samples: int, + metric_class: Metric, +): + """Test the EvaluationDistributedSampler. + + This sampler should not add additional samples to the dataset compared to the standard DistributedSampler. Thus we + expect different results. + + """ + pytest.pool.map( + partial( + _test_evaluation_distributed_dataloader, + dataset_size=dataset_size, + batch_size=batch_size, + distributed_sampler_class=distributed_sampler_class, + rank_0_batches=rank_0_batches, + rank_0_samples=rank_0_samples, + rank_1_batches=rank_1_batches, + rank_1_samples=rank_1_samples, + metric_class=metric_class, + ), + range(NUM_PROCESSES), + )