From 9df6835c1c210d365b77cb9503aa4026bad782fc Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Wed, 20 Mar 2024 12:20:25 +0900 Subject: [PATCH 1/8] add Entropy metric --- docs/source/metrics.rst | 1 + ignite/metrics/__init__.py | 2 + ignite/metrics/entropy.py | 82 +++++++++++ tests/ignite/metrics/test_entropy.py | 196 +++++++++++++++++++++++++++ 4 files changed, 281 insertions(+) create mode 100644 ignite/metrics/entropy.py create mode 100644 tests/ignite/metrics/test_entropy.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index ca0f41661a10..0ed0290bd2b5 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -351,6 +351,7 @@ Complete list of metrics InceptionScore FID CosineSimilarity + Entropy Helpers for customizing metrics ------------------------------- diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 9d63cfdc4ac8..04b490b9486b 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -3,6 +3,7 @@ from ignite.metrics.classification_report import ClassificationReport from ignite.metrics.confusion_matrix import ConfusionMatrix, DiceCoefficient, IoU, JaccardIndex, mIoU from ignite.metrics.cosine_similarity import CosineSimilarity +from ignite.metrics.entropy import Entropy from ignite.metrics.epoch_metric import EpochMetric from ignite.metrics.fbeta import Fbeta from ignite.metrics.frequency import Frequency @@ -39,6 +40,7 @@ "TopKCategoricalAccuracy", "Average", "DiceCoefficient", + "Entropy", "EpochMetric", "Fbeta", "FID", diff --git a/ignite/metrics/entropy.py b/ignite/metrics/entropy.py new file mode 100644 index 000000000000..03c38eda4ee1 --- /dev/null +++ b/ignite/metrics/entropy.py @@ -0,0 +1,82 @@ +from typing import Sequence + +import torch +import torch.nn.functional as F + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["Entropy"] + + +class Entropy(Metric): + r"""Calculates the mean of `entropy `_. + + .. math:: H = \frac{1}{N} \sum_{i=1}^N \sum_{c=1}^C -p_{i,c} \log p_{i,c}, + \quad p_{i,c} = \frac{\exp(z_{i,c})}{\sum_{c'=1}^C \exp(z_{i,c'})} + + where :math:`p_{i,c}` is the prediction probability of :math:`i`-th data belonging to the class :math:`c`. + + - ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric. + - ``y_pred`` is expected to be the unnormalized logits for each class. + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added + to the metric to transform the output into the form expected by the metric. + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = Entropy() + metric.attach(default_evaluator, 'entropy') + y_true = torch.tensor([0, 1, 2]) # not considered in the Entropy metric. + y_pred = torch.tensor([ + [ 0.0000, 0.6931, 1.0986], + [ 1.3863, 1.6094, 1.6094], + [ 0.0000, -2.3026, -2.3026] + ]) + state = default_evaluator.run([[y_pred, y_true]]) + print(state.metrics['entropy']) + + .. testoutput:: + + 0.8902875582377116 + """ + + _state_dict_all_req_keys = ("_sum_of_entropies", "_num_examples") + + @reinit__is_reduced + def reset(self) -> None: + self._sum_of_entropies = torch.tensor(0.0, device=self._device) + self._num_examples = 0 + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + y_pred = output[0].detach() + prob = F.softmax(y_pred, dim=1) + log_prob = F.log_softmax(y_pred, dim=1) + entropy_sum = -torch.sum(prob * log_prob) + self._sum_of_entropies += entropy_sum.to(self._device) + self._num_examples += y_pred.shape[0] + + @sync_all_reduce("_sum_of_entropies", "_num_examples") + def compute(self) -> float: + if self._num_examples == 0: + raise NotComputableError("Entropy must have at least one example before it can be computed.") + return self._sum_of_entropies.item() / self._num_examples diff --git a/tests/ignite/metrics/test_entropy.py b/tests/ignite/metrics/test_entropy.py new file mode 100644 index 000000000000..6f0f9434c7a3 --- /dev/null +++ b/tests/ignite/metrics/test_entropy.py @@ -0,0 +1,196 @@ +import os + +import numpy as np +import pytest +import torch + +import ignite.distributed as idist +from ignite.exceptions import NotComputableError +from ignite.metrics import Entropy + + +def np_entropy(np_y_pred: np.ndarray): + np_y_pred = np_y_pred - np_y_pred.max(axis=1, keepdims=True) + prob = np.exp(np_y_pred) / np.sum(np.exp(np_y_pred), axis=1, keepdims=True) + log_prob = np_y_pred - np.log(np.sum(np.exp(np_y_pred), axis=1, keepdims=True)) + np_ent = -np.sum(prob * log_prob) / np_y_pred.shape[0] + return np_ent + + +def test_zero_sample(): + ent = Entropy() + with pytest.raises(NotComputableError, match=r"Entropy must have at least one example before it can be computed"): + ent.compute() + + +@pytest.fixture(params=[item for item in range(4)]) +def test_case(request): + return [ + (torch.randn((100, 10)), torch.randint(0, 10, size=[100]), 1), + (torch.rand((100, 500)), torch.randint(0, 500, size=[100]), 1), + # updated batches + (torch.normal(0.0, 5.0, size=(100, 10)), torch.randint(0, 10, size=[100]), 16), + (torch.normal(5.0, 3.0, size=(100, 200)), torch.randint(0, 200, size=[100]), 16), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_compute(n_times, test_case): + ent = Entropy() + + y_pred, y, batch_size = test_case + + ent.reset() + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + ent.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + ent.update((y_pred, y)) + + np_res = np_entropy(y_pred.numpy()) + + assert isinstance(ent.compute(), float) + assert pytest.approx(ent.compute()) == np_res + + +def _test_distrib_integration(device, tol=1e-6): + from ignite.engine import Engine + + rank = idist.get_rank() + torch.manual_seed(12 + rank) + + def _test(metric_device): + n_iters = 100 + batch_size = 10 + n_cls = 50 + + y_true = torch.randint(0, n_cls, size=n_iters * batch_size, dtype=torch.long).to(device) + y_preds = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device) + + def update(engine, i): + return ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + + engine = Engine(update) + + m = Entropy(device=metric_device) + m.attach(engine, "entropy") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + + assert "entropy" in engine.state.metrics + res = engine.state.metrics["entropy"] + + true_res = np_entropy(y_preds.numpy()) + + assert pytest.approx(res, rel=tol) == true_res + + _test("cpu") + if device.type != "xla": + _test(idist.device()) + + +def _test_distrib_accumulator_device(device): + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + device = torch.device(device) + ent = Entropy(device=metric_device) + + for dev in [ent._device, ent._sum_of_entropies.device]: + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" + + y_pred = torch.tensor([[2.0], [-2.0]]) + y = torch.zeros(2) + ent.update((y_pred, y)) + + for dev in [ent._device, ent._sum_of_entropies.device]: + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" + + +def test_accumulator_detached(): + ent = Entropy() + + y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True) + y = torch.zeros(2) + ent.update((y_pred, y)) + + assert not ent._sum_of_entropies.requires_grad + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_distrib_nccl_gpu(distributed_context_single_node_nccl): + device = idist.device() + _test_distrib_integration(device) + _test_distrib_accumulator_device(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): + device = idist.device() + _test_distrib_integration(device) + _test_distrib_accumulator_device(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") +@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") +def test_distrib_hvd(gloo_hvd_executor): + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") + nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() + + gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): + device = idist.device() + _test_distrib_integration(device) + _test_distrib_accumulator_device(device) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): + device = idist.device() + _test_distrib_integration(device) + _test_distrib_accumulator_device(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_single_device_xla(): + device = idist.device() + _test_distrib_integration(device, tol=1e-4) + _test_distrib_accumulator_device(device) + + +def _test_distrib_xla_nprocs(index): + device = idist.device() + _test_distrib_integration(device, tol=1e-4) + _test_distrib_accumulator_device(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_xla_nprocs(xmp_executor): + n = int(os.environ["NUM_TPU_WORKERS"]) + xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) From d12ac46da1ab696c9948484b613defea4f0c96a4 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Wed, 20 Mar 2024 19:22:44 +0900 Subject: [PATCH 2/8] fix error in torch.randint --- tests/ignite/metrics/test_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ignite/metrics/test_entropy.py b/tests/ignite/metrics/test_entropy.py index 6f0f9434c7a3..a8123b13a65a 100644 --- a/tests/ignite/metrics/test_entropy.py +++ b/tests/ignite/metrics/test_entropy.py @@ -66,7 +66,7 @@ def _test(metric_device): batch_size = 10 n_cls = 50 - y_true = torch.randint(0, n_cls, size=n_iters * batch_size, dtype=torch.long).to(device) + y_true = torch.randint(0, n_cls, size=[n_iters * batch_size], dtype=torch.long).to(device) y_preds = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device) def update(engine, i): From fc7eddbf52d16dfdc8c77eae5e9c2e92d7699a9e Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Thu, 21 Mar 2024 22:30:21 +0900 Subject: [PATCH 3/8] update Entropy metric to support other shapes --- ignite/metrics/entropy.py | 10 +++++++++- tests/ignite/metrics/test_entropy.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/entropy.py b/ignite/metrics/entropy.py index 03c38eda4ee1..4818df5b099e 100644 --- a/ignite/metrics/entropy.py +++ b/ignite/metrics/entropy.py @@ -18,7 +18,7 @@ class Entropy(Metric): where :math:`p_{i,c}` is the prediction probability of :math:`i`-th data belonging to the class :math:`c`. - ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric. - - ``y_pred`` is expected to be the unnormalized logits for each class. + - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. Args: output_transform: a callable that is used to transform the @@ -69,6 +69,14 @@ def reset(self) -> None: @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: y_pred = output[0].detach() + if y_pred.ndim >= 3: + num_classes = y_pred.shape[1] + # (B, C, ...) -> (B, ..., C) -> (B*..., C) + # regarding as B*... predictions + y_pred = y_pred.movedim(1, -1).reshape(-1, num_classes) + elif y_pred.ndim == 1: + raise ValueError(f"y_pred must be in the shape of (B, C) or (B, C, ...), got {y_pred.shape}.") + prob = F.softmax(y_pred, dim=1) log_prob = F.log_softmax(y_pred, dim=1) entropy_sum = -torch.sum(prob * log_prob) diff --git a/tests/ignite/metrics/test_entropy.py b/tests/ignite/metrics/test_entropy.py index a8123b13a65a..03bd8720c745 100644 --- a/tests/ignite/metrics/test_entropy.py +++ b/tests/ignite/metrics/test_entropy.py @@ -23,6 +23,13 @@ def test_zero_sample(): ent.compute() +def test_invalid_shape(): + ent = Entropy() + y_pred = torch.randn(10).float() + with pytest.raises(ValueError, match=r"y_pred must be in the shape of \(B, C\) or \(B, C, ...\), got"): + ent.update((y_pred, None)) + + @pytest.fixture(params=[item for item in range(4)]) def test_case(request): return [ @@ -31,6 +38,9 @@ def test_case(request): # updated batches (torch.normal(0.0, 5.0, size=(100, 10)), torch.randint(0, 10, size=[100]), 16), (torch.normal(5.0, 3.0, size=(100, 200)), torch.randint(0, 200, size=[100]), 16), + # image segmentation + (torch.randn((100, 5, 32, 32)), torch.randint(0, 5, size=(100, 32, 32)), 16), + (torch.randn((100, 5, 224, 224)), torch.randint(0, 5, size=(100, 224, 224)), 16) ][request.param] From 4cb5780ddff99c332c14951496e8928b4c40038a Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Fri, 22 Mar 2024 21:07:32 +0900 Subject: [PATCH 4/8] Update ignite/metrics/entropy.py Co-authored-by: vfdev --- ignite/metrics/entropy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/entropy.py b/ignite/metrics/entropy.py index 4818df5b099e..96c156dce2b5 100644 --- a/ignite/metrics/entropy.py +++ b/ignite/metrics/entropy.py @@ -18,7 +18,8 @@ class Entropy(Metric): where :math:`p_{i,c}` is the prediction probability of :math:`i`-th data belonging to the class :math:`c`. - ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric. - - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. + - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) + or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. Args: output_transform: a callable that is used to transform the From 204f99a1f3e21d5ee58e98bcf79beeac168a44f9 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Fri, 22 Mar 2024 21:17:03 +0900 Subject: [PATCH 5/8] update test of Entropy metric --- tests/ignite/metrics/test_entropy.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/ignite/metrics/test_entropy.py b/tests/ignite/metrics/test_entropy.py index 03bd8720c745..260f09f56922 100644 --- a/tests/ignite/metrics/test_entropy.py +++ b/tests/ignite/metrics/test_entropy.py @@ -1,6 +1,8 @@ import os import numpy as np +from scipy.stats import entropy as scipy_entropy +from scipy.special import softmax import pytest import torch @@ -10,11 +12,9 @@ def np_entropy(np_y_pred: np.ndarray): - np_y_pred = np_y_pred - np_y_pred.max(axis=1, keepdims=True) - prob = np.exp(np_y_pred) / np.sum(np.exp(np_y_pred), axis=1, keepdims=True) - log_prob = np_y_pred - np.log(np.sum(np.exp(np_y_pred), axis=1, keepdims=True)) - np_ent = -np.sum(prob * log_prob) / np_y_pred.shape[0] - return np_ent + prob = softmax(np_y_pred, axis=1) + ent = np.mean(scipy_entropy(prob, axis=1)) + return ent def test_zero_sample(): From c35b54129897521a77e203aee7cb4e9c309ad4c7 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 22 Mar 2024 14:11:41 +0100 Subject: [PATCH 6/8] Update ignite/metrics/entropy.py --- ignite/metrics/entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/entropy.py b/ignite/metrics/entropy.py index 96c156dce2b5..b3d0cff21b6c 100644 --- a/ignite/metrics/entropy.py +++ b/ignite/metrics/entropy.py @@ -18,7 +18,7 @@ class Entropy(Metric): where :math:`p_{i,c}` is the prediction probability of :math:`i`-th data belonging to the class :math:`c`. - ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric. - - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) + - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. Args: From f839eaf7cd57fc94564744030410e401f689a3bf Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Fri, 22 Mar 2024 23:18:26 +0900 Subject: [PATCH 7/8] format code --- ignite/metrics/entropy.py | 2 +- tests/ignite/metrics/test_entropy.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ignite/metrics/entropy.py b/ignite/metrics/entropy.py index 96c156dce2b5..b3d0cff21b6c 100644 --- a/ignite/metrics/entropy.py +++ b/ignite/metrics/entropy.py @@ -18,7 +18,7 @@ class Entropy(Metric): where :math:`p_{i,c}` is the prediction probability of :math:`i`-th data belonging to the class :math:`c`. - ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric. - - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) + - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. Args: diff --git a/tests/ignite/metrics/test_entropy.py b/tests/ignite/metrics/test_entropy.py index 260f09f56922..527fefbad9f0 100644 --- a/tests/ignite/metrics/test_entropy.py +++ b/tests/ignite/metrics/test_entropy.py @@ -1,10 +1,10 @@ import os import numpy as np -from scipy.stats import entropy as scipy_entropy -from scipy.special import softmax import pytest import torch +from scipy.special import softmax +from scipy.stats import entropy as scipy_entropy import ignite.distributed as idist from ignite.exceptions import NotComputableError @@ -40,7 +40,7 @@ def test_case(request): (torch.normal(5.0, 3.0, size=(100, 200)), torch.randint(0, 200, size=[100]), 16), # image segmentation (torch.randn((100, 5, 32, 32)), torch.randint(0, 5, size=(100, 32, 32)), 16), - (torch.randn((100, 5, 224, 224)), torch.randint(0, 5, size=(100, 224, 224)), 16) + (torch.randn((100, 5, 224, 224)), torch.randint(0, 5, size=(100, 224, 224)), 16), ][request.param] From 605ed51810868f34fdcf5f6a6fe8a02bf826a2aa Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 23 Mar 2024 00:31:21 +0900 Subject: [PATCH 8/8] fix error in converting Tensor to ndarray --- tests/ignite/metrics/test_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ignite/metrics/test_entropy.py b/tests/ignite/metrics/test_entropy.py index 527fefbad9f0..712203572b21 100644 --- a/tests/ignite/metrics/test_entropy.py +++ b/tests/ignite/metrics/test_entropy.py @@ -99,7 +99,7 @@ def update(engine, i): assert "entropy" in engine.state.metrics res = engine.state.metrics["entropy"] - true_res = np_entropy(y_preds.numpy()) + true_res = np_entropy(y_preds.cpu().numpy()) assert pytest.approx(res, rel=tol) == true_res