From 1df3b8a88d6d64f9a89997127a98242d936b5ce4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 28 Jul 2023 09:16:30 +0100 Subject: [PATCH 1/6] upper case for names --- docs/source/transforms.rst | 4 ++-- gallery/plot_cutmix_mixup.py | 2 +- torchvision/transforms/v2/_augment.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index b29c22ee1f6..c4e4736b8db 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -261,10 +261,10 @@ The new transform can be used standalone or mixed-and-matched with existing tran AugMix v2.AugMix -Cutmix - Mixup +CutMix - MixUp -------------- -Cutmix and Mixup are special transforms that +CutMix and MixUp are special transforms that are meant to be used on batches rather than on individual images, because they are combining pairs of images together. These can be used after the dataloader, or part of a collation function. See diff --git a/gallery/plot_cutmix_mixup.py b/gallery/plot_cutmix_mixup.py index 19838fe907d..01d9db606a6 100644 --- a/gallery/plot_cutmix_mixup.py +++ b/gallery/plot_cutmix_mixup.py @@ -1,7 +1,7 @@ """ =========================== -How to use Cutmix and Mixup +How to use CutMix and MixUp =========================== TODO diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index fb4e23f5fe2..0841eebd983 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -204,7 +204,7 @@ def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: class Mixup(_BaseMixupCutmix): - """[BETA] Apply Mixup to the provided batch of images and labels. + """[BETA] Apply MixUp to the provided batch of images and labels. .. v2betastatus:: Mixup transform @@ -246,7 +246,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class Cutmix(_BaseMixupCutmix): - """[BETA] Apply Cutmix to the provided batch of images and labels. + """[BETA] Apply CutMix to the provided batch of images and labels. .. v2betastatus:: Cutmix transform From 8b8b7529aaf5f2e143d3fc65e06cc6dfcfd12d90 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 28 Jul 2023 10:18:20 +0100 Subject: [PATCH 2/6] Add CutMix and Mixup gallery example --- docs/Makefile | 2 +- docs/source/transforms.rst | 4 +- gallery/plot_cutmix_mixup.py | 141 +++++++++++++++++++++++++- test/test_transforms_v2_refactored.py | 2 +- torchvision/transforms/v2/_augment.py | 20 +++- 5 files changed, 161 insertions(+), 8 deletions(-) diff --git a/docs/Makefile b/docs/Makefile index f462ff22303..f5987993175 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -6,7 +6,7 @@ ifneq ($(EXAMPLES_PATTERN),) endif # You can set these variables from the command line. -SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS) +SPHINXOPTS = -j auto $(EXAMPLES_PATTERN_OPTS) SPHINXBUILD = sphinx-build SPHINXPROJ = torchvision SOURCEDIR = source diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index c4e4736b8db..21b5986476e 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -266,8 +266,8 @@ CutMix - MixUp CutMix and MixUp are special transforms that are meant to be used on batches rather than on individual images, because they -are combining pairs of images together. These can be used after the dataloader, -or part of a collation function. See +are combining pairs of images together. These can be used after the dataloader +(once the samples are batched), or part of a collation function. See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples. .. autosummary:: diff --git a/gallery/plot_cutmix_mixup.py b/gallery/plot_cutmix_mixup.py index 01d9db606a6..1f7f91282aa 100644 --- a/gallery/plot_cutmix_mixup.py +++ b/gallery/plot_cutmix_mixup.py @@ -4,5 +4,144 @@ How to use CutMix and MixUp =========================== -TODO +:class:`~torchvision.transforms.v2.Cutmix` and +:class:`~torchvision.transforms.v2.Mixup` are popular augmentation strategies +that can improve classification accuracy. + +These transforms are slightly different from the rest of the Torchvision +transforms, because they expect +**batches** of samples as input, not individual images. In this example we'll +explain how to use them: after the ``DataLoader``, or as part of a collation +function. """ + +# %% +import torch +import torchvision +from torchvision.datasets import FakeData + +# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that +# some APIs may slightly change in the future +torchvision.disable_beta_transforms_warning() + +from torchvision.transforms import v2 + + +NUM_CLASSES = 100 + +# %% +# Pre-processing pipeline +# ----------------------- +# +# We'll use a simple but typical image classification pipeline: + +preproc = v2.Compose([ + v2.PILToTensor(), + v2.RandomResizedCrop(size=(224, 224), antialias=True), + v2.RandomHorizontalFlip(p=0.5), + v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1] + v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet +]) + +dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc) + +img, label = dataset[0] +print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }") + +# %% +# +# One important thing to note is that neither CutMix nor MixUp are part of this +# pre-processing pipeline. We'll add them a bit later once we define the +# DataLoader. Just as a refresher, this is what the DataLoader and training loop +# would look like if we weren't using CutMix or MixUp: + +from torch.utils.data import DataLoader + +dataloader = DataLoader(dataset, batch_size=4, shuffle=True) + +for images, labels in dataloader: + print(f"{images.shape = }, {labels.shape = }") + print(labels.dtype) + # + break +# %% + +# %% +# Where to use MixUp and CutMix +# ----------------------------- +# +# After the DataLoader +# ^^^^^^^^^^^^^^^^^^^^ +# +# Now let's add CutMix and MixUp. The simplest way to do this right after the +# DataLoader: the Dataloader has already batched the images and labels for us, +# and this is exactly what these transforms expect as input: + +dataloader = DataLoader(dataset, batch_size=4, shuffle=True) + +cutmix = v2.Cutmix(num_classes=NUM_CLASSES) +mixup = v2.Mixup(num_classes=NUM_CLASSES) +cutmix_or_mixup = v2.RandomChoice([cutmix, mixup]) + +for images, labels in dataloader: + print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }") + images, labels = cutmix_or_mixup(images, labels) + print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }") + + # + break +# %% +# +# Note how the labels were also transformed: we went from a batched label of +# shape (batch_size,) to a tensor of shape (batch_size, num_classes). The +# transformed labels can still be passed as-is to a loss function like +# :func:`torch.nn.functional.cross_entropy`. +# +# As part of the collation function +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Passing the transforms after the DataLoader is the simplest way to use CutMix +# and MixUp, but one disadvantage is that it does not take advantage of the +# DataLoader multi-processing. For that, we can pass those transforms as part of +# the collation function (refer to the `PyTorch docs +# `_ to learn +# more about collation). + +from torch.utils.data import default_collate + +def collate_fn(batch): + return cutmix_or_mixup(*default_collate(batch)) + +dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn) + +for images, labels in dataloader: + print(f"{images.shape = }, {labels.shape = }") + # No need to call cutmix_or_mixup, it's already been called as part of the DataLoader! + # + break + +# %% +# Non-standard input format +# ------------------------- +# +# So far we've used a typical sample structure where we pass ``(images, +# labels)`` as inputs. MixUp and CutMix will magically work by default with most +# common sample structures: tuples where the second parameter is a tensor label, +# or dict with a "label[s]" key. Look at the documentation of the +# ``labels_getter`` parameter for more details. +# +# If your samples have a different structure, you can still use CutMix and MixUp +# by passing a callable to the ``labels_getter`` parameter. For example: + +batch = { + "imgs": torch.rand(4, 3, 224, 224), + "target": { + "classes": torch.randint(0, NUM_CLASSES, size=(4,)), + "some_other_key": "this is going to be passed-through" + } +} +def labels_getter(batch): + return batch["target"]["classes"] + +out = v2.Cutmix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch) +print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }") diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 0ec3c5f01ee..069d5b7e5d2 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1922,7 +1922,7 @@ def test_supported_input_structure(self, T): dataset = self.DummyDataset(size=batch_size, num_classes=num_classes) - cutmix_mixup = T(alpha=0.5, num_classes=num_classes) + cutmix_mixup = T(num_classes=num_classes) dl = DataLoader(dataset, batch_size=batch_size) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 0841eebd983..cf5a4930639 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -141,7 +141,7 @@ def _transform( class _BaseMixupCutmix(Transform): - def __init__(self, *, alpha: float = 1, num_classes: int, labels_getter="default") -> None: + def __init__(self, *, alpha: float = 1., num_classes: int, labels_getter="default") -> None: super().__init__() self.alpha = alpha self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) @@ -210,7 +210,14 @@ class Mixup(_BaseMixupCutmix): Paper: `mixup: Beyond Empirical Risk Minimization `_. - See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples. + .. note:: + This transform is meant to be used on **batches** of samples, not + individual images. See + :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage + examples. + The sample pairing is deterministic and done by matching consecutive + samples in the batch, so the batch needs to be shuffled (this is an + implementation detail, not a guaranteed convention.) In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed into a tensor of shape ``(batch_size, num_classes)``. @@ -253,7 +260,14 @@ class Cutmix(_BaseMixupCutmix): Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features `_. - See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples. + .. note:: + This transform is meant to be used on **batches** of samples, not + individual images. See + :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage + examples. + The sample pairing is deterministic and done by matching consecutive + samples in the batch, so the batch needs to be shuffled (this is an + implementation detail, not a guaranteed convention.) In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed into a tensor of shape ``(batch_size, num_classes)``. From fa9790c2ebbbceb8070e8c75c9f9797da32d69a3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 31 Jul 2023 10:03:42 +0100 Subject: [PATCH 3/6] address comment --- torchvision/transforms/v2/_augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index cf5a4930639..aad7793a7dc 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -143,7 +143,7 @@ def _transform( class _BaseMixupCutmix(Transform): def __init__(self, *, alpha: float = 1., num_classes: int, labels_getter="default") -> None: super().__init__() - self.alpha = alpha + self.alpha = float(alpha) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) self.num_classes = num_classes From e7e597787b7f2cef38dd179e5810ac0f377b97a7 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 31 Jul 2023 21:12:55 +0100 Subject: [PATCH 4/6] put back -W --- docs/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Makefile b/docs/Makefile index f5987993175..f462ff22303 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -6,7 +6,7 @@ ifneq ($(EXAMPLES_PATTERN),) endif # You can set these variables from the command line. -SPHINXOPTS = -j auto $(EXAMPLES_PATTERN_OPTS) +SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS) SPHINXBUILD = sphinx-build SPHINXPROJ = torchvision SOURCEDIR = source From 4f54ac0413f4e37956316362bde6888439c44f37 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 31 Jul 2023 21:21:50 +0100 Subject: [PATCH 5/6] no chill --- gallery/plot_cutmix_mixup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/plot_cutmix_mixup.py b/gallery/plot_cutmix_mixup.py index 1f7f91282aa..59529e5dc0a 100644 --- a/gallery/plot_cutmix_mixup.py +++ b/gallery/plot_cutmix_mixup.py @@ -98,7 +98,7 @@ # :func:`torch.nn.functional.cross_entropy`. # # As part of the collation function -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # Passing the transforms after the DataLoader is the simplest way to use CutMix # and MixUp, but one disadvantage is that it does not take advantage of the From afb15d81d4075c316b19c0fc76ab9e15a994cc40 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 31 Jul 2023 21:24:08 +0100 Subject: [PATCH 6/6] lint --- gallery/plot_cutmix_mixup.py | 5 +++++ torchvision/transforms/v2/_augment.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/gallery/plot_cutmix_mixup.py b/gallery/plot_cutmix_mixup.py index 59529e5dc0a..d1c92a27812 100644 --- a/gallery/plot_cutmix_mixup.py +++ b/gallery/plot_cutmix_mixup.py @@ -109,9 +109,11 @@ from torch.utils.data import default_collate + def collate_fn(batch): return cutmix_or_mixup(*default_collate(batch)) + dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn) for images, labels in dataloader: @@ -140,8 +142,11 @@ def collate_fn(batch): "some_other_key": "this is going to be passed-through" } } + + def labels_getter(batch): return batch["target"]["classes"] + out = v2.Cutmix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch) print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }") diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index aad7793a7dc..39bc2a2ce01 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -141,7 +141,7 @@ def _transform( class _BaseMixupCutmix(Transform): - def __init__(self, *, alpha: float = 1., num_classes: int, labels_getter="default") -> None: + def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None: super().__init__() self.alpha = float(alpha) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))