Skip to content

Commit aa3fc9b

Browse files
Fix flake8 issue and some bugs
1 parent 9f9d323 commit aa3fc9b

3 files changed

Lines changed: 17 additions & 12 deletions

File tree

ignite/metrics/precision.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ def _check_type(self, output: Sequence[torch.Tensor]) -> None:
4141
@reinit__is_reduced
4242
def reset(self) -> None:
4343
if self._average == "samples":
44-
self._sum_samples_metric = torch.tensor(0, device=self._device) # type: torch.Tensor
44+
self._sum_samples_metric = 0 # type: Union[int, torch.Tensor]
4545
self._samples_cnt = 0 # type: int
4646
else:
47-
self._true_positives = torch.tensor(0, device=self._device) # type: torch.Tensor
48-
self._positives = torch.tensor(0, device=self._device) # type: torch.Tensor
47+
self._true_positives = 0 # type: Union[int, torch.Tensor]
48+
self._positives = 0 # type: Union[int, torch.Tensor]
4949

5050
if self._average == "weighted":
51-
self._actual_positives = torch.tensor(0, device=self._device) # type: torch.Tensor
51+
self._actual_positives = 0 # type: Union[int, torch.Tensor]
5252
self._updated = False
5353

5454
super(_BasePrecisionRecall, self).reset()
@@ -70,15 +70,16 @@ def compute(self) -> Union[torch.Tensor, float]:
7070
self._is_reduced = True # type: bool
7171

7272
if self._average == "samples":
73-
return (self._sum_samples_metric / self._samples_cnt).item()
73+
return (self._sum_samples_metric / self._samples_cnt).item() # type: ignore
7474

7575
result = self._true_positives / (self._positives + self.eps)
7676
if self._average == "weighted":
77-
return ((result @ self._actual_positives) / (self._actual_positives.sum() + self.eps)).item()
77+
denominator = self._actual_positives.sum() + self.eps # type: ignore
78+
return ((result @ self._actual_positives) / denominator).item() # type: ignore
7879
elif self._average == "micro":
79-
return result.item()
80+
return result.item() # type: ignore
8081
else:
81-
return result if self._type != "binary" else result.item()
82+
return result if self._type != "binary" else result.item() # type: ignore
8283

8384

8485
class Precision(_BasePrecisionRecall):

tests/ignite/metrics/test_precision.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,8 @@ def _test(average, metric_device):
539539
if average == "weighted":
540540
assert (
541541
pr._actual_positives.device == metric_device
542-
), f"{type(pr._actual_positives.device)}:{pr._actual_positives.device} vs {type(metric_device)}:{metric_device}"
542+
), f"{type(pr._actual_positives.device)}:{pr._actual_positives.device} vs "
543+
f"{type(metric_device)}:{metric_device}"
543544

544545
metric_devices = [torch.device("cpu")]
545546
if device.type != "xla":
@@ -566,7 +567,8 @@ def _test(average, metric_device):
566567
if average == "samples":
567568
assert (
568569
pr._sum_samples_metric.device == metric_device
569-
), f"{type(pr._sum_samples_metric.device)}:{pr._sum_samples_metric.device} vs {type(metric_device)}:{metric_device}"
570+
), f"{type(pr._sum_samples_metric.device)}:{pr._sum_samples_metric.device} vs "
571+
f"{type(metric_device)}:{metric_device}"
570572
else:
571573
assert (
572574
pr._true_positives.device == metric_device
@@ -577,7 +579,8 @@ def _test(average, metric_device):
577579
if average == "weighted":
578580
assert (
579581
pr._actual_positives.device == metric_device
580-
), f"{type(pr._actual_positives.device)}:{pr._actual_positives.device} vs {type(metric_device)}:{metric_device}"
582+
), f"{type(pr._actual_positives.device)}:{pr._actual_positives.device} vs "
583+
f"{type(metric_device)}:{metric_device}"
581584

582585
metric_devices = [torch.device("cpu")]
583586
if device.type != "xla":

tests/ignite/metrics/test_recall.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,8 @@ def _test(average, metric_device):
553553
if average == "samples":
554554
assert (
555555
re._sum_samples_metric.device == metric_device
556-
), f"{type(re._sum_samples_metric.device)}:{re._sum_samples_metric.device} vs {type(metric_device)}:{metric_device}"
556+
), f"{type(re._sum_samples_metric.device)}:{re._sum_samples_metric.device} vs "
557+
f"{type(metric_device)}:{metric_device}"
557558
else:
558559
assert (
559560
re._true_positives.device == metric_device

0 commit comments

Comments
 (0)