Skip to content
5 changes: 5 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,8 @@ Base classes for custom datasets
DatasetFolder
ImageFolder
VisionDataset

Transforms v2
-------------

.. autofunction:: wrap_dataset_for_transforms_v2
65 changes: 51 additions & 14 deletions torchvision/datapoints/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,51 @@

# TODO: naming!
def wrap_dataset_for_transforms_v2(dataset):
"""Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`.

.. note::

So far we only provide wrappers for the most popular datasets. Furthermore, the wrappers only support dataset
configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you
to raise an issue to ``torchvision`` for a dataset or configuration that you need, please act on it.

The dataset samples are wrapped according to the description below.

Special

* :class:`~torchvision.datasets.CocoDetection`: Instead returning the target as list of dicts, now returns it as
dict of lists. In addition, the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` are added which
wrap the data in the corresponding ``torchvision.datapoints``.
* :class:`~torchvision.datasets.VOCDetection`
* :class:`~torchvision.datasets.SBDataset`
* :class:`~torchvision.datasets.CelebA`
* :class:`~torchvision.datasets.Kitti`
* :class:`~torchvision.datasets.OxfordIIITPet`
* :class:`~torchvision.datasets.Cityscapes`
* :class:`~torchvision.datasets.WIDERFace`

Image classification datasets

This wrapper is a no-op for image classification datasets, since they were already fully supported by
:mod:`torchvision.transforms` and thus no change is needed for :mod:`torchvision.transforms.v2`.

Segmentation datasets

Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation` return a two-tuple of
:class:`PIL.Image.Image`'s. This wrapper leaves the image, i.e. the first item, as is, while wrapping the
segmentation mask, i.e. the second item, into a :class:`~torchvision.datapoints.Mask`.

Video classification datasets

Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple contained a
:class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
:class:`~torchvision.datapoints.Video` while leaving the other items as is.

.. note::

Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative
``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`.
"""
return VisionDatasetDatapointWrapper(dataset)


Expand Down Expand Up @@ -103,21 +148,13 @@ def raise_not_supported(description):
)


def identity(item):
return item


def identity_wrapper_factory(dataset):
def wrapper(idx, sample):
return sample

return wrapper


def pil_image_to_mask(pil_image):
return datapoints.Mask(pil_image)


def list_of_dicts_to_dict_of_lists(list_of_dicts):
dict_of_lists = defaultdict(list)
for dct in list_of_dicts:
Expand All @@ -131,7 +168,7 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
target = [target]

wrapped_target = tuple(
type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target)
type_wrappers.get(target_type, lambda x: x)(item) for target_type, item in zip(target_types, target)
)

if len(wrapped_target) == 1:
Expand Down Expand Up @@ -161,7 +198,7 @@ def classification_wrapper_factory(dataset):
def segmentation_wrapper_factory(dataset):
def wrapper(idx, sample):
image, mask = sample
return image, pil_image_to_mask(mask)
return image, datapoints.Mask(mask)

return wrapper

Expand Down Expand Up @@ -307,7 +344,7 @@ def wrapper(idx, sample):


@WRAPPER_FACTORIES.register(datasets.SBDataset)
def sbd_wrapper(dataset):
def sbdataset_wrapper(dataset):
if dataset.mode == "boundaries":
raise_not_supported("SBDataset with mode='boundaries'")

Expand Down Expand Up @@ -374,7 +411,7 @@ def wrapper(idx, sample):
target,
target_types=dataset._target_types,
type_wrappers={
"segmentation": pil_image_to_mask,
"segmentation": datapoints.Mask,
},
)

Expand All @@ -390,7 +427,7 @@ def cityscapes_wrapper_factory(dataset):

def instance_segmentation_wrapper(mask):
# See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21
data = pil_image_to_mask(mask)
data = datapoints.Mask(mask)
masks = []
labels = []
for id in data.unique():
Expand All @@ -409,7 +446,7 @@ def wrapper(idx, sample):
target_types=dataset.target_type,
type_wrappers={
"instance": instance_segmentation_wrapper,
"semantic": pil_image_to_mask,
"semantic": datapoints.Mask,
},
)

Expand Down