diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index b0d31518e1c4..330311c8a78e 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -219,7 +219,7 @@ def __init__( @abstractmethod def reset(self) -> None: """ - Resets the metric to it's initial state. + Resets the metric to its initial state. By default, this is called at the start of each epoch. """ @@ -240,7 +240,7 @@ def update(self, output: Any) -> None: @abstractmethod def compute(self) -> Any: """ - Computes the metric based on it's accumulated state. + Computes the metric based on its accumulated state. By default, this is called at the end of each epoch. @@ -273,7 +273,7 @@ def iteration_completed(self, engine: Engine) -> None: Note: ``engine.state.output`` is used to compute metric values. - The majority of implemented metrics accepts the following formats for ``engine.state.output``: + The majority of implemented metrics accept the following formats for ``engine.state.output``: ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. ``y_pred`` and ``y`` can be torch tensors or list of tensors/numbers if applicable. diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index dbbfa1ba1a05..111b77962b9e 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -61,8 +61,8 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens num_classes = 2 if self._type == "binary" else y_pred.size(1) if self._type == "multiclass" and y.max() + 1 > num_classes: raise ValueError( - f"y_pred contains less classes than y. Number of predicted classes is {num_classes}" - f" and element in y has invalid class = {y.max().item() + 1}." + f"y_pred contains fewer classes than y. Number of classes in the prediction is {num_classes}" + f" and an element in y has invalid class = {y.max().item() + 1}." ) y = y.view(-1) if self._type == "binary" and self._average is False: @@ -86,30 +86,32 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens @reinit__is_reduced def reset(self) -> None: - # `numerator`, `denominator` and `weight` are three variables chosen to be abstract - # representatives of the ones that are measured for cases with different `average` parameters. - # `weight` is only used when `average='weighted'`. Actual value of these three variables is - # as follows. - # - # average='samples': - # numerator (torch.Tensor): sum of metric value for samples - # denominator (int): number of samples - # - # average='weighted': - # numerator (torch.Tensor): number of true positives per class/label - # denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) - # positives per class/label - # weight (torch.Tensor): number of actual positives per class - # - # average='micro': - # numerator (torch.Tensor): sum of number of true positives for classes/labels - # denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives - # for classes/labels - # - # average='macro' or boolean or None: - # numerator (torch.Tensor): number of true positives per class/label - # denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) - # positives per class/label + """ + `numerator`, `denominator` and `weight` are three variables chosen to be abstract + representatives of the ones that are measured for cases with different `average` parameters. + `weight` is only used when `average='weighted'`. Actual value of these three variables is + as follows. + + average='samples': + numerator (torch.Tensor): sum of metric value for samples + denominator (int): number of samples + + average='weighted': + numerator (torch.Tensor): number of true positives per class/label + denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per + class/label. + weight (torch.Tensor): number of actual positives per class + + average='micro': + numerator (torch.Tensor): sum of number of true positives for classes/labels + denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives for + classes/labels. + + average='macro' or boolean or None: + numerator (torch.Tensor): number of true positives per class/label + denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per + class/label. + """ self._numerator: Union[int, torch.Tensor] = 0 self._denominator: Union[int, torch.Tensor] = 0 @@ -120,16 +122,20 @@ def reset(self) -> None: @sync_all_reduce("_numerator", "_denominator") def compute(self) -> Union[torch.Tensor, float]: - # Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows. - # - # .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight - # - # wherein `weight` is the internal variable `weight` for `'weighted'` option and :math:`1/C` - # for the `macro` one. :math:`C` is the number of classes/labels. - # - # Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows. - # - # .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator } + r""" + Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows. + + .. math:: + \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight + + wherein `weight` is the internal variable `_weight` for `'weighted'` option and :math:`1/C` + for the `macro` one. :math:`C` is the number of classes/labels. + + Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows. + + .. math:: + \text{Precision/Recall} = \frac{ numerator }{ denominator } + """ if not self._updated: raise NotComputableError( @@ -367,6 +373,33 @@ def thresholded_output_transform(output): @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: + r""" + Update the metric state using prediction and target. + + Args: + output: a binary tuple of tensors (y_pred, y) whose shapes follow the table below. N stands for the batch + dimension, `...` for possible additional dimensions and C for class dimension. + + .. list-table:: + :widths: 20 10 10 10 + :header-rows: 1 + + * - Output member\\Data type + - Binary + - Multiclass + - Multilabel + * - y_pred + - (N, ...) + - (N, C, ...) + - (N, C, ...) + * - y + - (N, ...) + - (N, ...) + - (N, C, ...) + + For binary and multilabel data, both y and y_pred should consist of 0's and 1's, but for multiclass + data, y_pred and y should consist of probabilities and integers respectively. + """ self._check_shape(output) self._check_type(output) y_pred, y, correct = self._prepare_output(output)