@@ -623,12 +623,13 @@ def test_wrong_input():
623623 m .update (([{"bbox" : None , "scores" : None , "labels" : None }], [{"labels" : None }]))
624624
625625
626- def test_empty_data ():
626+ def test_empty_data (available_device ):
627627 """
628628 Note that PyCOCO returns -1 when threre's no ground truth data.
629629 """
630630
631- metric = ObjectDetectionAvgPrecisionRecall ()
631+ metric = ObjectDetectionAvgPrecisionRecall (device = available_device )
632+ assert metric ._device == torch .device (available_device )
632633 metric .update (
633634 (
634635 [{"bbox" : torch .zeros ((0 , 4 )), "scores" : torch .zeros ((0 ,)), "labels" : torch .zeros ((0 ), dtype = torch .long )}],
@@ -652,7 +653,8 @@ def test_empty_data():
652653 )
653654 assert metric .compute () == (- 1 , - 1 )
654655
655- metric = ObjectDetectionAvgPrecisionRecall ()
656+ metric = ObjectDetectionAvgPrecisionRecall (device = available_device )
657+ assert metric ._device == torch .device (available_device )
656658 metric .update (
657659 (
658660 [{"bbox" : torch .zeros ((0 , 4 )), "scores" : torch .zeros ((0 ,)), "labels" : torch .zeros ((0 ), dtype = torch .long )}],
@@ -670,7 +672,8 @@ def test_empty_data():
670672 assert metric ._y_true_count [1 ] == 1
671673 assert metric .compute () == (0 , 0 )
672674
673- metric = ObjectDetectionAvgPrecisionRecall ()
675+ metric = ObjectDetectionAvgPrecisionRecall (device = available_device )
676+ assert metric ._device == torch .device (available_device )
674677 pred = {
675678 "bbox" : torch .tensor ([[0.0 , 0.0 , 100.0 , 100.0 ]]),
676679 "scores" : torch .tensor ([0.9 ]),
@@ -689,8 +692,9 @@ def test_no_torchvision():
689692 ObjectDetectionAvgPrecisionRecall ()
690693
691694
692- def test_iou (sample ):
693- m = ObjectDetectionAvgPrecisionRecall (num_classes = 91 )
695+ def test_iou (sample , available_device ):
696+ m = ObjectDetectionAvgPrecisionRecall (num_classes = 91 , device = available_device )
697+ assert m ._device == torch .device (available_device )
694698 from pycocotools .mask import iou as pycoco_iou
695699
696700 for pred , tgt in zip (* sample .data ):
@@ -710,8 +714,9 @@ def test_iou(sample):
710714 assert equal .all ()
711715
712716
713- def test_iou_thresholding ():
714- metric = ObjectDetectionAvgPrecisionRecall (iou_thresholds = [0.0 , 0.3 , 0.5 , 0.75 ])
717+ def test_iou_thresholding (available_device ):
718+ metric = ObjectDetectionAvgPrecisionRecall (iou_thresholds = [0.0 , 0.3 , 0.5 , 0.75 ], device = available_device )
719+ assert metric ._device == torch .device (available_device )
715720
716721 pred = {
717722 "bbox" : torch .tensor ([[0.0 , 0.0 , 100.0 , 100.0 ]]),
@@ -720,7 +725,7 @@ def test_iou_thresholding():
720725 }
721726 gt = {"bbox" : torch .tensor ([[0.0 , 0.0 , 50.0 , 100.0 ]]), "iscrowd" : torch .zeros ((1 ,)), "labels" : torch .tensor ([1 ])}
722727 metric .update (([pred ], [gt ]))
723- assert (metric ._tps [0 ] == torch .tensor ([[True ], [True ], [True ], [False ]])).all ()
728+ assert (metric ._tps [0 ] == torch .tensor ([[True ], [True ], [True ], [False ]], device = available_device )).all ()
724729
725730 pred = {
726731 "bbox" : torch .tensor ([[0.0 , 0.0 , 100.0 , 100.0 ]]),
@@ -729,10 +734,10 @@ def test_iou_thresholding():
729734 }
730735 gt = {"bbox" : torch .tensor ([[100.0 , 0.0 , 200.0 , 100.0 ]]), "iscrowd" : torch .zeros ((1 ,)), "labels" : torch .tensor ([1 ])}
731736 metric .update (([pred ], [gt ]))
732- assert (metric ._tps [1 ] == torch .tensor ([[True ], [False ], [False ], [False ]])).all ()
737+ assert (metric ._tps [1 ] == torch .tensor ([[True ], [False ], [False ], [False ]], device = available_device )).all ()
733738
734739
735- def test_matching ():
740+ def test_matching (available_device ):
736741 """
737742 PyCOCO matching rules:
738743 1. The higher confidence in a prediction, the sooner decision is made for it.
@@ -750,7 +755,8 @@ def test_matching():
750755 7. Non-ignored ground truths are given priority over the ignored ones when matching with a prediction
751756 even if their IOU is lower.
752757 """
753- metric = ObjectDetectionAvgPrecisionRecall (iou_thresholds = [0.2 ])
758+ metric = ObjectDetectionAvgPrecisionRecall (iou_thresholds = [0.2 ], device = available_device )
759+ assert metric ._device == torch .device (available_device )
754760
755761 pred = {
756762 "bbox" : torch .tensor ([[0.0 , 0.0 , 100.0 , 100.0 ], [0.0 , 0.0 , 100.0 , 100.0 ]]),
@@ -764,27 +770,27 @@ def test_matching():
764770 }
765771 metric .update (([pred ], [gt ]))
766772 # Preds are sorted by their scores internally
767- assert (metric ._tps [0 ] == torch .tensor ([[True , False ]])).all ()
768- assert (metric ._fps [0 ] == torch .tensor ([[False , True ]])).all ()
769- assert (metric ._scores [0 ] == torch .tensor ([[0.9 , 0.8 ]])).all ()
773+ assert (metric ._tps [0 ] == torch .tensor ([[True , False ]], device = available_device )).all ()
774+ assert (metric ._fps [0 ] == torch .tensor ([[False , True ]], device = available_device )).all ()
775+ assert (metric ._scores [0 ] == torch .tensor ([[0.9 , 0.8 ]], device = available_device )).all ()
770776
771777 pred ["scores" ] = torch .tensor ([0.9 , 0.9 ])
772778 metric .update (([pred ], [gt ]))
773- assert (metric ._tps [1 ] == torch .tensor ([[True , False ]])).all ()
774- assert (metric ._fps [1 ] == torch .tensor ([[False , True ]])).all ()
779+ assert (metric ._tps [1 ] == torch .tensor ([[True , False ]], device = available_device )).all ()
780+ assert (metric ._fps [1 ] == torch .tensor ([[False , True ]], device = available_device )).all ()
775781
776782 gt ["iscrowd" ] = torch .tensor ([1 ])
777783 metric .update (([pred ], [gt ]))
778- assert (metric ._tps [2 ] == torch .tensor ([[False , False ]])).all ()
779- assert (metric ._fps [2 ] == torch .tensor ([[False , False ]])).all ()
784+ assert (metric ._tps [2 ] == torch .tensor ([[False , False ]], device = available_device )).all ()
785+ assert (metric ._fps [2 ] == torch .tensor ([[False , False ]], device = available_device )).all ()
780786
781787 pred ["bbox" ] = torch .tensor ([[0.0 , 0.0 , 100.0 , 100.0 ], [100.0 , 0.0 , 200.0 , 100.0 ]])
782788 gt ["bbox" ] = torch .tensor ([[0.0 , 0.0 , 25.0 , 50.0 ], [50.0 , 0.0 , 150.0 , 100.0 ]])
783789 gt ["iscrowd" ] = torch .zeros ((2 ,))
784790 gt ["labels" ] = torch .tensor ([1 , 1 ])
785791 metric .update (([pred ], [gt ]))
786- assert (metric ._tps [3 ] == torch .tensor ([[True , False ]])).all ()
787- assert (metric ._fps [3 ] == torch .tensor ([[False , True ]])).all ()
792+ assert (metric ._tps [3 ] == torch .tensor ([[True , False ]], device = available_device )).all ()
793+ assert (metric ._fps [3 ] == torch .tensor ([[False , True ]], device = available_device )).all ()
788794
789795 metric ._area_range = "small"
790796 pred ["bbox" ] = torch .tensor (
@@ -794,14 +800,14 @@ def test_matching():
794800 pred ["labels" ] = torch .tensor ([1 , 1 , 1 , 1 ])
795801 gt ["bbox" ] = torch .tensor ([[0.0 , 0.0 , 100.0 , 11.0 ], [0.0 , 0.0 , 100.0 , 5.0 ]])
796802 metric .update (([pred ], [gt ]))
797- assert (metric ._tps [4 ] == torch .tensor ([[True , False , False , False ]])).all ()
798- assert (metric ._fps [4 ] == torch .tensor ([[False , False , False , True ]])).all ()
803+ assert (metric ._tps [4 ] == torch .tensor ([[True , False , False , False ]], device = available_device )).all ()
804+ assert (metric ._fps [4 ] == torch .tensor ([[False , False , False , True ]], device = available_device )).all ()
799805
800806 pred ["scores" ] = torch .tensor ([0.9 , 1.0 , 0.9 , 0.9 ])
801807 metric ._max_detections_per_image_per_class = 1
802808 metric .update (([pred ], [gt ]))
803- assert (metric ._tps [5 ] == torch .tensor ([[True ]])).all ()
804- assert (metric ._fps [5 ] == torch .tensor ([[False ]])).all ()
809+ assert (metric ._tps [5 ] == torch .tensor ([[True ]], device = available_device )).all ()
810+ assert (metric ._fps [5 ] == torch .tensor ([[False ]], device = available_device )).all ()
805811
806812
807813def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold (y_true , y_score ):
@@ -825,11 +831,12 @@ def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold
825831 return np .hstack ((precision [sl ], 1 )), np .hstack ((recall [sl ], 0 )), None
826832
827833
828- def test__compute_recall_and_precision ():
834+ def test__compute_recall_and_precision (available_device ):
829835 # The case in which detector detects all gt objects but also produces some wrong predictions.
830836 scores = torch .rand ((50 ,))
831837 y_true = torch .randint (0 , 2 , (50 ,))
832- m = ObjectDetectionAvgPrecisionRecall ()
838+ m = ObjectDetectionAvgPrecisionRecall (device = available_device )
839+ assert m ._device == torch .device (available_device )
833840
834841 ignite_recall , ignite_precision = m ._compute_recall_and_precision (
835842 y_true .bool (), ~ (y_true .bool ()), scores , y_true .sum ()
0 commit comments