Skip to content

Commit f9e4c8c

Browse files
authored
available device in test_object_detection_map.py (#3409)
1 parent 65e5085 commit f9e4c8c

1 file changed

Lines changed: 34 additions & 27 deletions

File tree

tests/ignite/metrics/vision/test_object_detection_map.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -623,12 +623,13 @@ def test_wrong_input():
623623
m.update(([{"bbox": None, "scores": None, "labels": None}], [{"labels": None}]))
624624

625625

626-
def test_empty_data():
626+
def test_empty_data(available_device):
627627
"""
628628
Note that PyCOCO returns -1 when threre's no ground truth data.
629629
"""
630630

631-
metric = ObjectDetectionAvgPrecisionRecall()
631+
metric = ObjectDetectionAvgPrecisionRecall(device=available_device)
632+
assert metric._device == torch.device(available_device)
632633
metric.update(
633634
(
634635
[{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)}],
@@ -652,7 +653,8 @@ def test_empty_data():
652653
)
653654
assert metric.compute() == (-1, -1)
654655

655-
metric = ObjectDetectionAvgPrecisionRecall()
656+
metric = ObjectDetectionAvgPrecisionRecall(device=available_device)
657+
assert metric._device == torch.device(available_device)
656658
metric.update(
657659
(
658660
[{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)}],
@@ -670,7 +672,8 @@ def test_empty_data():
670672
assert metric._y_true_count[1] == 1
671673
assert metric.compute() == (0, 0)
672674

673-
metric = ObjectDetectionAvgPrecisionRecall()
675+
metric = ObjectDetectionAvgPrecisionRecall(device=available_device)
676+
assert metric._device == torch.device(available_device)
674677
pred = {
675678
"bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]),
676679
"scores": torch.tensor([0.9]),
@@ -689,8 +692,9 @@ def test_no_torchvision():
689692
ObjectDetectionAvgPrecisionRecall()
690693

691694

692-
def test_iou(sample):
693-
m = ObjectDetectionAvgPrecisionRecall(num_classes=91)
695+
def test_iou(sample, available_device):
696+
m = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=available_device)
697+
assert m._device == torch.device(available_device)
694698
from pycocotools.mask import iou as pycoco_iou
695699

696700
for pred, tgt in zip(*sample.data):
@@ -710,8 +714,9 @@ def test_iou(sample):
710714
assert equal.all()
711715

712716

713-
def test_iou_thresholding():
714-
metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.0, 0.3, 0.5, 0.75])
717+
def test_iou_thresholding(available_device):
718+
metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.0, 0.3, 0.5, 0.75], device=available_device)
719+
assert metric._device == torch.device(available_device)
715720

716721
pred = {
717722
"bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]),
@@ -720,7 +725,7 @@ def test_iou_thresholding():
720725
}
721726
gt = {"bbox": torch.tensor([[0.0, 0.0, 50.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1])}
722727
metric.update(([pred], [gt]))
723-
assert (metric._tps[0] == torch.tensor([[True], [True], [True], [False]])).all()
728+
assert (metric._tps[0] == torch.tensor([[True], [True], [True], [False]], device=available_device)).all()
724729

725730
pred = {
726731
"bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]),
@@ -729,10 +734,10 @@ def test_iou_thresholding():
729734
}
730735
gt = {"bbox": torch.tensor([[100.0, 0.0, 200.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1])}
731736
metric.update(([pred], [gt]))
732-
assert (metric._tps[1] == torch.tensor([[True], [False], [False], [False]])).all()
737+
assert (metric._tps[1] == torch.tensor([[True], [False], [False], [False]], device=available_device)).all()
733738

734739

735-
def test_matching():
740+
def test_matching(available_device):
736741
"""
737742
PyCOCO matching rules:
738743
1. The higher confidence in a prediction, the sooner decision is made for it.
@@ -750,7 +755,8 @@ def test_matching():
750755
7. Non-ignored ground truths are given priority over the ignored ones when matching with a prediction
751756
even if their IOU is lower.
752757
"""
753-
metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.2])
758+
metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.2], device=available_device)
759+
assert metric._device == torch.device(available_device)
754760

