-
-
Notifications
You must be signed in to change notification settings - Fork 666
Add Entropy metric #3210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Entropy metric #3210
Changes from 6 commits
9df6835
d12ac46
fc7eddb
4cb5780
204f99a
8a85511
c35b541
f839eaf
c82d116
f411749
605ed51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| from typing import Sequence | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from ignite.exceptions import NotComputableError | ||
| from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce | ||
|
|
||
| __all__ = ["Entropy"] | ||
|
|
||
|
|
||
| class Entropy(Metric): | ||
| r"""Calculates the mean of `entropy <https://en.wikipedia.org/wiki/Entropy_(information_theory)>`_. | ||
|
|
||
| .. math:: H = \frac{1}{N} \sum_{i=1}^N \sum_{c=1}^C -p_{i,c} \log p_{i,c}, | ||
| \quad p_{i,c} = \frac{\exp(z_{i,c})}{\sum_{c'=1}^C \exp(z_{i,c'})} | ||
|
|
||
| where :math:`p_{i,c}` is the prediction probability of :math:`i`-th data belonging to the class :math:`c`. | ||
|
|
||
| - ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric. | ||
| - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) | ||
| or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. | ||
|
|
||
| Args: | ||
| output_transform: a callable that is used to transform the | ||
| :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
| form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
| you want to compute the metric with respect to one of the outputs. | ||
| By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. | ||
| device: specifies which device updates are accumulated on. Setting the | ||
| metric's device to be the same as your ``update`` arguments ensures the ``update`` method is | ||
| non-blocking. By default, CPU. | ||
|
|
||
| Examples: | ||
| To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. | ||
| The output of the engine's ``process_function`` needs to be in the format of | ||
| ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added | ||
| to the metric to transform the output into the form expected by the metric. | ||
|
|
||
| For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. | ||
|
|
||
| .. include:: defaults.rst | ||
| :start-after: :orphan: | ||
|
|
||
| .. testcode:: | ||
|
|
||
| metric = Entropy() | ||
| metric.attach(default_evaluator, 'entropy') | ||
| y_true = torch.tensor([0, 1, 2]) # not considered in the Entropy metric. | ||
| y_pred = torch.tensor([ | ||
| [ 0.0000, 0.6931, 1.0986], | ||
| [ 1.3863, 1.6094, 1.6094], | ||
| [ 0.0000, -2.3026, -2.3026] | ||
| ]) | ||
| state = default_evaluator.run([[y_pred, y_true]]) | ||
| print(state.metrics['entropy']) | ||
|
|
||
| .. testoutput:: | ||
|
|
||
| 0.8902875582377116 | ||
| """ | ||
|
|
||
| _state_dict_all_req_keys = ("_sum_of_entropies", "_num_examples") | ||
|
|
||
| @reinit__is_reduced | ||
| def reset(self) -> None: | ||
| self._sum_of_entropies = torch.tensor(0.0, device=self._device) | ||
| self._num_examples = 0 | ||
|
|
||
| @reinit__is_reduced | ||
| def update(self, output: Sequence[torch.Tensor]) -> None: | ||
| y_pred = output[0].detach() | ||
| if y_pred.ndim >= 3: | ||
| num_classes = y_pred.shape[1] | ||
| # (B, C, ...) -> (B, ..., C) -> (B*..., C) | ||
| # regarding as B*... predictions | ||
| y_pred = y_pred.movedim(1, -1).reshape(-1, num_classes) | ||
| elif y_pred.ndim == 1: | ||
| raise ValueError(f"y_pred must be in the shape of (B, C) or (B, C, ...), got {y_pred.shape}.") | ||
|
|
||
| prob = F.softmax(y_pred, dim=1) | ||
| log_prob = F.log_softmax(y_pred, dim=1) | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| entropy_sum = -torch.sum(prob * log_prob) | ||
| self._sum_of_entropies += entropy_sum.to(self._device) | ||
| self._num_examples += y_pred.shape[0] | ||
|
|
||
| @sync_all_reduce("_sum_of_entropies", "_num_examples") | ||
| def compute(self) -> float: | ||
| if self._num_examples == 0: | ||
| raise NotComputableError("Entropy must have at least one example before it can be computed.") | ||
| return self._sum_of_entropies.item() / self._num_examples | ||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,206 @@ | ||||||||||||
| import os | ||||||||||||
|
|
||||||||||||
| import numpy as np | ||||||||||||
| from scipy.stats import entropy as scipy_entropy | ||||||||||||
| from scipy.special import softmax | ||||||||||||
| import pytest | ||||||||||||
| import torch | ||||||||||||
|
|
||||||||||||
| import ignite.distributed as idist | ||||||||||||
| from ignite.exceptions import NotComputableError | ||||||||||||
| from ignite.metrics import Entropy | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def np_entropy(np_y_pred: np.ndarray): | ||||||||||||
| prob = softmax(np_y_pred, axis=1) | ||||||||||||
| ent = np.mean(scipy_entropy(prob, axis=1)) | ||||||||||||
| return ent | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_zero_sample(): | ||||||||||||
| ent = Entropy() | ||||||||||||
| with pytest.raises(NotComputableError, match=r"Entropy must have at least one example before it can be computed"): | ||||||||||||
| ent.compute() | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_invalid_shape(): | ||||||||||||
| ent = Entropy() | ||||||||||||
| y_pred = torch.randn(10).float() | ||||||||||||
| with pytest.raises(ValueError, match=r"y_pred must be in the shape of \(B, C\) or \(B, C, ...\), got"): | ||||||||||||
| ent.update((y_pred, None)) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.fixture(params=[item for item in range(4)]) | ||||||||||||
| def test_case(request): | ||||||||||||
| return [ | ||||||||||||
| (torch.randn((100, 10)), torch.randint(0, 10, size=[100]), 1), | ||||||||||||
| (torch.rand((100, 500)), torch.randint(0, 500, size=[100]), 1), | ||||||||||||
| # updated batches | ||||||||||||
| (torch.normal(0.0, 5.0, size=(100, 10)), torch.randint(0, 10, size=[100]), 16), | ||||||||||||
| (torch.normal(5.0, 3.0, size=(100, 200)), torch.randint(0, 200, size=[100]), 16), | ||||||||||||
| # image segmentation | ||||||||||||
| (torch.randn((100, 5, 32, 32)), torch.randint(0, 5, size=(100, 32, 32)), 16), | ||||||||||||
| (torch.randn((100, 5, 224, 224)), torch.randint(0, 5, size=(100, 224, 224)), 16) | ||||||||||||
| ][request.param] | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.mark.parametrize("n_times", range(5)) | ||||||||||||
| def test_compute(n_times, test_case): | ||||||||||||
| ent = Entropy() | ||||||||||||
|
|
||||||||||||
| y_pred, y, batch_size = test_case | ||||||||||||
|
|
||||||||||||
| ent.reset() | ||||||||||||
| if batch_size > 1: | ||||||||||||
| n_iters = y.shape[0] // batch_size + 1 | ||||||||||||
| for i in range(n_iters): | ||||||||||||
| idx = i * batch_size | ||||||||||||
| ent.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) | ||||||||||||
| else: | ||||||||||||
| ent.update((y_pred, y)) | ||||||||||||
|
|
||||||||||||
| np_res = np_entropy(y_pred.numpy()) | ||||||||||||
|
|
||||||||||||
| assert isinstance(ent.compute(), float) | ||||||||||||
| assert pytest.approx(ent.compute()) == np_res | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def _test_distrib_integration(device, tol=1e-6): | ||||||||||||
| from ignite.engine import Engine | ||||||||||||
|
|
||||||||||||
| rank = idist.get_rank() | ||||||||||||
| torch.manual_seed(12 + rank) | ||||||||||||
|
|
||||||||||||
| def _test(metric_device): | ||||||||||||
| n_iters = 100 | ||||||||||||
| batch_size = 10 | ||||||||||||
| n_cls = 50 | ||||||||||||
|
|
||||||||||||
| y_true = torch.randint(0, n_cls, size=[n_iters * batch_size], dtype=torch.long).to(device) | ||||||||||||
| y_preds = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device) | ||||||||||||
|
|
||||||||||||
| def update(engine, i): | ||||||||||||
| return ( | ||||||||||||
| y_preds[i * batch_size : (i + 1) * batch_size], | ||||||||||||
| y_true[i * batch_size : (i + 1) * batch_size], | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| engine = Engine(update) | ||||||||||||
|
|
||||||||||||
| m = Entropy(device=metric_device) | ||||||||||||
| m.attach(engine, "entropy") | ||||||||||||
|
|
||||||||||||
| data = list(range(n_iters)) | ||||||||||||
| engine.run(data=data, max_epochs=1) | ||||||||||||
|
|
||||||||||||
| y_preds = idist.all_gather(y_preds) | ||||||||||||
| y_true = idist.all_gather(y_true) | ||||||||||||
|
|
||||||||||||
| assert "entropy" in engine.state.metrics | ||||||||||||
| res = engine.state.metrics["entropy"] | ||||||||||||
|
|
||||||||||||
| true_res = np_entropy(y_preds.numpy()) | ||||||||||||
|
|
||||||||||||
| assert pytest.approx(res, rel=tol) == true_res | ||||||||||||
|
|
||||||||||||
| _test("cpu") | ||||||||||||
| if device.type != "xla": | ||||||||||||
| _test(idist.device()) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def _test_distrib_accumulator_device(device): | ||||||||||||
| metric_devices = [torch.device("cpu")] | ||||||||||||
| if device.type != "xla": | ||||||||||||
| metric_devices.append(idist.device()) | ||||||||||||
| for metric_device in metric_devices: | ||||||||||||
| device = torch.device(device) | ||||||||||||
| ent = Entropy(device=metric_device) | ||||||||||||
|
|
||||||||||||
| for dev in [ent._device, ent._sum_of_entropies.device]: | ||||||||||||
| assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" | ||||||||||||
|
|
||||||||||||
| y_pred = torch.tensor([[2.0], [-2.0]]) | ||||||||||||
| y = torch.zeros(2) | ||||||||||||
| ent.update((y_pred, y)) | ||||||||||||
|
|
||||||||||||
| for dev in [ent._device, ent._sum_of_entropies.device]: | ||||||||||||
| assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_accumulator_detached(): | ||||||||||||
| ent = Entropy() | ||||||||||||
|
|
||||||||||||
| y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True) | ||||||||||||
| y = torch.zeros(2) | ||||||||||||
| ent.update((y_pred, y)) | ||||||||||||
|
|
||||||||||||
| assert not ent._sum_of_entropies.requires_grad | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.mark.distributed | ||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For distributed config tests, could you please rewrite them using new testing formalism that we are trying to adopt. Here is an example of the code to inspire of: ignite/tests/ignite/metrics/test_recall.py Lines 422 to 426 in 5fe7443
Here is a PR showing how to pass from old code to the new one: Thanks! |
||||||||||||
| @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") | ||||||||||||
| @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") | ||||||||||||
| def test_distrib_nccl_gpu(distributed_context_single_node_nccl): | ||||||||||||
| device = idist.device() | ||||||||||||
| _test_distrib_integration(device) | ||||||||||||
| _test_distrib_accumulator_device(device) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.mark.distributed | ||||||||||||
| @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") | ||||||||||||
| def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): | ||||||||||||
| device = idist.device() | ||||||||||||
| _test_distrib_integration(device) | ||||||||||||
| _test_distrib_accumulator_device(device) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.mark.distributed | ||||||||||||
| @pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") | ||||||||||||
| @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") | ||||||||||||
| def test_distrib_hvd(gloo_hvd_executor): | ||||||||||||
| device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") | ||||||||||||
| nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() | ||||||||||||
|
|
||||||||||||
| gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) | ||||||||||||
| gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.mark.multinode_distributed | ||||||||||||
| @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") | ||||||||||||
| @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") | ||||||||||||
| def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): | ||||||||||||
| device = idist.device() | ||||||||||||
| _test_distrib_integration(device) | ||||||||||||
| _test_distrib_accumulator_device(device) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.mark.multinode_distributed | ||||||||||||
| @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") | ||||||||||||
| @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") | ||||||||||||
| def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): | ||||||||||||
| device = idist.device() | ||||||||||||
| _test_distrib_integration(device) | ||||||||||||
| _test_distrib_accumulator_device(device) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.mark.tpu | ||||||||||||
| @pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") | ||||||||||||
| @pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") | ||||||||||||
| def test_distrib_single_device_xla(): | ||||||||||||
| device = idist.device() | ||||||||||||
| _test_distrib_integration(device, tol=1e-4) | ||||||||||||
| _test_distrib_accumulator_device(device) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def _test_distrib_xla_nprocs(index): | ||||||||||||
| device = idist.device() | ||||||||||||
| _test_distrib_integration(device, tol=1e-4) | ||||||||||||
| _test_distrib_accumulator_device(device) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.mark.tpu | ||||||||||||
| @pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") | ||||||||||||
| @pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") | ||||||||||||
| def test_distrib_xla_nprocs(xmp_executor): | ||||||||||||
| n = int(os.environ["NUM_TPU_WORKERS"]) | ||||||||||||
| xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) | ||||||||||||
Uh oh!
There was an error while loading. Please reload this page.