Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
import os
import pathlib
import pickle
import random
import shutil
import string
Expand Down Expand Up @@ -572,35 +573,42 @@ def test_transforms_v2_wrapper(self, config):

try:
with self.create_dataset(config) as (dataset, info):
for target_keys in [None, "all"]:
if target_keys is not None and self.DATASET_CLASS not in {
torchvision.datasets.CocoDetection,
torchvision.datasets.VOCDetection,
torchvision.datasets.Kitti,
torchvision.datasets.WIDERFace,
}:
with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
continue

wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
assert isinstance(wrapped_dataset, self.DATASET_CLASS)
assert len(wrapped_dataset) == info["num_examples"]

wrapped_sample = wrapped_dataset[0]
assert tree_any(
lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample
)
wrap_dataset_for_transforms_v2(dataset)
except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
if str(error).startswith(msg):
pytest.skip(msg)
return
raise error
except RuntimeError as error:
if "currently not supported by this wrapper" in str(error):
pytest.skip("Config is currently not supported by this wrapper")
return
raise error

for target_keys, de_serialize in itertools.product(
[None, "all"], [lambda d: d, lambda d: pickle.loads(pickle.dumps(d))]
):

with self.create_dataset(config) as (dataset, info):
if target_keys is not None and self.DATASET_CLASS not in {
torchvision.datasets.CocoDetection,
torchvision.datasets.VOCDetection,
torchvision.datasets.Kitti,
torchvision.datasets.WIDERFace,
}:
with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
continue

wrapped_dataset = de_serialize(wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys))

assert isinstance(wrapped_dataset, self.DATASET_CLASS)
assert len(wrapped_dataset) == info["num_examples"]

wrapped_sample = wrapped_dataset[0]
assert tree_any(
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
)


class ImageDatasetTestCase(DatasetTestCase):
"""Abstract base class for image dataset testcases.
Expand Down
7 changes: 6 additions & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import pathlib
import pickle
import random
import warnings

Expand Down Expand Up @@ -169,8 +170,11 @@ class TestSmoke:
next(make_vanilla_tensor_images()),
],
)
@pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_common(self, transform, adapter, container_type, image_or_video, device):
def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device):
transform = de_serialize(transform)

canvas_size = F.get_size(image_or_video)
input = dict(
image_or_video=image_or_video,
Expand Down Expand Up @@ -234,6 +238,7 @@ def test_common(self, transform, adapter, container_type, image_or_video, device
tensor=torch.empty(5),
array=np.empty(5),
)

if adapter is not None:
input = adapter(transform, input, device)

Expand Down
3 changes: 3 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import decimal
import inspect
import math
import pickle
import re
from pathlib import Path
from unittest import mock
Expand Down Expand Up @@ -247,6 +248,8 @@ def _check_transform_v1_compatibility(transform, input):
def check_transform(transform_cls, input, *args, **kwargs):
transform = transform_cls(*args, **kwargs)

pickle.loads(pickle.dumps(transform))

output = transform(input)
assert isinstance(output, type(input))

Expand Down
9 changes: 8 additions & 1 deletion torchvision/datapoints/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import collections.abc

import contextlib
from collections import defaultdict

Expand Down Expand Up @@ -97,6 +96,10 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
f"but got {target_keys}"
)

return _make_wrapped_dataset(dataset, target_keys)


def _make_wrapped_dataset(dataset, target_keys):
# Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
# "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the
# original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
Expand Down Expand Up @@ -162,6 +165,7 @@ def __init__(self, dataset, target_keys):
raise TypeError(msg)

self._dataset = dataset
self._target_keys = target_keys
self._wrapper = wrapper_factory(dataset, target_keys)

# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
Expand Down Expand Up @@ -197,6 +201,9 @@ def __getitem__(self, idx):
def __len__(self):
return len(self._dataset)

def __reduce__(self):
return _make_wrapped_dataset, (self._dataset, self._target_keys)


def raise_not_supported(description):
raise RuntimeError(
Expand Down