Skip to content

Commit 31df2ab

Browse files
authored
Merge branch 'main' into make-image-channels-last
2 parents 6ef8fce + 9f0afd5 commit 31df2ab

File tree

14 files changed

+140
-35
lines changed

14 files changed

+140
-35
lines changed

gallery/v2_transforms/plot_transforms_v2.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,9 @@
9999
format="XYXY", canvas_size=img.shape[-2:])
100100

101101
transforms = v2.Compose([
102-
v2.RandomPhotometricDistort(),
103-
v2.RandomIoUCrop(),
104-
v2.RandomHorizontalFlip(p=0.5),
105-
v2.SanitizeBoundingBoxes(),
102+
v2.RandomResizedCrop(size=(224, 224), antialias=True),
103+
v2.RandomPhotometricDistort(p=1),
104+
v2.RandomHorizontalFlip(p=1),
106105
])
107106
out_img, out_bboxes = transforms(img, bboxes)
108107

references/classification/presets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(
6161

6262
transforms.extend(
6363
[
64-
T.ConvertImageDtype(torch.float),
64+
T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
6565
T.Normalize(mean=mean, std=std),
6666
]
6767
)
@@ -106,7 +106,7 @@ def __init__(
106106
transforms.append(T.PILToTensor())
107107

108108
transforms += [
109-
T.ConvertImageDtype(torch.float),
109+
T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
110110
T.Normalize(mean=mean, std=std),
111111
]
112112

references/detection/presets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
# Note: we could just convert to pure tensors even in v2.
7474
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
7575

76-
transforms += [T.ConvertImageDtype(torch.float)]
76+
transforms += [T.ToDtype(torch.float, scale=True)]
7777

7878
if use_v2:
7979
transforms += [
@@ -103,7 +103,7 @@ def __init__(self, backend="pil", use_v2=False):
103103
else:
104104
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
105105

106-
transforms += [T.ConvertImageDtype(torch.float)]
106+
transforms += [T.ToDtype(torch.float, scale=True)]
107107

108108
if use_v2:
109109
transforms += [T.ToPureTensor()]

references/detection/transforms.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,17 @@ def forward(
5353
return image, target
5454

5555

56-
class ConvertImageDtype(nn.Module):
57-
def __init__(self, dtype: torch.dtype) -> None:
56+
class ToDtype(nn.Module):
57+
def __init__(self, dtype: torch.dtype, scale: bool = False) -> None:
5858
super().__init__()
5959
self.dtype = dtype
60+
self.scale = scale
6061

6162
def forward(
6263
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
6364
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
65+
if not self.scale:
66+
return image.to(dtype=self.dtype), target
6467
image = F.convert_image_dtype(image, self.dtype)
6568
return image, target
6669

references/segmentation/presets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
]
6161
else:
6262
# No need to explicitly convert masks as they're magically int64 already
63-
transforms += [T.ConvertImageDtype(torch.float)]
63+
transforms += [T.ToDtype(torch.float, scale=True)]
6464

6565
transforms += [T.Normalize(mean=mean, std=std)]
6666
if use_v2:
@@ -97,7 +97,7 @@ def __init__(
9797
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
9898

9999
transforms += [
100-
T.ConvertImageDtype(torch.float),
100+
T.ToDtype(torch.float, scale=True),
101101
T.Normalize(mean=mean, std=std),
102102
]
103103
if use_v2:

references/segmentation/transforms.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,14 @@ def __call__(self, image, target):
8181
return image, target
8282

8383

84-
class ConvertImageDtype:
85-
def __init__(self, dtype):
84+
class ToDtype:
85+
def __init__(self, dtype, scale=False):
8686
self.dtype = dtype
87+
self.scale = scale
8788

8889
def __call__(self, image, target):
90+
if not self.scale:
91+
return image.to(dtype=self.dtype), target
8992
image = F.convert_image_dtype(image, self.dtype)
9093
return image, target
9194

references/segmentation/v2_extras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,6 @@ def _coco_detection_masks_to_voc_segmentation_mask(self, target):
7878
def forward(self, image, target):
7979
segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target)
8080
if segmentation_mask is None:
81-
segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8)
81+
segmentation_mask = torch.zeros(v2.functional.get_size(image), dtype=torch.uint8)
8282

8383
return image, datapoints.Mask(segmentation_mask)

test/datasets_utils.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -662,27 +662,39 @@ class VideoDatasetTestCase(DatasetTestCase):
662662
FEATURE_TYPES = (torch.Tensor, torch.Tensor, int)
663663
REQUIRED_PACKAGES = ("av",)
664664

665-
DEFAULT_FRAMES_PER_CLIP = 1
665+
FRAMES_PER_CLIP = 1
666666

667667
def __init__(self, *args, **kwargs):
668668
super().__init__(*args, **kwargs)
669669
self.dataset_args = self._set_default_frames_per_clip(self.dataset_args)
670670

671-
def _set_default_frames_per_clip(self, inject_fake_data):
671+
def _set_default_frames_per_clip(self, dataset_args):
672672
argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__)
673673
args_without_default = argspec.args[1 : (-len(argspec.defaults) if argspec.defaults else None)]
674674
frames_per_clip_last = args_without_default[-1] == "frames_per_clip"
675675

676-
@functools.wraps(inject_fake_data)
676+
@functools.wraps(dataset_args)
677677
def wrapper(tmpdir, config):
678-
args = inject_fake_data(tmpdir, config)
678+
args = dataset_args(tmpdir, config)
679679
if frames_per_clip_last and len(args) == len(args_without_default) - 1:
680-
args = (*args, self.DEFAULT_FRAMES_PER_CLIP)
680+
args = (*args, self.FRAMES_PER_CLIP)
681681

682682
return args
683683

684684
return wrapper
685685

686+
def test_output_format(self):
687+
for output_format in ["TCHW", "THWC"]:
688+
with self.create_dataset(output_format=output_format) as (dataset, _):
689+
for video, *_ in dataset:
690+
if output_format == "TCHW":
691+
num_frames, num_channels, *_ = video.shape
692+
else: # output_format == "THWC":
693+
num_frames, *_, num_channels = video.shape
694+
695+
assert num_frames == self.FRAMES_PER_CLIP
696+
assert num_channels == 3
697+
686698
@test_all_configs
687699
def test_transforms_v2_wrapper(self, config):
688700
# `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly

test/test_image.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,9 @@ def test_decode_jpeg(img_path, pil_mode, mode):
8383
with Image.open(img_path) as img:
8484
is_cmyk = img.mode == "CMYK"
8585
if pil_mode is not None:
86-
if is_cmyk:
87-
# libjpeg does not support the conversion
88-
pytest.xfail("Decoding a CMYK jpeg isn't supported")
8986
img = img.convert(pil_mode)
9087
img_pil = torch.from_numpy(np.array(img))
91-
if is_cmyk:
88+
if is_cmyk and mode == ImageReadMode.UNCHANGED:
9289
# flip the colors to match libjpeg
9390
img_pil = 255 - img_pil
9491

test/test_transforms_v2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,20 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
12561256
assert out_labels.tolist() == valid_indices
12571257

12581258

1259+
def test_sanitize_bounding_boxes_no_label():
1260+
# Non-regression test for https://github.com/pytorch/vision/issues/7878
1261+
1262+
img = make_image()
1263+
boxes = make_bounding_boxes()
1264+
1265+
with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"):
1266+
transforms.SanitizeBoundingBoxes()(img, boxes)
1267+
1268+
out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes)
1269+
assert isinstance(out_img, datapoints.Image)
1270+
assert isinstance(out_boxes, datapoints.BoundingBoxes)
1271+
1272+
12591273
def test_sanitize_bounding_boxes_errors():
12601274

12611275
good_bbox = datapoints.BoundingBoxes(

0 commit comments

Comments
 (0)