Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ rebuild:
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@ $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
6 changes: 3 additions & 3 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(
@abstractmethod
def reset(self) -> None:
"""
Resets the metric to it's initial state.
Resets the metric to its initial state.

By default, this is called at the start of each epoch.
"""
Expand All @@ -240,7 +240,7 @@ def update(self, output: Any) -> None:
@abstractmethod
def compute(self) -> Any:
"""
Computes the metric based on it's accumulated state.
Computes the metric based on its accumulated state.

By default, this is called at the end of each epoch.

Expand Down Expand Up @@ -273,7 +273,7 @@ def iteration_completed(self, engine: Engine) -> None:

Note:
``engine.state.output`` is used to compute metric values.
The majority of implemented metrics accepts the following formats for ``engine.state.output``:
The majority of implemented metrics accept the following formats for ``engine.state.output``:
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. ``y_pred`` and ``y`` can be torch tensors or
list of tensors/numbers if applicable.

Expand Down
105 changes: 69 additions & 36 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
num_classes = 2 if self._type == "binary" else y_pred.size(1)
if self._type == "multiclass" and y.max() + 1 > num_classes:
raise ValueError(
f"y_pred contains less classes than y. Number of predicted classes is {num_classes}"
f" and element in y has invalid class = {y.max().item() + 1}."
f"y_pred contains fewer classes than y. Number of classes in the prediction is {num_classes}"
f" and an element in y has invalid class = {y.max().item() + 1}."
)
y = y.view(-1)
if self._type == "binary" and self._average is False:
Expand All @@ -86,30 +86,32 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens

@reinit__is_reduced
def reset(self) -> None:
# `numerator`, `denominator` and `weight` are three variables chosen to be abstract
# representatives of the ones that are measured for cases with different `average` parameters.
# `weight` is only used when `average='weighted'`. Actual value of these three variables is
# as follows.
#
# average='samples':
# numerator (torch.Tensor): sum of metric value for samples
# denominator (int): number of samples
#
# average='weighted':
# numerator (torch.Tensor): number of true positives per class/label
# denominator (torch.Tensor): number of predicted(for precision) or actual(for recall)
# positives per class/label
# weight (torch.Tensor): number of actual positives per class
#
# average='micro':
# numerator (torch.Tensor): sum of number of true positives for classes/labels
# denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives
# for classes/labels
#
# average='macro' or boolean or None:
# numerator (torch.Tensor): number of true positives per class/label
# denominator (torch.Tensor): number of predicted(for precision) or actual(for recall)
# positives per class/label
"""
`numerator`, `denominator` and `weight` are three variables chosen to be abstract
representatives of the ones that are measured for cases with different `average` parameters.
`weight` is only used when `average='weighted'`. Actual value of these three variables is
as follows.

average='samples':
numerator (torch.Tensor): sum of metric value for samples
denominator (int): number of samples

average='weighted':
numerator (torch.Tensor): number of true positives per class/label
denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per
class/label.
weight (torch.Tensor): number of actual positives per class

average='micro':
numerator (torch.Tensor): sum of number of true positives for classes/labels
denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives for
classes/labels.

average='macro' or boolean or None:
numerator (torch.Tensor): number of true positives per class/label
denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per
class/label.
"""

self._numerator: Union[int, torch.Tensor] = 0
self._denominator: Union[int, torch.Tensor] = 0
Expand All @@ -120,16 +122,20 @@ def reset(self) -> None:

@sync_all_reduce("_numerator", "_denominator")
def compute(self) -> Union[torch.Tensor, float]:
# Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.
#
# .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight
#
# wherein `weight` is the internal variable `weight` for `'weighted'` option and :math:`1/C`
# for the `macro` one. :math:`C` is the number of classes/labels.
#
# Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows.
#
# .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator }
r"""
Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.

.. math::
\text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight

wherein `weight` is the internal variable `_weight` for `'weighted'` option and :math:`1/C`
for the `macro` one. :math:`C` is the number of classes/labels.

Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows.

.. math::
\text{Precision/Recall} = \frac{ numerator }{ denominator }
"""

if not self._updated:
raise NotComputableError(
Expand Down Expand Up @@ -367,6 +373,33 @@ def thresholded_output_transform(output):

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
r"""
Update the metric state using prediction and target.

Args:
output: a binary tuple of tensors (y_pred, y) whose shapes follow the table below. N stands for the batch
dimension, `...` for possible additional dimensions and C for class dimension.

.. list-table::
:widths: 20 10 10 10
:header-rows: 1

* - Output member\\Data type
- Binary
- Multiclass
- Multilabel
* - y_pred
- (N, ...)
- (N, C, ...)
- (N, C, ...)
* - y
- (N, ...)
- (N, ...)
- (N, C, ...)

For binary and multilabel data, both y and y_pred should consist of 0's and 1's, but for multiclass
data, y_pred and y should consist of probabilities and integers respectively.
"""
self._check_shape(output)
self._check_type(output)
y_pred, y, correct = self._prepare_output(output)
Expand Down