755761
pred = {
756762
"bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0], [0.0, 0.0, 100.0, 100.0]]),
@@ -764,27 +770,27 @@ def test_matching():
764770
}
765771
metric.update(([pred], [gt]))
766772
# Preds are sorted by their scores internally
767-
assert (metric._tps[0] == torch.tensor([[True, False]])).all()
768-
assert (metric._fps[0] == torch.tensor([[False, True]])).all()
769-
assert (metric._scores[0] == torch.tensor([[0.9, 0.8]])).all()
773+
assert (metric._tps[0] == torch.tensor([[True, False]], device=available_device)).all()
774+
assert (metric._fps[0] == torch.tensor([[False, True]], device=available_device)).all()
775+
assert (metric._scores[0] == torch.tensor([[0.9, 0.8]], device=available_device)).all()
770776

771777
pred["scores"] = torch.tensor([0.9, 0.9])
772778
metric.update(([pred], [gt]))
773-
assert (metric._tps[1] == torch.tensor([[True, False]])).all()
774-
assert (metric._fps[1] == torch.tensor([[False, True]])).all()
779+
assert (metric._tps[1] == torch.tensor([[True, False]], device=available_device)).all()
780+
assert (metric._fps[1] == torch.tensor([[False, True]], device=available_device)).all()
775781

776782
gt["iscrowd"] = torch.tensor([1])
777783
metric.update(([pred], [gt]))
778-
assert (metric._tps[2] == torch.tensor([[False, False]])).all()
779-
assert (metric._fps[2] == torch.tensor([[False, False]])).all()
784+
assert (metric._tps[2] == torch.tensor([[False, False]], device=available_device)).all()
785+
assert (metric._fps[2] == torch.tensor([[False, False]], device=available_device)).all()
780786

781787
pred["bbox"] = torch.tensor([[0.0, 0.0, 100.0, 100.0], [100.0, 0.0, 200.0, 100.0]])
782788
gt["bbox"] = torch.tensor([[0.0, 0.0, 25.0, 50.0], [50.0, 0.0, 150.0, 100.0]])
783789
gt["iscrowd"] = torch.zeros((2,))
784790
gt["labels"] = torch.tensor([1, 1])
785791
metric.update(([pred], [gt]))
786-
assert (metric._tps[3] == torch.tensor([[True, False]])).all()
787-
assert (metric._fps[3] == torch.tensor([[False, True]])).all()
792+
assert (metric._tps[3] == torch.tensor([[True, False]], device=available_device)).all()
793+
assert (metric._fps[3] == torch.tensor([[False, True]], device=available_device)).all()
788794

789795
metric._area_range = "small"
790796
pred["bbox"] = torch.tensor(
@@ -794,14 +800,14 @@ def test_matching():
794800
pred["labels"] = torch.tensor([1, 1, 1, 1])
795801
gt["bbox"] = torch.tensor([[0.0, 0.0, 100.0, 11.0], [0.0, 0.0, 100.0, 5.0]])
796802
metric.update(([pred], [gt]))
797-
assert (metric._tps[4] == torch.tensor([[True, False, False, False]])).all()
798-
assert (metric._fps[4] == torch.tensor([[False, False, False, True]])).all()
803+
assert (metric._tps[4] == torch.tensor([[True, False, False, False]], device=available_device)).all()
804+
assert (metric._fps[4] == torch.tensor([[False, False, False, True]], device=available_device)).all()
799805

800806
pred["scores"] = torch.tensor([0.9, 1.0, 0.9, 0.9])
801807
metric._max_detections_per_image_per_class = 1
802808
metric.update(([pred], [gt]))
803-
assert (metric._tps[5] == torch.tensor([[True]])).all()
804-
assert (metric._fps[5] == torch.tensor([[False]])).all()
809+
assert (metric._tps[5] == torch.tensor([[True]], device=available_device)).all()
810+
assert (metric._fps[5] == torch.tensor([[False]], device=available_device)).all()
805811

806812

807813
def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold(y_true, y_score):
@@ -825,11 +831,12 @@ def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold
825831
return np.hstack((precision[sl], 1)), np.hstack((recall[sl], 0)), None
826832

827833

828-
def test__compute_recall_and_precision():
834+
def test__compute_recall_and_precision(available_device):
829835
# The case in which detector detects all gt objects but also produces some wrong predictions.
830836
scores = torch.rand((50,))
831837
y_true = torch.randint(0, 2, (50,))
832-
m = ObjectDetectionAvgPrecisionRecall()
838+
m = ObjectDetectionAvgPrecisionRecall(device=available_device)
839+
assert m._device == torch.device(available_device)
833840

834841
ignite_recall, ignite_precision = m._compute_recall_and_precision(
835842
y_true.bool(), ~(y_true.bool()), scores, y_true.sum()

0 commit comments

Comments
 (0)