@@ -711,21 +711,20 @@ def _parse_padding(padding):
711711@pytest .mark .parametrize ("device" , cpu_and_cuda ())
712712@pytest .mark .parametrize ("padding" , [[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]])
713713def test_correctness_pad_bounding_boxes (device , padding ):
714- def _compute_expected_bbox (bbox , padding_ ):
714+ def _compute_expected_bbox (bbox , format , padding_ ):
715715 pad_left , pad_up , _ , _ = _parse_padding (padding_ )
716716
717717 dtype = bbox .dtype
718- format = bbox .format
719718 bbox = (
720719 bbox .clone ()
721720 if format == datapoints .BoundingBoxFormat .XYXY
722- else convert_format_bounding_boxes (bbox , new_format = datapoints .BoundingBoxFormat .XYXY )
721+ else convert_format_bounding_boxes (bbox , old_format = format , new_format = datapoints .BoundingBoxFormat .XYXY )
723722 )
724723
725724 bbox [0 ::2 ] += pad_left
726725 bbox [1 ::2 ] += pad_up
727726
728- bbox = convert_format_bounding_boxes (bbox , new_format = format )
727+ bbox = convert_format_bounding_boxes (bbox , old_format = datapoints . BoundingBoxFormat . XYXY , new_format = format )
729728 if bbox .dtype != dtype :
730729 # Temporary cast to original dtype
731730 # e.g. float32 -> int
@@ -737,7 +736,7 @@ def _compute_expected_canvas_size(bbox, padding_):
737736 height , width = bbox .canvas_size
738737 return height + pad_up + pad_down , width + pad_left + pad_right
739738
740- for bboxes in make_bounding_boxes ():
739+ for bboxes in make_bounding_boxes (extra_dims = (( 4 ,),) ):
741740 bboxes = bboxes .to (device )
742741 bboxes_format = bboxes .format
743742 bboxes_canvas_size = bboxes .canvas_size
@@ -748,18 +747,10 @@ def _compute_expected_canvas_size(bbox, padding_):
748747
749748 torch .testing .assert_close (output_canvas_size , _compute_expected_canvas_size (bboxes , padding ))
750749
751- if bboxes .ndim < 2 or bboxes .shape [0 ] == 0 :
752- bboxes = [bboxes ]
753-
754- expected_bboxes = []
755- for bbox in bboxes :
756- bbox = datapoints .BoundingBoxes (bbox , format = bboxes_format , canvas_size = bboxes_canvas_size )
757- expected_bboxes .append (_compute_expected_bbox (bbox , padding ))
750+ expected_bboxes = torch .stack (
751+ [_compute_expected_bbox (b , bboxes_format , padding ) for b in bboxes .reshape (- 1 , 4 ).unbind ()]
752+ ).reshape (bboxes .shape )
758753
759- if len (expected_bboxes ) > 1 :
760- expected_bboxes = torch .stack (expected_bboxes )
761- else :
762- expected_bboxes = expected_bboxes [0 ]
763754 torch .testing .assert_close (output_boxes , expected_bboxes , atol = 1 , rtol = 0 )
764755
765756
@@ -784,7 +775,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
784775 ],
785776)
786777def test_correctness_perspective_bounding_boxes (device , startpoints , endpoints ):
787- def _compute_expected_bbox (bbox , pcoeffs_ ):
778+ def _compute_expected_bbox (bbox , format_ , canvas_size_ , pcoeffs_ ):
788779 m1 = np .array (
789780 [
790781 [pcoeffs_ [0 ], pcoeffs_ [1 ], pcoeffs_ [2 ]],
@@ -798,7 +789,9 @@ def _compute_expected_bbox(bbox, pcoeffs_):
798789 ]
799790 )
800791
801- bbox_xyxy = convert_format_bounding_boxes (bbox , new_format = datapoints .BoundingBoxFormat .XYXY )
792+ bbox_xyxy = convert_format_bounding_boxes (
793+ bbox , old_format = format_ , new_format = datapoints .BoundingBoxFormat .XYXY
794+ )
802795 points = np .array (
803796 [
804797 [bbox_xyxy [0 ].item (), bbox_xyxy [1 ].item (), 1.0 ],
@@ -818,14 +811,11 @@ def _compute_expected_bbox(bbox, pcoeffs_):
818811 np .max (transformed_points [:, 1 ]),
819812 ]
820813 )
821- out_bbox = datapoints .BoundingBoxes (
822- out_bbox ,
823- format = datapoints .BoundingBoxFormat .XYXY ,
824- canvas_size = bbox .canvas_size ,
825- dtype = bbox .dtype ,
826- device = bbox .device ,
814+ out_bbox = torch .from_numpy (out_bbox )
815+ out_bbox = convert_format_bounding_boxes (
816+ out_bbox , old_format = datapoints .BoundingBoxFormat .XYXY , new_format = format_
827817 )
828- return clamp_bounding_boxes (convert_format_bounding_boxes ( out_bbox , new_format = bbox . format ) )
818+ return clamp_bounding_boxes (out_bbox , format = format_ , canvas_size = canvas_size_ ). to ( bbox )
829819
830820 canvas_size = (32 , 38 )
831821
@@ -844,17 +834,13 @@ def _compute_expected_bbox(bbox, pcoeffs_):
844834 coefficients = pcoeffs ,
845835 )
846836
847- if bboxes .ndim < 2 :
848- bboxes = [bboxes ]
837+ expected_bboxes = torch .stack (
838+ [
839+ _compute_expected_bbox (b , bboxes .format , bboxes .canvas_size , inv_pcoeffs )
840+ for b in bboxes .reshape (- 1 , 4 ).unbind ()
841+ ]
842+ ).reshape (bboxes .shape )
849843
850- expected_bboxes = []
851- for bbox in bboxes :
852- bbox = datapoints .BoundingBoxes (bbox , format = bboxes .format , canvas_size = bboxes .canvas_size )
853- expected_bboxes .append (_compute_expected_bbox (bbox , inv_pcoeffs ))
854- if len (expected_bboxes ) > 1 :
855- expected_bboxes = torch .stack (expected_bboxes )
856- else :
857- expected_bboxes = expected_bboxes [0 ]
858844 torch .testing .assert_close (output_bboxes , expected_bboxes , rtol = 0 , atol = 1 )
859845
860846
@@ -864,9 +850,7 @@ def _compute_expected_bbox(bbox, pcoeffs_):
864850 [(18 , 18 ), [18 , 15 ], (16 , 19 ), [12 ], [46 , 48 ]],
865851)
866852def test_correctness_center_crop_bounding_boxes (device , output_size ):
867- def _compute_expected_bbox (bbox , output_size_ ):
868- format_ = bbox .format
869- canvas_size_ = bbox .canvas_size
853+ def _compute_expected_bbox (bbox , format_ , canvas_size_ , output_size_ ):
870854 dtype = bbox .dtype
871855 bbox = convert_format_bounding_boxes (bbox .float (), format_ , datapoints .BoundingBoxFormat .XYWH )
872856
@@ -895,18 +879,12 @@ def _compute_expected_bbox(bbox, output_size_):
895879 bboxes , bboxes_format , bboxes_canvas_size , output_size
896880 )
897881
898- if bboxes .ndim < 2 :
899- bboxes = [bboxes ]
900-
901- expected_bboxes = []
902- for bbox in bboxes :
903- bbox = datapoints .BoundingBoxes (bbox , format = bboxes_format , canvas_size = bboxes_canvas_size )
904- expected_bboxes .append (_compute_expected_bbox (bbox , output_size ))
905-
906- if len (expected_bboxes ) > 1 :
907- expected_bboxes = torch .stack (expected_bboxes )
908- else :
909- expected_bboxes = expected_bboxes [0 ]
882+ expected_bboxes = torch .stack (
883+ [
884+ _compute_expected_bbox (b , bboxes_format , bboxes_canvas_size , output_size )
885+ for b in bboxes .reshape (- 1 , 4 ).unbind ()
886+ ]
887+ ).reshape (bboxes .shape )
910888
911889 torch .testing .assert_close (output_boxes , expected_bboxes , atol = 1 , rtol = 0 )
912890 torch .testing .assert_close (output_canvas_size , output_size )
0 commit comments