-
-
Notifications
You must be signed in to change notification settings - Fork 673
Add MutualInformation Metric #3230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+247
−1
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
2e0d9fd
add MutualInformationMetric
kzkadc c6cf3e5
update test for MutualInformation metric
kzkadc 06b218c
format code for MutualInformation Metric
kzkadc 896e92c
Merge branch 'master' into mutual-information
kzkadc b85fbd1
update test for MutualInformation metric
kzkadc e59e05b
Merge branch 'master' into mutual-information
kzkadc 78bdbb2
Merge branch 'master' into mutual-information
kzkadc 40e9e67
update test
kzkadc f7d3a41
update docstring
kzkadc 72e1664
fix device compatibility
kzkadc 3f5c28d
fix test_accumulator_device for MutualInformation metric
kzkadc 035e32e
Merge branch 'master' into mutual-information
kzkadc b274037
Merge branch 'master' into mutual-information
kzkadc 61669ee
update doc
kzkadc 51676eb
modify docstring
kzkadc 6ededec
modify formula of docstring
kzkadc ba3e78f
update formula of docstring
kzkadc 73b7928
update formula of docstring
kzkadc c4bc88a
remove unused import
kzkadc d124752
add reference
kzkadc 366d854
commonalize redundant code
kzkadc d6b2ee6
modify decorator
kzkadc b1cc792
add a comment
kzkadc a23b435
fix decorator
kzkadc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| import torch | ||
|
|
||
| from ignite.exceptions import NotComputableError | ||
| from ignite.metrics import Entropy | ||
| from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce | ||
|
|
||
| __all__ = ["MutualInformation"] | ||
|
|
||
|
|
||
| class MutualInformation(Entropy): | ||
| r"""Calculates the `mutual information <https://en.wikipedia.org/wiki/Mutual_information>`_ | ||
| between input :math:`X` and prediction :math:`Y`. | ||
|
|
||
| .. math:: | ||
| \begin{align*} | ||
| I(X;Y) &= H(Y) - H(Y|X) = H \left( \frac{1}{N}\sum_{i=1}^N \hat{\mathbf{p}}_i \right) | ||
| - \frac{1}{N}\sum_{i=1}^N H(\hat{\mathbf{p}}_i), \\ | ||
| H(\mathbf{p}) &= -\sum_{c=1}^C p_c \log p_c. | ||
| \end{align*} | ||
|
|
||
| where :math:`\hat{\mathbf{p}}_i` is the prediction probability vector for :math:`i`-th input, | ||
| and :math:`H(\mathbf{p})` is the entropy of :math:`\mathbf{p}`. | ||
|
|
||
| Intuitively, this metric measures how well input data are clustered by classes in the feature space [1]. | ||
|
|
||
| [1] https://proceedings.mlr.press/v70/hu17b.html | ||
|
|
||
| - ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric. | ||
| - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) | ||
| or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. | ||
|
|
||
| Args: | ||
| output_transform: a callable that is used to transform the | ||
| :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
| form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
| you want to compute the metric with respect to one of the outputs. | ||
| By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. | ||
| device: specifies which device updates are accumulated on. Setting the | ||
| metric's device to be the same as your ``update`` arguments ensures the ``update`` method is | ||
| non-blocking. By default, CPU. | ||
|
|
||
| Examples: | ||
| To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. | ||
| The output of the engine's ``process_function`` needs to be in the format of | ||
| ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added | ||
| to the metric to transform the output into the form expected by the metric. | ||
|
|
||
| For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. | ||
|
|
||
| .. include:: defaults.rst | ||
| :start-after: :orphan: | ||
|
|
||
| .. testcode:: | ||
|
|
||
| metric = MutualInformation() | ||
| metric.attach(default_evaluator, 'mutual_information') | ||
| y_true = torch.tensor([0, 1, 2]) # not considered in the MutualInformation metric. | ||
| y_pred = torch.tensor([ | ||
| [ 0.0000, 0.6931, 1.0986], | ||
| [ 1.3863, 1.6094, 1.6094], | ||
| [ 0.0000, -2.3026, -2.3026] | ||
| ]) | ||
| state = default_evaluator.run([[y_pred, y_true]]) | ||
| print(state.metrics['mutual_information']) | ||
|
|
||
| .. testoutput:: | ||
|
|
||
| 0.18599730730056763 | ||
| """ | ||
|
|
||
| _state_dict_all_req_keys = ("_sum_of_probabilities",) | ||
|
|
||
| @reinit__is_reduced | ||
| def reset(self) -> None: | ||
| super().reset() | ||
| self._sum_of_probabilities = torch.tensor(0.0, device=self._device) | ||
|
|
||
| def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None: | ||
| super()._update(prob, log_prob) | ||
| # We can't use += below as _sum_of_probabilities can be a scalar and prob.sum(dim=0) is a vector | ||
| self._sum_of_probabilities = self._sum_of_probabilities + prob.sum(dim=0).to(self._device) | ||
kzkadc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @sync_all_reduce("_sum_of_probabilities", "_sum_of_entropies", "_num_examples") | ||
| def compute(self) -> float: | ||
| n = self._num_examples | ||
| if n == 0: | ||
| raise NotComputableError("MutualInformation must have at least one example before it can be computed.") | ||
|
|
||
| marginal_prob = self._sum_of_probabilities / n | ||
| marginal_ent = -(marginal_prob * torch.log(marginal_prob)).sum() | ||
| conditional_ent = self._sum_of_entropies / n | ||
| mi = marginal_ent - conditional_ent | ||
| mi = torch.clamp(mi, min=0.0) # mutual information cannot be negative | ||
| return float(mi.item()) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,145 @@ | ||
| from typing import Tuple | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
| from scipy.special import softmax | ||
| from scipy.stats import entropy | ||
| from torch import Tensor | ||
|
|
||
| import ignite.distributed as idist | ||
|
|
||
| from ignite.engine import Engine | ||
| from ignite.exceptions import NotComputableError | ||
| from ignite.metrics import MutualInformation | ||
|
|
||
|
|
||
| def np_mutual_information(np_y_pred: np.ndarray) -> float: | ||
| prob = softmax(np_y_pred, axis=1) | ||
| marginal_ent = entropy(np.mean(prob, axis=0)) | ||
| conditional_ent = np.mean(entropy(prob, axis=1)) | ||
| return max(0.0, marginal_ent - conditional_ent) | ||
kzkadc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def test_zero_sample(): | ||
| mi = MutualInformation() | ||
| with pytest.raises( | ||
| NotComputableError, match=r"MutualInformation must have at least one example before it can be computed" | ||
| ): | ||
| mi.compute() | ||
|
|
||
|
|
||
| def test_invalid_shape(): | ||
| mi = MutualInformation() | ||
| y_pred = torch.randn(10).float() | ||
| with pytest.raises(ValueError, match=r"y_pred must be in the shape of \(B, C\) or \(B, C, ...\), got"): | ||
| mi.update((y_pred, None)) | ||
|
|
||
|
|
||
| @pytest.fixture(params=list(range(4))) | ||
| def test_case(request): | ||
| return [ | ||
| (torch.randn((100, 10)).float(), torch.randint(0, 10, size=[100]), 1), | ||
| (torch.rand((100, 500)).float(), torch.randint(0, 500, size=[100]), 1), | ||
| # updated batches | ||
| (torch.normal(0.0, 5.0, size=(100, 10)).float(), torch.randint(0, 10, size=[100]), 16), | ||
| (torch.normal(5.0, 3.0, size=(100, 200)).float(), torch.randint(0, 200, size=[100]), 16), | ||
| # image segmentation | ||
| (torch.randn((100, 5, 32, 32)).float(), torch.randint(0, 5, size=(100, 32, 32)), 16), | ||
| (torch.randn((100, 5, 224, 224)).float(), torch.randint(0, 5, size=(100, 224, 224)), 16), | ||
| ][request.param] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("n_times", range(5)) | ||
| def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]): | ||
| mi = MutualInformation() | ||
|
|
||
| y_pred, y, batch_size = test_case | ||
|
|
||
| mi.reset() | ||
| if batch_size > 1: | ||
| n_iters = y.shape[0] // batch_size + 1 | ||
| for i in range(n_iters): | ||
| idx = i * batch_size | ||
| mi.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) | ||
| else: | ||
| mi.update((y_pred, y)) | ||
|
|
||
| np_res = np_mutual_information(y_pred.numpy()) | ||
| res = mi.compute() | ||
|
|
||
| assert isinstance(res, float) | ||
| assert pytest.approx(np_res, rel=1e-4) == res | ||
|
|
||
|
|
||
| def test_accumulator_detached(): | ||
| mi = MutualInformation() | ||
|
|
||
| y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True) | ||
| y = torch.zeros(2) | ||
| mi.update((y_pred, y)) | ||
|
|
||
| assert not mi._sum_of_probabilities.requires_grad | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("distributed") | ||
| class TestDistributed: | ||
| def test_integration(self): | ||
| tol = 1e-4 | ||
| n_iters = 100 | ||
| batch_size = 10 | ||
| n_cls = 50 | ||
| device = idist.device() | ||
| rank = idist.get_rank() | ||
| torch.manual_seed(12 + rank) | ||
|
|
||
| metric_devices = [torch.device("cpu")] | ||
| if device.type != "xla": | ||
| metric_devices.append(device) | ||
|
|
||
| for metric_device in metric_devices: | ||
| y_true = torch.randint(0, n_cls, size=[n_iters * batch_size], dtype=torch.long).to(device) | ||
| y_preds = torch.normal(0.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device) | ||
|
|
||
| engine = Engine( | ||
| lambda e, i: ( | ||
| y_preds[i * batch_size : (i + 1) * batch_size], | ||
| y_true[i * batch_size : (i + 1) * batch_size], | ||
| ) | ||
| ) | ||
|
|
||
| m = MutualInformation(device=metric_device) | ||
| m.attach(engine, "mutual_information") | ||
|
|
||
| data = list(range(n_iters)) | ||
| engine.run(data=data, max_epochs=1) | ||
|
|
||
| y_preds = idist.all_gather(y_preds) | ||
| y_true = idist.all_gather(y_true) | ||
|
|
||
| assert "mutual_information" in engine.state.metrics | ||
| res = engine.state.metrics["mutual_information"] | ||
|
|
||
| true_res = np_mutual_information(y_preds.cpu().numpy()) | ||
|
|
||
| assert pytest.approx(true_res, rel=tol) == res | ||
|
|
||
| def test_accumulator_device(self): | ||
| device = idist.device() | ||
| metric_devices = [torch.device("cpu")] | ||
| if device.type != "xla": | ||
| metric_devices.append(device) | ||
| for metric_device in metric_devices: | ||
| mi = MutualInformation(device=metric_device) | ||
|
|
||
| devices = (mi._device, mi._sum_of_probabilities.device) | ||
| for dev in devices: | ||
| assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" | ||
|
|
||
| y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True) | ||
| y = torch.zeros(2) | ||
| mi.update((y_pred, y)) | ||
|
|
||
| devices = (mi._device, mi._sum_of_probabilities.device) | ||
| for dev in devices: | ||
| assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.