Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ Complete list of metrics
MeanPairwiseDistance
MeanSquaredError
metric.Metric
metric_group.MetricGroup
metrics_lambda.MetricsLambda
MultiLabelConfusionMatrix
MutualInformation
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
from ignite.metrics.metric_group import MetricGroup
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix
from ignite.metrics.mutual_information import MutualInformation
Expand All @@ -41,6 +42,7 @@
"Metric",
"Accuracy",
"Loss",
"MetricGroup",
"MetricsLambda",
"MeanAbsoluteError",
"MeanPairwiseDistance",
Expand Down
51 changes: 51 additions & 0 deletions ignite/metrics/metric_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Any, Callable, Dict, Sequence

import torch

from ignite.metrics import Metric


class MetricGroup(Metric):
"""
A class for grouping metrics so that user could manage them easier.

Args:
metrics: a dictionary of names to metric instances.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. `output_transform` of each metric in the group is also
called upon its update.

Examples:
We construct a group of metrics, attach them to the engine at once and retrieve their result.

.. code-block:: python
metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())})
metric_group.attach(default_evaluator, "eval_metrics")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.tensor([1, 0, 1, 0, 1, 1])
state = default_evaluator.run([[y_pred, y_true]])

# Metrics individually available in `state.metrics`
state.metrics["acc"], state.metrics["precision"], state.metrics["loss"]

# And also altogether
state.metrics["eval_metrics]
"""

_state_dict_all_req_keys = ("metrics",)

def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x):
self.metrics = metrics
super(MetricGroup, self).__init__(output_transform=output_transform)

def reset(self):
for m in self.metrics.values():
m.reset()

def update(self, output: Sequence[torch.Tensor]):
for m in self.metrics.values():
m.update(m._output_transform(output))

def compute(self) -> Dict[str, Any]:
return {k: m.compute() for k, m in self.metrics.items()}
118 changes: 118 additions & 0 deletions tests/ignite/metrics/test_metric_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pytest
import torch

from ignite import distributed as idist
from ignite.engine import Engine
from ignite.metrics import Accuracy, MetricGroup, Precision

torch.manual_seed(41)


def test_update():
precision = Precision()
accuracy = Accuracy()

group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()})

y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update((y_pred, y))
accuracy.update((y_pred, y))
group.update((y_pred, y))

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


def test_output_transform():
def drop_first(output):
y_pred, y = output
return (y_pred[1:], y[1:])

precision = Precision(output_transform=drop_first)
accuracy = Accuracy(output_transform=drop_first)

group = MetricGroup(
{"precision": Precision(output_transform=drop_first), "accuracy": Accuracy(output_transform=drop_first)}
)

y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update(drop_first(drop_first((y_pred, y))))
accuracy.update(drop_first(drop_first((y_pred, y))))
group.update(drop_first((y_pred, y)))

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


def test_compute():
precision = Precision()
accuracy = Accuracy()

group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()})

for _ in range(3):
y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update((y_pred, y))
accuracy.update((y_pred, y))
group.update((y_pred, y))

assert group.compute() == {"precision": precision.compute(), "accuracy": accuracy.compute()}

precision.reset()
accuracy.reset()
group.reset()

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


@pytest.mark.usefixtures("distributed")
class TestDistributed:
def test_integration(self):
rank = idist.get_rank()
torch.manual_seed(12 + rank)

n_epochs = 3
n_iters = 5
batch_size = 10
device = idist.device()

y_true = torch.randint(0, 2, size=(n_iters * batch_size,)).to(device)
y_pred = torch.randint(0, 2, (n_iters * batch_size,)).to(device)

def update(_, i):
return (
y_pred[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)

precision = Precision()
precision.attach(engine, "precision")

accuracy = Accuracy()
accuracy.attach(engine, "accuracy")

group = MetricGroup({"eval_metrics.accuracy": Accuracy(), "eval_metrics.precision": Precision()})
group.attach(engine, "eval_metrics")

data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

assert "eval_metrics" in engine.state.metrics
assert "eval_metrics.accuracy" in engine.state.metrics
assert "eval_metrics.precision" in engine.state.metrics

assert engine.state.metrics["eval_metrics"] == {
"eval_metrics.accuracy": engine.state.metrics["accuracy"],
"eval_metrics.precision": engine.state.metrics["precision"],
}
assert engine.state.metrics["eval_metrics.accuracy"] == engine.state.metrics["accuracy"]
assert engine.state.metrics["eval_metrics.precision"] == engine.state.metrics["precision"]