Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions tests/ignite/metrics/vision/test_object_detection_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,12 +623,13 @@ def test_wrong_input():
m.update(([{"bbox": None, "scores": None, "labels": None}], [{"labels": None}]))


def test_empty_data():
def test_empty_data(available_device):
"""
Note that PyCOCO returns -1 when threre's no ground truth data.
"""

metric = ObjectDetectionAvgPrecisionRecall()
metric = ObjectDetectionAvgPrecisionRecall(device=available_device)
assert metric._device == torch.device(available_device)
metric.update(
(
[{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)}],
Expand All @@ -652,7 +653,8 @@ def test_empty_data():
)
assert metric.compute() == (-1, -1)

metric = ObjectDetectionAvgPrecisionRecall()
metric = ObjectDetectionAvgPrecisionRecall(device=available_device)
assert metric._device == torch.device(available_device)
metric.update(
(
[{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)}],
Expand All @@ -670,7 +672,8 @@ def test_empty_data():
assert metric._y_true_count[1] == 1
assert metric.compute() == (0, 0)

metric = ObjectDetectionAvgPrecisionRecall()
metric = ObjectDetectionAvgPrecisionRecall(device=available_device)
assert metric._device == torch.device(available_device)
pred = {
"bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]),
"scores": torch.tensor([0.9]),
Expand All @@ -689,8 +692,9 @@ def test_no_torchvision():
ObjectDetectionAvgPrecisionRecall()


def test_iou(sample):
m = ObjectDetectionAvgPrecisionRecall(num_classes=91)
def test_iou(sample, available_device):
m = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=available_device)
assert m._device == torch.device(available_device)
from pycocotools.mask import iou as pycoco_iou

for pred, tgt in zip(*sample.data):
Expand All @@ -710,8 +714,9 @@ def test_iou(sample):
assert equal.all()


def test_iou_thresholding():
metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.0, 0.3, 0.5, 0.75])
def test_iou_thresholding(available_device):
metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.0, 0.3, 0.5, 0.75], device=available_device)
assert metric._device == torch.device(available_device)

pred = {
"bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]),
Expand All @@ -720,7 +725,7 @@ def test_iou_thresholding():
}
gt = {"bbox": torch.tensor([[0.0, 0.0, 50.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1])}
metric.update(([pred], [gt]))
assert (metric._tps[0] == torch.tensor([[True], [True], [True], [False]])).all()
assert (metric._tps[0] == torch.tensor([[True], [True], [True], [False]], device=available_device)).all()

pred = {
"bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]),
Expand All @@ -729,10 +734,10 @@ def test_iou_thresholding():
}
gt = {"bbox": torch.tensor([[100.0, 0.0, 200.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1])}
metric.update(([pred], [gt]))
assert (metric._tps[1] == torch.tensor([[True], [False], [False], [False]])).all()
assert (metric._tps[1] == torch.tensor([[True], [False], [False], [False]], device=available_device)).all()


def test_matching():
def test_matching(available_device):
"""
PyCOCO matching rules:
1. The higher confidence in a prediction, the sooner decision is made for it.
Expand All @@ -750,7 +755,8 @@ def test_matching():
7. Non-ignored ground truths are given priority over the ignored ones when matching with a prediction
even if their IOU is lower.
"""
metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.2])
metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.2], device=available_device)
assert metric._device == torch.device(available_device)

pred = {
"bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0], [0.0, 0.0, 100.0, 100.0]]),
Expand All @@ -764,27 +770,27 @@ def test_matching():
}
metric.update(([pred], [gt]))
# Preds are sorted by their scores internally
assert (metric._tps[0] == torch.tensor([[True, False]])).all()
assert (metric._fps[0] == torch.tensor([[False, True]])).all()
assert (metric._scores[0] == torch.tensor([[0.9, 0.8]])).all()
assert (metric._tps[0] == torch.tensor([[True, False]], device=available_device)).all()
assert (metric._fps[0] == torch.tensor([[False, True]], device=available_device)).all()
assert (metric._scores[0] == torch.tensor([[0.9, 0.8]], device=available_device)).all()

pred["scores"] = torch.tensor([0.9, 0.9])
metric.update(([pred], [gt]))
assert (metric._tps[1] == torch.tensor([[True, False]])).all()
assert (metric._fps[1] == torch.tensor([[False, True]])).all()
assert (metric._tps[1] == torch.tensor([[True, False]], device=available_device)).all()
assert (metric._fps[1] == torch.tensor([[False, True]], device=available_device)).all()

gt["iscrowd"] = torch.tensor([1])
metric.update(([pred], [gt]))
assert (metric._tps[2] == torch.tensor([[False, False]])).all()
assert (metric._fps[2] == torch.tensor([[False, False]])).all()
assert (metric._tps[2] == torch.tensor([[False, False]], device=available_device)).all()
assert (metric._fps[2] == torch.tensor([[False, False]], device=available_device)).all()

pred["bbox"] = torch.tensor([[0.0, 0.0, 100.0, 100.0], [100.0, 0.0, 200.0, 100.0]])
gt["bbox"] = torch.tensor([[0.0, 0.0, 25.0, 50.0], [50.0, 0.0, 150.0, 100.0]])
gt["iscrowd"] = torch.zeros((2,))
gt["labels"] = torch.tensor([1, 1])
metric.update(([pred], [gt]))
assert (metric._tps[3] == torch.tensor([[True, False]])).all()
assert (metric._fps[3] == torch.tensor([[False, True]])).all()
assert (metric._tps[3] == torch.tensor([[True, False]], device=available_device)).all()
assert (metric._fps[3] == torch.tensor([[False, True]], device=available_device)).all()

metric._area_range = "small"
pred["bbox"] = torch.tensor(
Expand All @@ -794,14 +800,14 @@ def test_matching():
pred["labels"] = torch.tensor([1, 1, 1, 1])
gt["bbox"] = torch.tensor([[0.0, 0.0, 100.0, 11.0], [0.0, 0.0, 100.0, 5.0]])
metric.update(([pred], [gt]))
assert (metric._tps[4] == torch.tensor([[True, False, False, False]])).all()
assert (metric._fps[4] == torch.tensor([[False, False, False, True]])).all()
assert (metric._tps[4] == torch.tensor([[True, False, False, False]], device=available_device)).all()
assert (metric._fps[4] == torch.tensor([[False, False, False, True]], device=available_device)).all()

pred["scores"] = torch.tensor([0.9, 1.0, 0.9, 0.9])
metric._max_detections_per_image_per_class = 1
metric.update(([pred], [gt]))
assert (metric._tps[5] == torch.tensor([[True]])).all()
assert (metric._fps[5] == torch.tensor([[False]])).all()
assert (metric._tps[5] == torch.tensor([[True]], device=available_device)).all()
assert (metric._fps[5] == torch.tensor([[False]], device=available_device)).all()


def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold(y_true, y_score):
Expand All @@ -825,11 +831,12 @@ def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold
return np.hstack((precision[sl], 1)), np.hstack((recall[sl], 0)), None


def test__compute_recall_and_precision():
def test__compute_recall_and_precision(available_device):
# The case in which detector detects all gt objects but also produces some wrong predictions.
scores = torch.rand((50,))
y_true = torch.randint(0, 2, (50,))
m = ObjectDetectionAvgPrecisionRecall()
m = ObjectDetectionAvgPrecisionRecall(device=available_device)
assert m._device == torch.device(available_device)

ignite_recall, ignite_precision = m._compute_recall_and_precision(
y_true.bool(), ~(y_true.bool()), scores, y_true.sum()
Expand Down
Loading