-
-
Notifications
You must be signed in to change notification settings - Fork 666
Add PearsonCorrelation metric #3212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
dba3049
add PearsonCorrelation metric
kzkadc 1f80ede
match the notation of the docstring with the other metrics
kzkadc d5f3b0c
Merge branch 'master' into correlation-coefficient
kzkadc 69f7f5e
move PearsonCorrelation metric from contrib.metrics.regression to met…
kzkadc 855a505
update test for PearsonCorrelation metric
kzkadc 6b17c36
update test
kzkadc fb5b8f4
Merge branch 'master' into correlation-coefficient
kzkadc 194f213
Merge branch 'correlation-coefficient_rewrite-test' into correlation-…
kzkadc 7b1acd3
Merge branch 'master' into correlation-coefficient
kzkadc 62e65d2
Merge branch 'master' into correlation-coefficient
kzkadc ce701bc
Merge branch 'master' into correlation-coefficient
kzkadc b931ec9
modify doc for PearsonCorrelation metric
kzkadc 6609925
fix import
kzkadc 9069779
resolve code formatting issue
kzkadc 4a0096d
remove loop from test
kzkadc 1de9181
Merge branch 'master' into correlation-coefficient
vfdev-5 ad1e090
Update ignite/metrics/regression/pearson_correlation.py
kzkadc 4d6b18e
Update pearson_correlation.py
vfdev-5 ad41bbc
update test for PearsonCorrelation
kzkadc 30f1684
Update tests/ignite/metrics/regression/test_pearson_correlation.py
vfdev-5 c6b6d92
relax pytest.approx
kzkadc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| from typing import Callable, Tuple, Union | ||
|
|
||
| import torch | ||
|
|
||
| from ignite.exceptions import NotComputableError | ||
| from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce | ||
|
|
||
| from ignite.metrics.regression._base import _BaseRegression | ||
|
|
||
|
|
||
| class PearsonCorrelation(_BaseRegression): | ||
| r"""Calculates the | ||
| `Pearson correlation coefficient <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`_. | ||
|
|
||
| .. math:: | ||
| r = \frac{\sum_{j=1}^n (P_j-\bar{P})(A_j-\bar{A})} | ||
| {\max (\sqrt{\sum_{j=1}^n (P_j-\bar{P})^2 \sum_{j=1}^n (A_j-\bar{A})^2}, \epsilon)}, | ||
| \quad \bar{P}=\frac{1}{n}\sum_{j=1}^n P_j, \quad \bar{A}=\frac{1}{n}\sum_{j=1}^n A_j | ||
|
|
||
| where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. | ||
|
|
||
| - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. | ||
| - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. | ||
|
|
||
| Parameters are inherited from ``Metric.__init__``. | ||
|
|
||
| Args: | ||
| eps: a small value to avoid division by zero. Default: 1e-8 | ||
| output_transform: a callable that is used to transform the | ||
| :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
| form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
| you want to compute the metric with respect to one of the outputs. | ||
| By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. | ||
| device: specifies which device updates are accumulated on. Setting the | ||
| metric's device to be the same as your ``update`` arguments ensures the ``update`` method is | ||
| non-blocking. By default, CPU. | ||
|
|
||
| Examples: | ||
| To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. | ||
| The output of the engine's ``process_function`` needs to be in format of | ||
| ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. | ||
|
|
||
| .. include:: defaults.rst | ||
| :start-after: :orphan: | ||
|
|
||
| .. testcode:: | ||
|
|
||
| metric = PearsonCorrelation() | ||
| metric.attach(default_evaluator, 'corr') | ||
| y_true = torch.tensor([0., 1., 2., 3., 4., 5.]) | ||
| y_pred = torch.tensor([0.5, 1.3, 1.9, 2.8, 4.1, 6.0]) | ||
| state = default_evaluator.run([[y_pred, y_true]]) | ||
| print(state.metrics['corr']) | ||
|
|
||
| .. testoutput:: | ||
|
|
||
| 0.9768688678741455 | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| eps: float = 1e-8, | ||
| output_transform: Callable = lambda x: x, | ||
| device: Union[str, torch.device] = torch.device("cpu"), | ||
| ): | ||
| super().__init__(output_transform, device) | ||
|
|
||
| self.eps = eps | ||
|
|
||
| _state_dict_all_req_keys = ( | ||
| "_sum_of_y_preds", | ||
| "_sum_of_ys", | ||
| "_sum_of_y_pred_squares", | ||
| "_sum_of_y_squares", | ||
| "_sum_of_products", | ||
| "_num_examples", | ||
| ) | ||
|
|
||
| @reinit__is_reduced | ||
| def reset(self) -> None: | ||
| self._sum_of_y_preds = torch.tensor(0.0, device=self._device) | ||
| self._sum_of_ys = torch.tensor(0.0, device=self._device) | ||
| self._sum_of_y_pred_squares = torch.tensor(0.0, device=self._device) | ||
| self._sum_of_y_squares = torch.tensor(0.0, device=self._device) | ||
| self._sum_of_products = torch.tensor(0.0, device=self._device) | ||
| self._num_examples = 0 | ||
|
|
||
| def _update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: | ||
| y_pred, y = output[0].detach(), output[1].detach() | ||
| self._sum_of_y_preds += y_pred.sum() | ||
| self._sum_of_ys += y.sum() | ||
| self._sum_of_y_pred_squares += y_pred.square().sum() | ||
| self._sum_of_y_squares += y.square().sum() | ||
| self._sum_of_products += (y_pred * y).sum() | ||
| self._num_examples += y.shape[0] | ||
|
|
||
| @sync_all_reduce( | ||
| "_sum_of_y_preds", | ||
| "_sum_of_ys", | ||
| "_sum_of_y_pred_squares", | ||
| "_sum_of_y_squares", | ||
| "_sum_of_products", | ||
| "_num_examples", | ||
| ) | ||
| def compute(self) -> float: | ||
| n = self._num_examples | ||
| if n == 0: | ||
| raise NotComputableError("PearsonCorrelation must have at least one example before it can be computed.") | ||
|
|
||
| # cov = E[xy] - E[x]*E[y] | ||
| cov = self._sum_of_products / n - self._sum_of_y_preds * self._sum_of_ys / (n * n) | ||
|
|
||
| # var = E[x^2] - E[x]^2 | ||
| y_pred_mean = self._sum_of_y_preds / n | ||
| y_pred_var = self._sum_of_y_pred_squares / n - y_pred_mean * y_pred_mean | ||
| y_pred_var = torch.clamp(y_pred_var, min=0.0) | ||
|
|
||
| y_mean = self._sum_of_ys / n | ||
| y_var = self._sum_of_y_squares / n - y_mean * y_mean | ||
| y_var = torch.clamp(y_var, min=0.0) | ||
|
|
||
| r = cov / torch.clamp(torch.sqrt(y_pred_var * y_var), min=self.eps) | ||
| return float(r.item()) |
258 changes: 258 additions & 0 deletions
258
tests/ignite/metrics/regression/test_pearson_correlation.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,258 @@ | ||
| from typing import Tuple | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
| from scipy.stats import pearsonr | ||
| from torch import Tensor | ||
|
|
||
| import ignite.distributed as idist | ||
| from ignite.engine import Engine | ||
| from ignite.exceptions import NotComputableError | ||
| from ignite.metrics.regression import PearsonCorrelation | ||
|
|
||
|
|
||
| def np_corr_eps(np_y_pred: np.ndarray, np_y: np.ndarray, eps: float = 1e-8): | ||
| cov = np.cov(np_y_pred, np_y, ddof=0)[0, 1] | ||
| std_y_pred = np.std(np_y_pred, ddof=0) | ||
| std_y = np.std(np_y, ddof=0) | ||
| corr = cov / np.clip(std_y_pred * std_y, eps, None) | ||
| return corr | ||
|
|
||
|
|
||
| def scipy_corr(np_y_pred: np.ndarray, np_y: np.ndarray): | ||
| corr = pearsonr(np_y_pred, np_y) | ||
| return corr.statistic | ||
|
|
||
|
|
||
| def test_zero_sample(): | ||
| m = PearsonCorrelation() | ||
| with pytest.raises( | ||
| NotComputableError, match=r"PearsonCorrelation must have at least one example before it can be computed" | ||
| ): | ||
| m.compute() | ||
|
|
||
|
|
||
| def test_wrong_input_shapes(): | ||
| m = PearsonCorrelation() | ||
|
|
||
| with pytest.raises(ValueError, match=r"Input data shapes should be the same, but given"): | ||
| m.update((torch.rand(4), torch.rand(4, 1))) | ||
|
|
||
| with pytest.raises(ValueError, match=r"Input data shapes should be the same, but given"): | ||
| m.update((torch.rand(4, 1), torch.rand(4))) | ||
|
|
||
|
|
||
| def test_degenerated_sample(): | ||
| # one sample | ||
| m = PearsonCorrelation() | ||
| y_pred = torch.tensor([1.0]) | ||
| y = torch.tensor([1.0]) | ||
| m.update((y_pred, y)) | ||
|
|
||
| np_y_pred = y_pred.numpy() | ||
| np_y = y_pred.numpy() | ||
| np_res = np_corr_eps(np_y_pred, np_y) | ||
| assert pytest.approx(np_res) == m.compute() | ||
|
|
||
| # constant samples | ||
| m.reset() | ||
| y_pred = torch.ones(10).float() | ||
| y = torch.zeros(10).float() | ||
| m.update((y_pred, y)) | ||
|
|
||
| np_y_pred = y_pred.numpy() | ||
| np_y = y_pred.numpy() | ||
| np_res = np_corr_eps(np_y_pred, np_y) | ||
| assert pytest.approx(np_res) == m.compute() | ||
|
|
||
|
|
||
| def test_pearson_correlation(): | ||
| a = np.random.randn(4).astype(np.float32) | ||
| b = np.random.randn(4).astype(np.float32) | ||
| c = np.random.randn(4).astype(np.float32) | ||
| d = np.random.randn(4).astype(np.float32) | ||
| ground_truth = np.random.randn(4).astype(np.float32) | ||
|
|
||
| m = PearsonCorrelation() | ||
|
|
||
| m.update((torch.from_numpy(a), torch.from_numpy(ground_truth))) | ||
| np_ans = scipy_corr(a, ground_truth) | ||
| assert m.compute() == pytest.approx(np_ans, rel=1e-4) | ||
|
|
||
| m.update((torch.from_numpy(b), torch.from_numpy(ground_truth))) | ||
| np_ans = scipy_corr(np.concatenate([a, b]), np.concatenate([ground_truth] * 2)) | ||
| assert m.compute() == pytest.approx(np_ans, rel=1e-4) | ||
|
|
||
| m.update((torch.from_numpy(c), torch.from_numpy(ground_truth))) | ||
| np_ans = scipy_corr(np.concatenate([a, b, c]), np.concatenate([ground_truth] * 3)) | ||
| assert m.compute() == pytest.approx(np_ans, rel=1e-4) | ||
|
|
||
| m.update((torch.from_numpy(d), torch.from_numpy(ground_truth))) | ||
| np_ans = scipy_corr(np.concatenate([a, b, c, d]), np.concatenate([ground_truth] * 4)) | ||
| assert m.compute() == pytest.approx(np_ans, rel=1e-4) | ||
|
|
||
|
|
||
| @pytest.fixture(params=list(range(2))) | ||
| def test_case(request): | ||
| # correlated sample | ||
| x = torch.randn(size=[50]).float() | ||
| y = x + torch.randn_like(x) * 0.1 | ||
|
|
||
| return [ | ||
| (x, y, 1), | ||
| (torch.rand(size=(50, 1)).float(), torch.rand(size=(50, 1)).float(), 10), | ||
| ][request.param] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("n_times", range(5)) | ||
| def test_integration(n_times, test_case: Tuple[Tensor, Tensor, int]): | ||
| y_pred, y, batch_size = test_case | ||
|
|
||
| def update_fn(engine: Engine, batch): | ||
| idx = (engine.state.iteration - 1) * batch_size | ||
| y_true_batch = np_y[idx : idx + batch_size] | ||
| y_pred_batch = np_y_pred[idx : idx + batch_size] | ||
| return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) | ||
|
|
||
| engine = Engine(update_fn) | ||
|
|
||
| m = PearsonCorrelation() | ||
| m.attach(engine, "corr") | ||
|
|
||
| np_y = y.ravel().numpy() | ||
| np_y_pred = y_pred.ravel().numpy() | ||
|
|
||
| data = list(range(y_pred.shape[0] // batch_size)) | ||
| corr = engine.run(data, max_epochs=1).metrics["corr"] | ||
|
|
||
| np_ans = scipy_corr(np_y_pred, np_y) | ||
|
|
||
| assert pytest.approx(np_ans, rel=2e-4) == corr | ||
|
|
||
|
|
||
| def test_accumulator_detached(): | ||
| corr = PearsonCorrelation() | ||
|
|
||
| y_pred = torch.tensor([2.0, 3.0], requires_grad=True) | ||
| y = torch.tensor([-2.0, -1.0]) | ||
| corr.update((y_pred, y)) | ||
|
|
||
| assert all( | ||
| (not accumulator.requires_grad) | ||
| for accumulator in ( | ||
| corr._sum_of_products, | ||
| corr._sum_of_y_pred_squares, | ||
| corr._sum_of_y_preds, | ||
| corr._sum_of_y_squares, | ||
| corr._sum_of_ys, | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("distributed") | ||
| class TestDistributed: | ||
| def test_compute(self): | ||
| rank = idist.get_rank() | ||
| device = idist.device() | ||
| metric_devices = [torch.device("cpu")] | ||
| if device.type != "xla": | ||
| metric_devices.append(device) | ||
|
|
||
| torch.manual_seed(10 + rank) | ||
| for metric_device in metric_devices: | ||
| m = PearsonCorrelation(device=metric_device) | ||
|
|
||
| y_pred = torch.rand(size=[100], device=device) | ||
| y = torch.rand(size=[100], device=device) | ||
|
|
||
| m.update((y_pred, y)) | ||
|
|
||
| y_pred = idist.all_gather(y_pred) | ||
| y = idist.all_gather(y) | ||
|
|
||
| np_y = y.cpu().numpy() | ||
| np_y_pred = y_pred.cpu().numpy() | ||
|
|
||
| np_ans = scipy_corr(np_y_pred, np_y) | ||
|
|
||
| assert pytest.approx(np_ans) == m.compute() | ||
|
|
||
| @pytest.mark.parametrize("n_epochs", [1, 2]) | ||
| def test_integration(self, n_epochs: int): | ||
| tol = 2e-4 | ||
| rank = idist.get_rank() | ||
| device = idist.device() | ||
| metric_devices = [torch.device("cpu")] | ||
| if device.type != "xla": | ||
| metric_devices.append(device) | ||
|
|
||
| n_iters = 80 | ||
| batch_size = 16 | ||
|
|
||
| for metric_device in metric_devices: | ||
| torch.manual_seed(12 + rank) | ||
|
|
||
| y_true = torch.rand(size=(n_iters * batch_size,)).to(device) | ||
| y_preds = torch.rand(size=(n_iters * batch_size,)).to(device) | ||
|
|
||
| engine = Engine( | ||
| lambda e, i: ( | ||
| y_preds[i * batch_size : (i + 1) * batch_size], | ||
| y_true[i * batch_size : (i + 1) * batch_size], | ||
| ) | ||
| ) | ||
|
|
||
| corr = PearsonCorrelation(device=metric_device) | ||
| corr.attach(engine, "corr") | ||
|
|
||
| data = list(range(n_iters)) | ||
| engine.run(data=data, max_epochs=n_epochs) | ||
|
|
||
| y_preds = idist.all_gather(y_preds) | ||
| y_true = idist.all_gather(y_true) | ||
|
|
||
| assert "corr" in engine.state.metrics | ||
|
|
||
| res = engine.state.metrics["corr"] | ||
|
|
||
| np_y = y_true.cpu().numpy() | ||
| np_y_pred = y_preds.cpu().numpy() | ||
|
|
||
| np_ans = scipy_corr(np_y_pred, np_y) | ||
|
|
||
| assert pytest.approx(np_ans, rel=tol) == res | ||
|
|
||
| def test_accumulator_device(self): | ||
| device = idist.device() | ||
| metric_devices = [torch.device("cpu")] | ||
| if device.type != "xla": | ||
| metric_devices.append(device) | ||
| for metric_device in metric_devices: | ||
| corr = PearsonCorrelation(device=metric_device) | ||
|
|
||
| devices = ( | ||
| corr._device, | ||
| corr._sum_of_products.device, | ||
| corr._sum_of_y_pred_squares.device, | ||
| corr._sum_of_y_preds.device, | ||
| corr._sum_of_y_squares.device, | ||
| corr._sum_of_ys.device, | ||
| ) | ||
| for dev in devices: | ||
| assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" | ||
|
|
||
| y_pred = torch.tensor([2.0, 3.0]) | ||
| y = torch.tensor([-1.0, 1.0]) | ||
| corr.update((y_pred, y)) | ||
|
|
||
| devices = ( | ||
| corr._device, | ||
| corr._sum_of_products.device, | ||
| corr._sum_of_y_pred_squares.device, | ||
| corr._sum_of_y_preds.device, | ||
| corr._sum_of_y_squares.device, | ||
| corr._sum_of_ys.device, | ||
| ) | ||
| for dev in devices: | ||
| assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.