@@ -41,14 +41,14 @@ def _check_type(self, output: Sequence[torch.Tensor]) -> None:
4141 @reinit__is_reduced
4242 def reset (self ) -> None :
4343 if self ._average == "samples" :
44- self ._sum_samples_metric = torch . tensor ( 0 , device = self . _device ) # type: torch.Tensor
44+ self ._sum_samples_metric = 0 # type: Union[int, torch.Tensor]
4545 self ._samples_cnt = 0 # type: int
4646 else :
47- self ._true_positives = torch . tensor ( 0 , device = self . _device ) # type: torch.Tensor
48- self ._positives = torch . tensor ( 0 , device = self . _device ) # type: torch.Tensor
47+ self ._true_positives = 0 # type: Union[int, torch.Tensor]
48+ self ._positives = 0 # type: Union[int, torch.Tensor]
4949
5050 if self ._average == "weighted" :
51- self ._actual_positives = torch . tensor ( 0 , device = self . _device ) # type: torch.Tensor
51+ self ._actual_positives = 0 # type: Union[int, torch.Tensor]
5252 self ._updated = False
5353
5454 super (_BasePrecisionRecall , self ).reset ()
@@ -70,15 +70,16 @@ def compute(self) -> Union[torch.Tensor, float]:
7070 self ._is_reduced = True # type: bool
7171
7272 if self ._average == "samples" :
73- return (self ._sum_samples_metric / self ._samples_cnt ).item ()
73+ return (self ._sum_samples_metric / self ._samples_cnt ).item () # type: ignore
7474
7575 result = self ._true_positives / (self ._positives + self .eps )
7676 if self ._average == "weighted" :
77- return ((result @ self ._actual_positives ) / (self ._actual_positives .sum () + self .eps )).item ()
77+ denominator = self ._actual_positives .sum () + self .eps # type: ignore
78+ return ((result @ self ._actual_positives ) / denominator ).item () # type: ignore
7879 elif self ._average == "micro" :
79- return result .item ()
80+ return result .item () # type: ignore
8081 else :
81- return result if self ._type != "binary" else result .item ()
82+ return result if self ._type != "binary" else result .item () # type: ignore
8283
8384
8485class Precision (_BasePrecisionRecall ):
0 commit comments