-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Add tuto for custom transforms and custom datapoints in gallery example #7795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
270f4f0
d1d8aa9
24ca6cc
937069d
5780fd7
cbff036
43a71e7
99ea401
12fb8e1
abe4f05
24036da
88bf09b
1ad2465
bedb858
8178376
5f58689
7be27e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| """ | ||
| ===================================== | ||
| How to write your own Datapoint class | ||
| ===================================== | ||
|
|
||
| TODO | ||
| """ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| """ | ||
| ================================ | ||
| How to write your own transforms | ||
| ================================ | ||
|
|
||
| This guide explains how to write transforms that are compatible with the | ||
| torchvision transforms V2 API. | ||
| """ | ||
|
|
||
| # %% | ||
| import torch | ||
| import torchvision | ||
|
|
||
| # We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that | ||
| # some APIs may slightly change in the future | ||
| torchvision.disable_beta_transforms_warning() | ||
|
|
||
| from torchvision import datapoints | ||
| from torchvision.transforms import v2 | ||
|
|
||
|
|
||
| # %% | ||
| # Just create a ``nn.Module`` and override the ``forward`` method | ||
| # =============================================================== | ||
| # | ||
| # In most cases, this is all you're going to need, as long as you already know | ||
| # the structure of the input that your transform will expect. For example if | ||
| # you're just doing image classification, your transform will typically accept a | ||
| # single image as input, or a ``(img, label)`` input. So you can just hard-code | ||
| # your ``forward`` method to accept just that, e.g. | ||
| # | ||
| # .. code:: python | ||
| # | ||
| # class MyCustomTransform(torch.nn.Module): | ||
| # def forward(self, img, label): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to use this as example although users cannot use such a transform with our datasets? Our classification datasets indeed return
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honestly, I don't have an opinion, other than the fact that this probably doesn't matter and I don't feel like re-writing this. If you still think it's better to really illustrate this example the same way we have our datasets structured, I'll re-write.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leaning towards making this compatible with our datasets. OTOH, datasets aren't even mentioned here so it is a little less of a concern. I'm ok with leaving it as is. |
||
| # # Do some transformations | ||
| # return new_img, new_label | ||
| # | ||
| # .. note:: | ||
| # | ||
| # This means that if you have a custom transform that is already compatible | ||
| # with the V1 transforms (those in ``torchvision.transforms``), it will | ||
| # still work with the V2 transforms without any change! | ||
| # | ||
| # We will illustrate this more completely below with a typical detection case, | ||
| # where our samples are just images, bounding boxes and labels: | ||
|
|
||
| class MyCustomTransform(torch.nn.Module): | ||
| def forward(self, img, bboxes, label): # we assume inputs are always structured like this | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to above: do we want to use the |
||
| print( | ||
| f"I'm transforming an image of shape {img.shape} " | ||
| f"with bboxes = {bboxes}\n{label = }" | ||
| ) | ||
| # Do some transformations. Here, we're just passing though the input | ||
| return img, bboxes, label | ||
|
|
||
| transforms = v2.Compose([ | ||
| MyCustomTransform(), | ||
| v2.RandomResizedCrop((224, 224), antialias=True), | ||
| v2.RandomHorizontalFlip(p=1), | ||
| v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1]) | ||
| ]) | ||
|
|
||
| H, W = 256, 256 | ||
| img = torch.rand(3, H, W) | ||
| bboxes = datapoints.BoundingBoxes( | ||
| torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]), | ||
| format="XYXY", | ||
| canvas_size=(H, W) | ||
| ) | ||
| label = 3 | ||
|
|
||
| out_img, out_bboxes, out_label = transforms(img, bboxes, label) | ||
| # %% | ||
| print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }") | ||
| # %% | ||
| # .. note:: | ||
| # As you're maniupulate datapoint classes in your code, make sure to | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # familiarize yourself with this section: | ||
| # :ref:`datapoint_unwrapping_behaviour` | ||
| # | ||
| # Supporting arbitrary input structures | ||
| # ===================================== | ||
| # | ||
| # In the section above, we have assumed that you already know the structure of | ||
| # your inputs and that you're OK with hard-coding this expected structure in | ||
| # your code. If you want your custom transforms to be as flexible as possible, | ||
| # this can be a bit limitting. | ||
| # | ||
| # A key feature of the builtin Torchvision V2 transforms is that they can accept | ||
| # arbitrary input structure and return the same structure as output (with | ||
| # transformed entries). For example, transforms can accept a single image, or a | ||
| # tuple of ``(img, label)``, or an arbitrary nested dictionary as input: | ||
|
|
||
| structured_input = { | ||
| "img": img, | ||
| "annotations": (bboxes, label), | ||
| "something_that_will_be_ignored": (1, "hello") | ||
| } | ||
| structured_output = v2.RandomHorizontalFlip(p=1)(structured_input) | ||
|
|
||
| assert isinstance(structured_output, dict) | ||
| assert structured_output["something_that_will_be_ignored"] == (1, "hello") | ||
| print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") | ||
|
|
||
| # %% | ||
| # If you want to reproduce this behavior in your own transform, we invite you to | ||
| # look at our `code | ||
| # <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_ | ||
| # and adapt it to your needs. | ||
| # | ||
| # In brief, the core logic is to unpack the input into a flat list using `pytree | ||
| # <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and | ||
| # then transform only the entries that can be transformed (the decision is made | ||
| # based on the **class** of the entries, as all datapoints are | ||
| # tensor-subclasses) + some custom logic that is out of score here - check the | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # code for details. The (potentially transformed) entries are then repacked and | ||
| # returned, in the same structure as the input. | ||
| # | ||
| # We do not provide public dev-facing tools to achieve that at this time, but if | ||
| # this is something that would be valuable to you, please let us know by opening | ||
| # an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_. | ||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,13 +3,22 @@ | |||||||||
| Datapoints FAQ | ||||||||||
| ============== | ||||||||||
|
|
||||||||||
| The :mod:`torchvision.datapoints` namespace was introduced together with ``torchvision.transforms.v2``. This example | ||||||||||
| showcases what these datapoints are and how they behave. This is a fairly low-level topic that most users will not need | ||||||||||
| to worry about: you do not need to understand the internals of datapoints to efficiently rely on | ||||||||||
| ``torchvision.transforms.v2``. It may however be useful for advanced users trying to implement their own datasets, | ||||||||||
| transforms, or work directly with the datapoints. | ||||||||||
| Datapoints are Tensor subclasses introduced together with | ||||||||||
| ``torchvision.transforms.v2``. This example showcases what these datapoints are | ||||||||||
| and how they behave. | ||||||||||
|
|
||||||||||
| .. warning:: | ||||||||||
|
|
||||||||||
| **Intended Audience** Unless you're writing your own transforms or your own datapoints, you | ||||||||||
| probably do not need to read this guide. This is a fairly low-level topic | ||||||||||
| that most users will not need to worry about: you do not need to understand | ||||||||||
| the internals of datapoints to efficiently rely on | ||||||||||
| ``torchvision.transforms.v2``. It may however be useful for advanced users | ||||||||||
| trying to implement their own datasets, transforms, or work directly with | ||||||||||
| the datapoints. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| # %% | ||||||||||
| import PIL.Image | ||||||||||
|
|
||||||||||
| import torch | ||||||||||
|
|
@@ -35,11 +44,20 @@ | |||||||||
| assert isinstance(image, torch.Tensor) | ||||||||||
| assert image.data_ptr() == tensor.data_ptr() | ||||||||||
|
|
||||||||||
|
|
||||||||||
| # %% | ||||||||||
| # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function | ||||||||||
| # for the input data. | ||||||||||
| # | ||||||||||
| # What can I do with a datapoint? | ||||||||||
| # ------------------------------- | ||||||||||
| # | ||||||||||
| # Datapoints look and feel just like regular tensors - they **are** tensors. | ||||||||||
| # Everything that is supported on a plain :class:`torch.Tensor` like `.sum()` or | ||||||||||
| # any torch.*` operator will also works on datapoints. See | ||||||||||
| # :ref:`datapoint_unwrapping_behaviour` for more details. | ||||||||||
|
|
||||||||||
| # %% | ||||||||||
| # | ||||||||||
| # What datapoints are supported? | ||||||||||
| # ------------------------------ | ||||||||||
| # | ||||||||||
|
|
@@ -79,10 +97,10 @@ | |||||||||
| # :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the | ||||||||||
| # corresponding image alongside the actual values: | ||||||||||
|
|
||||||||||
| bounding_box = datapoints.BoundingBoxes( | ||||||||||
| [17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:] | ||||||||||
| bboxes = datapoints.BoundingBoxes( | ||||||||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
| [[17, 16, 344, 495], [0, 10, 0, 10]], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:] | ||||||||||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| ) | ||||||||||
| print(bounding_box) | ||||||||||
| print(bboxes) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| # %% | ||||||||||
|
|
@@ -105,8 +123,8 @@ class PennFudanDataset(torch.utils.data.Dataset): | |||||||||
| def __getitem__(self, item): | ||||||||||
| ... | ||||||||||
|
|
||||||||||
| target["boxes"] = datapoints.BoundingBoxes( | ||||||||||
| boxes, | ||||||||||
| target["bboxes"] = datapoints.BoundingBoxes( | ||||||||||
| bboxes, | ||||||||||
|
Comment on lines
+156
to
+157
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer we stick to the terms given in the tutorial this example is based on: https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
Suggested change
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure, I think we want to be consistent with the names used in this tutorial. What's being used in another tutorial is lower in priority when it comes to consistency
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason we have this example in the first place is because a user wanted to use the dataset from the tutorial and didn't get it working with v2: #6753 (comment). Plus, the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I'll move this section into another file, so let's revisit there. I think we should try find a better way to illustrate all this anyway because right now that section isn't even executable |
||||||||||
| format=datapoints.BoundingBoxFormat.XYXY, | ||||||||||
| canvas_size=F.get_size(img), | ||||||||||
| ) | ||||||||||
|
|
@@ -147,7 +165,7 @@ def get_transform(train): | |||||||||
| # %% | ||||||||||
| # .. note:: | ||||||||||
| # | ||||||||||
| # If both :class:`~torchvision.datapoints.BoundingBoxes`'es and :class:`~torchvision.datapoints.Mask`'s are included in | ||||||||||
| # If both :class:`~torchvision.datapoints.BoundingBoxes` and :class:`~torchvision.datapoints.Mask`'s are included in | ||||||||||
| # the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or | ||||||||||
| # at least not wrapping the obsolete parts, can lead to a significant performance boost. | ||||||||||
| # | ||||||||||
|
|
@@ -156,41 +174,63 @@ def get_transform(train): | |||||||||
| # even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are | ||||||||||
| # generated from the masks. | ||||||||||
| # | ||||||||||
| # How do the datapoints behave inside a computation? | ||||||||||
| # -------------------------------------------------- | ||||||||||
| # .. _datapoint_unwrapping_behaviour: | ||||||||||
| # | ||||||||||
| # Datapoints look and feel just like regular tensors. Everything that is supported on a plain :class:`torch.Tensor` | ||||||||||
| # also works on datapoints. | ||||||||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
| # Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the | ||||||||||
| # datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below): | ||||||||||
| # I had a Datapoint but now I have a Tensor. Help! | ||||||||||
| # ------------------------------------------------ | ||||||||||
| # | ||||||||||
| # For a lot of operations involving datapoints, we cannot safely infer whether | ||||||||||
| # the result should retain the datapoint type, so we choose to return a plain | ||||||||||
| # tensor instead of a datapoint (this might change, see note below): | ||||||||||
|
|
||||||||||
|
|
||||||||||
| assert isinstance(image, datapoints.Image) | ||||||||||
| assert isinstance(bboxes, datapoints.BoundingBoxes) | ||||||||||
|
|
||||||||||
| new_image = image + 0 | ||||||||||
| # Shift bboxes by 3 pixels in both H and W | ||||||||||
| new_bboxes = bboxes + 3 | ||||||||||
|
|
||||||||||
| assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image) | ||||||||||
| assert isinstance(new_bboxes, torch.Tensor) and not isinstance(new_bboxes, datapoints.BoundingBoxes) | ||||||||||
|
|
||||||||||
| # %% | ||||||||||
| # If you're writing your own custom transforms or code involving datapoints, you | ||||||||||
| # can re-wrap the output into a datapoint by just calling their constructor, or | ||||||||||
| # by using the ``.wrap_like()`` class method: | ||||||||||
|
|
||||||||||
| new_bboxes = bboxes + 3 | ||||||||||
| new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) | ||||||||||
| assert isinstance(new_bboxes, datapoints.BoundingBoxes) | ||||||||||
|
|
||||||||||
| # %% | ||||||||||
| # .. note:: | ||||||||||
| # | ||||||||||
| # You never need to re-wrap manually if you're using the built-in transforms | ||||||||||
| # or their functional equivalents, because this logic is taken care of for | ||||||||||
| # you. | ||||||||||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| # | ||||||||||
| # .. note:: | ||||||||||
| # | ||||||||||
| # This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you | ||||||||||
| # have any suggestions on how to better support your use-cases, please reach out to us via this issue: | ||||||||||
| # https://github.com/pytorch/vision/issues/7319 | ||||||||||
| # | ||||||||||
| # There are two exceptions to this rule: | ||||||||||
| # There are two exceptions to this "unwrapping" rule: | ||||||||||
| # | ||||||||||
| # 1. The operations :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, and :meth:`~torch.Tensor.requires_grad_` | ||||||||||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| # retain the datapoint type. | ||||||||||
| # 2. Inplace operations on datapoints cannot change the type of the datapoint they are called on. However, if you use | ||||||||||
| # the flow style, the returned value will be unwrapped: | ||||||||||
| # 2. Inplace operations on datapoints like ``.add_()`` preserve they type. However, | ||||||||||
| # the **returned** value of inplace operations will be unwrapped into a pure | ||||||||||
| # tensor: | ||||||||||
|
|
||||||||||
| image = datapoints.Image([[[0, 1], [1, 0]]]) | ||||||||||
|
|
||||||||||
| new_image = image.add_(1).mul_(2) | ||||||||||
|
|
||||||||||
| assert isinstance(image, torch.Tensor) | ||||||||||
| # image got transformed in-place and is still an Image datapoint, but new_image | ||||||||||
| # is a Tensor. They share the same underlying data and they're equal, just | ||||||||||
| # different classes. | ||||||||||
| assert isinstance(image, datapoints.Image) | ||||||||||
| print(image) | ||||||||||
|
|
||||||||||
| assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image) | ||||||||||
| assert (new_image == image).all() | ||||||||||
| assert new_image.data_ptr() == image.data_ptr() | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can do this to hammer the point home, but I feel one of them would be sufficient. |
||||||||||
Uh oh!
There was an error while loading. Please reload this page.