@@ -188,7 +188,7 @@ class Precision(_BasePrecisionRecall):
188188 weighted_metric = Precision(average='weighted')
189189 metric.attach(default_evaluator, "precision")
190190 weighted_metric.attach(default_evaluator, "weighted precision")
191- y_true = torch.Tensor ([1, 0, 1, 1, 0, 1]).long( )
191+ y_true = torch.tensor ([1, 0, 1, 1, 0, 1])
192192 y_pred = torch.Tensor([1, 0, 1, 0, 1, 1])
193193 state = default_evaluator.run([[y_pred, y_true]])
194194 print(f"Precision: {state.metrics['precision']}")
@@ -211,7 +211,7 @@ class Precision(_BasePrecisionRecall):
211211 macro_metric.attach(default_evaluator, "macro precision")
212212 weighted_metric.attach(default_evaluator, "weighted precision")
213213
214- y_true = torch.Tensor ([2, 0, 2, 1, 0]).long( )
214+ y_true = torch.tensor ([2, 0, 2, 1, 0])
215215 y_pred = torch.Tensor([
216216 [0.0266, 0.1719, 0.3055],
217217 [0.6886, 0.3978, 0.8176],
@@ -286,14 +286,14 @@ def thresholded_output_transform(output):
286286
287287 metric = Precision(output_transform=thresholded_output_transform)
288288 metric.attach(default_evaluator, "precision")
289- y_true = torch.Tensor ([1, 0, 1, 1, 0, 1]).long( )
289+ y_true = torch.tensor ([1, 0, 1, 1, 0, 1])
290290 y_pred = torch.Tensor([0.6, 0.2, 0.9, 0.4, 0.7, 0.65])
291291 state = default_evaluator.run([[y_pred, y_true]])
292292 print(state.metrics["precision"])
293293
294294 .. testoutput:: 4
295295
296- [0.5 , 0.75]
296+ tensor( [0.5000 , 0.7500], dtype=torch.float64)
297297
298298
299299 .. versionchanged:: 0.5.0
@@ -306,23 +306,19 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
306306 self ._check_type (output )
307307 y_pred , y = output [0 ].detach (), output [1 ].detach ()
308308
309- if self ._type == "binary" :
309+ if self ._type == "binary" or self . _type == "multiclass" :
310310
311- y = to_onehot (y .view (- 1 ), num_classes = 2 )
312- y_pred = to_onehot (y_pred .view (- 1 ).long (), num_classes = 2 )
313- elif self ._type == "multiclass" :
314-
315- num_classes = y_pred .size (1 )
316- if y .max () + 1 > num_classes :
311+ num_classes = 2 if self ._type == "binary" else y_pred .size (1 )
312+ if self ._type == "multiclass" and y .max () + 1 > num_classes :
317313 raise ValueError (
318314 f"y_pred contains less classes than y. Number of predicted classes is { num_classes } "
319315 f" and element in y has invalid class = { y .max ().item () + 1 } ."
320316 )
321317 y = to_onehot (y .view (- 1 ), num_classes = num_classes )
322- indices = torch .argmax (y_pred , dim = 1 ). view ( - 1 )
323- y_pred = to_onehot (indices , num_classes = num_classes )
318+ indices = torch .argmax (y_pred , dim = 1 ) if self . _type == "multiclass" else y_pred . long ( )
319+ y_pred = to_onehot (indices . view ( - 1 ) , num_classes = num_classes )
324320 elif self ._type == "multilabel" :
325-
321+ # if y, y_pred shape is (N, C, ...) -> (N * ..., C)
326322 num_labels = y_pred .size (1 )
327323 y_pred = torch .transpose (y_pred , 1 , - 1 ).reshape (- 1 , num_labels )
328324 y = torch .transpose (y , 1 , - 1 ).reshape (- 1 , num_labels )
0 commit comments