diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index dc4d1f4723c..e358c83d9d1 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -230,10 +230,13 @@ def wrapper(idx, sample): batched_target["image_id"] = image_id spatial_size = tuple(F.get_spatial_size(image)) - batched_target["boxes"] = datapoints.BoundingBox( - batched_target["bbox"], - format=datapoints.BoundingBoxFormat.XYWH, - spatial_size=spatial_size, + batched_target["boxes"] = F.convert_format_bounding_box( + datapoints.BoundingBox( + batched_target["bbox"], + format=datapoints.BoundingBoxFormat.XYWH, + spatial_size=spatial_size, + ), + new_format=datapoints.BoundingBoxFormat.XYXY, ) batched_target["masks"] = datapoints.Mask( torch.stack( @@ -323,8 +326,13 @@ def wrapper(idx, sample): target, target_types=dataset.target_type, type_wrappers={ - "bbox": lambda item: datapoints.BoundingBox( - item, format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) + "bbox": lambda item: F.convert_format_bounding_box( + datapoints.BoundingBox( + item, + format=datapoints.BoundingBoxFormat.XYWH, + spatial_size=(image.height, image.width), + ), + new_format=datapoints.BoundingBoxFormat.XYXY, ), }, ) @@ -416,8 +424,11 @@ def wrapper(idx, sample): image, target = sample if target is not None: - target["bbox"] = datapoints.BoundingBox( - target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) + target["bbox"] = F.convert_format_bounding_box( + datapoints.BoundingBox( + target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) + ), + new_format=datapoints.BoundingBoxFormat.XYXY, ) return image, target