Skip to content

Commit d450903

Browse files
Do some refactors and improvements
1 parent 214088b commit d450903

File tree

4 files changed

+20
-28
lines changed

4 files changed

+20
-28
lines changed

ignite/metrics/precision.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class Precision(_BasePrecisionRecall):
188188
weighted_metric = Precision(average='weighted')
189189
metric.attach(default_evaluator, "precision")
190190
weighted_metric.attach(default_evaluator, "weighted precision")
191-
y_true = torch.Tensor([1, 0, 1, 1, 0, 1]).long()
191+
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
192192
y_pred = torch.Tensor([1, 0, 1, 0, 1, 1])
193193
state = default_evaluator.run([[y_pred, y_true]])
194194
print(f"Precision: {state.metrics['precision']}")
@@ -211,7 +211,7 @@ class Precision(_BasePrecisionRecall):
211211
macro_metric.attach(default_evaluator, "macro precision")
212212
weighted_metric.attach(default_evaluator, "weighted precision")
213213
214-
y_true = torch.Tensor([2, 0, 2, 1, 0]).long()
214+
y_true = torch.tensor([2, 0, 2, 1, 0])
215215
y_pred = torch.Tensor([
216216
[0.0266, 0.1719, 0.3055],
217217
[0.6886, 0.3978, 0.8176],
@@ -286,14 +286,14 @@ def thresholded_output_transform(output):
286286
287287
metric = Precision(output_transform=thresholded_output_transform)
288288
metric.attach(default_evaluator, "precision")
289-
y_true = torch.Tensor([1, 0, 1, 1, 0, 1]).long()
289+
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
290290
y_pred = torch.Tensor([0.6, 0.2, 0.9, 0.4, 0.7, 0.65])
291291
state = default_evaluator.run([[y_pred, y_true]])
292292
print(state.metrics["precision"])
293293
294294
.. testoutput:: 4
295295
296-
[0.5, 0.75]
296+
tensor([0.5000, 0.7500], dtype=torch.float64)
297297
298298
299299
.. versionchanged:: 0.5.0
@@ -306,23 +306,19 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
306306
self._check_type(output)
307307
y_pred, y = output[0].detach(), output[1].detach()
308308

309-
if self._type == "binary":
309+
if self._type == "binary" or self._type == "multiclass":
310310

311-
y = to_onehot(y.view(-1), num_classes=2)
312-
y_pred = to_onehot(y_pred.view(-1).long(), num_classes=2)
313-
elif self._type == "multiclass":
314-
315-
num_classes = y_pred.size(1)
316-
if y.max() + 1 > num_classes:
311+
num_classes = 2 if self._type == "binary" else y_pred.size(1)
312+
if self._type == "multiclass" and y.max() + 1 > num_classes:
317313
raise ValueError(
318314
f"y_pred contains less classes than y. Number of predicted classes is {num_classes}"
319315
f" and element in y has invalid class = {y.max().item() + 1}."
320316
)
321317
y = to_onehot(y.view(-1), num_classes=num_classes)
322-
indices = torch.argmax(y_pred, dim=1).view(-1)
323-
y_pred = to_onehot(indices, num_classes=num_classes)
318+
indices = torch.argmax(y_pred, dim=1) if self._type == "multiclass" else y_pred.long()
319+
y_pred = to_onehot(indices.view(-1), num_classes=num_classes)
324320
elif self._type == "multilabel":
325-
321+
# if y, y_pred shape is (N, C, ...) -> (N * ..., C)
326322
num_labels = y_pred.size(1)
327323
y_pred = torch.transpose(y_pred, 1, -1).reshape(-1, num_labels)
328324
y = torch.transpose(y, 1, -1).reshape(-1, num_labels)

ignite/metrics/recall.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -216,23 +216,19 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
216216
self._check_type(output)
217217
y_pred, y = output[0].detach(), output[1].detach()
218218

219-
if self._type == "binary":
219+
if self._type == "binary" or self._type == "multiclass":
220220

221-
y = to_onehot(y.view(-1), num_classes=2)
222-
y_pred = to_onehot(y_pred.view(-1).long(), num_classes=2)
223-
elif self._type == "multiclass":
224-
225-
num_classes = y_pred.size(1)
226-
if y.max() + 1 > num_classes:
221+
num_classes = 2 if self._type == "binary" else y_pred.size(1)
222+
if self._type == "multiclass" and y.max() + 1 > num_classes:
227223
raise ValueError(
228224
f"y_pred contains less classes than y. Number of predicted classes is {num_classes}"
229225
f" and element in y has invalid class = {y.max().item() + 1}."
230226
)
231227
y = to_onehot(y.view(-1), num_classes=num_classes)
232-
indices = torch.argmax(y_pred, dim=1).view(-1)
233-
y_pred = to_onehot(indices, num_classes=num_classes)
228+
indices = torch.argmax(y_pred, dim=1) if self._type == "multiclass" else y_pred.long()
229+
y_pred = to_onehot(indices.view(-1), num_classes=num_classes)
234230
elif self._type == "multilabel":
235-
231+
# if y, y_pred shape is (N, C, ...) -> (N * ..., C)
236232
num_labels = y_pred.size(1)
237233
y_pred = torch.transpose(y_pred, 1, -1).reshape(-1, num_labels)
238234
y = torch.transpose(y, 1, -1).reshape(-1, num_labels)

tests/ignite/metrics/test_precision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def test_multiclass_wrong_inputs():
169169
pr.update((torch.rand(10), torch.randint(0, 5, size=(10, 5, 6)).long()))
170170
assert pr._updated is False
171171

172-
pr = Precision()
172+
pr = Precision(average=True)
173173
assert pr._updated is False
174174

175175
with pytest.raises(ValueError):
@@ -184,7 +184,7 @@ def test_multiclass_wrong_inputs():
184184
pr.update((torch.rand(10, 6, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long()))
185185
assert pr._updated is True
186186

187-
pr = Precision()
187+
pr = Precision(average=False)
188188
assert pr._updated is False
189189

190190
with pytest.raises(ValueError):

tests/ignite/metrics/test_recall.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def test_multiclass_wrong_inputs():
171171
re.update((torch.rand(10), torch.randint(0, 5, size=(10, 5, 6)).long()))
172172
assert re._updated is False
173173

174-
re = Recall()
174+
re = Recall(average=True)
175175
assert re._updated is False
176176

177177
with pytest.raises(ValueError):
@@ -186,7 +186,7 @@ def test_multiclass_wrong_inputs():
186186
re.update((torch.rand(10, 6, 12, 14), torch.randint(0, 5, size=(10, 12, 14)).long()))
187187
assert re._updated is True
188188

189-
re = Recall()
189+
re = Recall(average=False)
190190
assert re._updated is False
191191

192192
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)