Skip to content

Commit 9d11700

Browse files
committed
fix_bbox_sanitize_tensor
1 parent 054432d commit 9d11700

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

gallery/v2_transforms/plot_transforms_v2.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,9 @@
9999
format="XYXY", canvas_size=img.shape[-2:])
100100

101101
transforms = v2.Compose([
102-
v2.RandomPhotometricDistort(),
103-
v2.RandomIoUCrop(),
104-
v2.RandomHorizontalFlip(p=0.5),
105-
v2.SanitizeBoundingBoxes(),
102+
v2.RandomResizedCrop(size=(224, 224), antialias=True),
103+
v2.RandomPhotometricDistort(p=1),
104+
v2.RandomHorizontalFlip(p=1),
106105
])
107106
out_img, out_bboxes = transforms(img, bboxes)
108107

test/test_transforms_v2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,20 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
12561256
assert out_labels.tolist() == valid_indices
12571257

12581258

1259+
def test_sanitize_bounding_boxes_no_label():
1260+
# Non-regression test for https://github.com/pytorch/vision/issues/7878
1261+
1262+
img = make_image()
1263+
boxes = make_bounding_boxes()
1264+
1265+
with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"):
1266+
transforms.SanitizeBoundingBoxes()(img, boxes)
1267+
1268+
out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes)
1269+
assert isinstance(out_img, datapoints.Image)
1270+
assert isinstance(out_boxes, datapoints.BoundingBoxes)
1271+
1272+
12591273
def test_sanitize_bounding_boxes_errors():
12601274

12611275
good_bbox = datapoints.BoundingBoxes(

torchvision/transforms/v2/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
112112
inputs = inputs[1]
113113

114114
# MixUp, CutMix
115-
if isinstance(inputs, torch.Tensor):
115+
if is_pure_tensor(inputs):
116116
return inputs
117117

118118
if not isinstance(inputs, collections.abc.Mapping):

0 commit comments

Comments
 (0)