From b6641dcc36f5600011b12af55f24b15d603f4ca8 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 17 May 2023 21:20:38 +0330 Subject: [PATCH 1/4] Apply improvements --- docs/Makefile | 4 +- ignite/metrics/metric.py | 6 +-- ignite/metrics/precision.py | 105 +++++++++++++++++++++++------------- 3 files changed, 74 insertions(+), 41 deletions(-) diff --git a/docs/Makefile b/docs/Makefile index eedf03332d70..a7abc304dc9a 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -20,11 +20,11 @@ docset: html convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png rebuild: - rm -rf source/generated && make clean && make html + rm -rf source/generated && find ../ignite -name *.pyc -delete && rm -rf source/__pycache__ && make clean && make html .PHONY: help Makefile docset # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + @ find ../ignite -name *.pyc -delete && $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index b0d31518e1c4..330311c8a78e 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -219,7 +219,7 @@ def __init__( @abstractmethod def reset(self) -> None: """ - Resets the metric to it's initial state. + Resets the metric to its initial state. By default, this is called at the start of each epoch. """ @@ -240,7 +240,7 @@ def update(self, output: Any) -> None: @abstractmethod def compute(self) -> Any: """ - Computes the metric based on it's accumulated state. + Computes the metric based on its accumulated state. By default, this is called at the end of each epoch. @@ -273,7 +273,7 @@ def iteration_completed(self, engine: Engine) -> None: Note: ``engine.state.output`` is used to compute metric values. - The majority of implemented metrics accepts the following formats for ``engine.state.output``: + The majority of implemented metrics accept the following formats for ``engine.state.output``: ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. ``y_pred`` and ``y`` can be torch tensors or list of tensors/numbers if applicable. diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index dbbfa1ba1a05..ab32152cf28b 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -61,8 +61,8 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens num_classes = 2 if self._type == "binary" else y_pred.size(1) if self._type == "multiclass" and y.max() + 1 > num_classes: raise ValueError( - f"y_pred contains less classes than y. Number of predicted classes is {num_classes}" - f" and element in y has invalid class = {y.max().item() + 1}." + f"y_pred contains fewer classes than y. Number of classes in prediction is {num_classes}" + f" and an element in y has invalid class = {y.max().item() + 1}." ) y = y.view(-1) if self._type == "binary" and self._average is False: @@ -86,30 +86,32 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens @reinit__is_reduced def reset(self) -> None: - # `numerator`, `denominator` and `weight` are three variables chosen to be abstract - # representatives of the ones that are measured for cases with different `average` parameters. - # `weight` is only used when `average='weighted'`. Actual value of these three variables is - # as follows. - # - # average='samples': - # numerator (torch.Tensor): sum of metric value for samples - # denominator (int): number of samples - # - # average='weighted': - # numerator (torch.Tensor): number of true positives per class/label - # denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) - # positives per class/label - # weight (torch.Tensor): number of actual positives per class - # - # average='micro': - # numerator (torch.Tensor): sum of number of true positives for classes/labels - # denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives - # for classes/labels - # - # average='macro' or boolean or None: - # numerator (torch.Tensor): number of true positives per class/label - # denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) - # positives per class/label + """ + `numerator`, `denominator` and `weight` are three variables chosen to be abstract + representatives of the ones that are measured for cases with different `average` parameters. + `weight` is only used when `average='weighted'`. Actual value of these three variables is + as follows. + + average='samples': + numerator (torch.Tensor): sum of metric value for samples + denominator (int): number of samples + + average='weighted': + numerator (torch.Tensor): number of true positives per class/label + denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per + class/label. + weight (torch.Tensor): number of actual positives per class + + average='micro': + numerator (torch.Tensor): sum of number of true positives for classes/labels + denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives for + classes/labels. + + average='macro' or boolean or None: + numerator (torch.Tensor): number of true positives per class/label + denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per + class/label. + """ self._numerator: Union[int, torch.Tensor] = 0 self._denominator: Union[int, torch.Tensor] = 0 @@ -120,16 +122,20 @@ def reset(self) -> None: @sync_all_reduce("_numerator", "_denominator") def compute(self) -> Union[torch.Tensor, float]: - # Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows. - # - # .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight - # - # wherein `weight` is the internal variable `weight` for `'weighted'` option and :math:`1/C` - # for the `macro` one. :math:`C` is the number of classes/labels. - # - # Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows. - # - # .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator } + r""" + Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows. + + .. math:: + \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight + + wherein `weight` is the internal variable `_weight` for `'weighted'` option and :math:`1/C` + for the `macro` one. :math:`C` is the number of classes/labels. + + Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows. + + .. math:: + \text{Precision/Recall} = \frac{ numerator }{ denominator } + """ if not self._updated: raise NotComputableError( @@ -367,6 +373,33 @@ def thresholded_output_transform(output): @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: + r""" + Update the metric state using prediction and target. + + Args: + output: a binary tuple of tensors (y_pred, y) whose shapes follow the table below. N stands for the batch + dimension, `...` for possible additional dimensions and C for class dimension. + + .. list-table:: + :widths: 20 10 10 10 + :header-rows: 1 + + * - Output member\\Data type + - Binary + - Multiclass + - Multilabel + * - y_pred + - (N, ...) + - (N, C, ...) + - (N, C, ...) + * - y + - (N, ...) + - (N, ...) + - (N, C, ...) + + For binary and multilabel data, both y and y_pred should consist of 0's and 1's, but for multiclass + data, y_pred and y should consist of probabilities and integers respectively. + """ self._check_shape(output) self._check_type(output) y_pred, y, correct = self._prepare_output(output) From 31e13075b088364f3a8ab6bd54b4325bea04a7b4 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 17 May 2023 22:59:05 +0330 Subject: [PATCH 2/4] Fix a typo --- ignite/metrics/precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index ab32152cf28b..111b77962b9e 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -61,7 +61,7 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens num_classes = 2 if self._type == "binary" else y_pred.size(1) if self._type == "multiclass" and y.max() + 1 > num_classes: raise ValueError( - f"y_pred contains fewer classes than y. Number of classes in prediction is {num_classes}" + f"y_pred contains fewer classes than y. Number of classes in the prediction is {num_classes}" f" and an element in y has invalid class = {y.max().item() + 1}." ) y = y.view(-1) From 58dc190e3bcb83bdf9ed82ab3fd99c7f84827eea Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Thu, 18 May 2023 19:11:47 +0330 Subject: [PATCH 3/4] Revert Makefile change --- docs/Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Makefile b/docs/Makefile index a7abc304dc9a..fb2be177c141 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -20,11 +20,11 @@ docset: html convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png rebuild: - rm -rf source/generated && find ../ignite -name *.pyc -delete && rm -rf source/__pycache__ && make clean && make html + rm -rf source/generated && make clean && make html .PHONY: help Makefile docset # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile - @ find ../ignite -name *.pyc -delete && $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + @ $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) From e1cf1a445cafdf6c22afc3936067084851a21090 Mon Sep 17 00:00:00 2001 From: vfdev Date: Sat, 20 May 2023 22:39:04 +0200 Subject: [PATCH 4/4] Update docs/Makefile --- docs/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Makefile b/docs/Makefile index fb2be177c141..eedf03332d70 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -27,4 +27,4 @@ rebuild: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile - @ $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)