diff --git a/gallery/plot_custom_datapoints.py b/gallery/plot_custom_datapoints.py index 0a62a991a75..b0a48d75d6a 100644 --- a/gallery/plot_custom_datapoints.py +++ b/gallery/plot_custom_datapoints.py @@ -3,7 +3,7 @@ How to write your own Datapoint class ===================================== -This guide is intended for downstream library maintainers. We explain how to +This guide is intended for advanced users and downstream library maintainers. We explain how to write your own datapoint class, and how to make it compatible with the built-in Torchvision v2 transforms. Before continuing, make sure you have read :ref:`sphx_glr_auto_examples_plot_datapoints.py`. @@ -68,10 +68,6 @@ def hflip_my_datapoint(my_dp, *args, **kwargs): # could also have used the functional *itself*, i.e. # ``@register_kernel(functional=F.hflip, ...)``. # -# The functionals that you can be hooked into are the ones in -# ``torchvision.transforms.v2.functional`` and they are documented in -# :ref:`functional_transforms`. -# # Now that we have registered our kernel, we can call the functional API on a # ``MyDatapoint`` instance: diff --git a/gallery/plot_datapoints.py b/gallery/plot_datapoints.py index d87575cdb8e..5bbf6c200af 100644 --- a/gallery/plot_datapoints.py +++ b/gallery/plot_datapoints.py @@ -48,26 +48,22 @@ # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function # for the input data. # +# :mod:`torchvision.datapoints` supports four types of datapoints: +# +# * :class:`~torchvision.datapoints.Image` +# * :class:`~torchvision.datapoints.Video` +# * :class:`~torchvision.datapoints.BoundingBoxes` +# * :class:`~torchvision.datapoints.Mask` +# # What can I do with a datapoint? # ------------------------------- # # Datapoints look and feel just like regular tensors - they **are** tensors. # Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or -# any ``torch.*`` operator will also works on datapoints. See +# any ``torch.*`` operator will also work on datapoints. See # :ref:`datapoint_unwrapping_behaviour` for a few gotchas. # %% -# -# What datapoints are supported? -# ------------------------------ -# -# So far :mod:`torchvision.datapoints` supports four types of datapoints: -# -# * :class:`~torchvision.datapoints.Image` -# * :class:`~torchvision.datapoints.Video` -# * :class:`~torchvision.datapoints.BoundingBoxes` -# * :class:`~torchvision.datapoints.Mask` -# # .. _datapoint_creation: # # How do I construct a datapoint? @@ -209,9 +205,8 @@ def get_transform(train): # I had a Datapoint but now I have a Tensor. Help! # ------------------------------------------------ # -# For a lot of operations involving datapoints, we cannot safely infer whether -# the result should retain the datapoint type, so we choose to return a plain -# tensor instead of a datapoint (this might change, see note below): +# By default, operations on :class:`~torchvision.datapoints.Datapoint` objects +# will return a pure Tensor: assert isinstance(bboxes, datapoints.BoundingBoxes) @@ -219,32 +214,69 @@ def get_transform(train): # Shift bboxes by 3 pixels in both H and W new_bboxes = bboxes + 3 -assert isinstance(new_bboxes, torch.Tensor) and not isinstance(new_bboxes, datapoints.BoundingBoxes) +assert isinstance(new_bboxes, torch.Tensor) +assert not isinstance(new_bboxes, datapoints.BoundingBoxes) + +# %% +# .. note:: +# +# This behavior only affects native ``torch`` operations. If you are using +# the built-in ``torchvision`` transforms or functionals, you will always get +# as output the same type that you passed as input (pure ``Tensor`` or +# ``Datapoint``). # %% -# If you're writing your own custom transforms or code involving datapoints, you -# can re-wrap the output into a datapoint by just calling their constructor, or -# by using the ``.wrap_like()`` class method: +# But I want a Datapoint back! +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# You can re-wrap a pure tensor into a datapoint by just calling the datapoint +# constructor, or by using the ``.wrap_like()`` class method (see more details +# above in :ref:`datapoint_creation`): new_bboxes = bboxes + 3 new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) assert isinstance(new_bboxes, datapoints.BoundingBoxes) # %% -# See more details above in :ref:`datapoint_creation`. +# Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type` +# as a global config setting for the whole program, or as a context manager: + +with datapoints.set_return_type("datapoint"): + new_bboxes = bboxes + 3 +assert isinstance(new_bboxes, datapoints.BoundingBoxes) + +# %% +# Why is this happening? +# ^^^^^^^^^^^^^^^^^^^^^^ # -# .. note:: +# **For performance reasons**. :class:`~torchvision.datapoints.Datapoint` +# classes are Tensor subclasses, so any operation involving a +# :class:`~torchvision.datapoints.Datapoint` object will go through the +# `__torch_function__ +# `_ +# protocol. This induces a small overhead, which we want to avoid when possible. +# This doesn't matter for built-in ``torchvision`` transforms because we can +# avoid the overhead there, but it could be a problem in your model's +# ``forward``. # -# You never need to re-wrap manually if you're using the built-in transforms -# or their functional equivalents: this is automatically taken care of for -# you. +# **The alternative isn't much better anyway.** For every operation where +# preserving the :class:`~torchvision.datapoints.Datapoint` type makes +# sense, there are just as many operations where returning a pure Tensor is +# preferable: for example, is ``img.sum()`` still an :class:`~torchvision.datapoints.Image`? +# If we were to preserve :class:`~torchvision.datapoints.Datapoint` types all +# the way, even model's logits or the output of the loss function would end up +# being of type :class:`~torchvision.datapoints.Image`, and surely that's not +# desirable. # # .. note:: # -# This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you +# This behaviour is something we're actively seeking feedback on. If you find this surprising or if you # have any suggestions on how to better support your use-cases, please reach out to us via this issue: # https://github.com/pytorch/vision/issues/7319 # +# Exceptions +# ^^^^^^^^^^ +# # There are a few exceptions to this "unwrapping" rule: # # 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, diff --git a/test/test_datapoints.py b/test/test_datapoints.py index e38fcaf1d04..1042587e396 100644 --- a/test/test_datapoints.py +++ b/test/test_datapoints.py @@ -101,6 +101,7 @@ def test_to_datapoint_reference(make_input, return_type): assert type(tensor_to) is (type(dp) if return_type == "datapoint" else torch.Tensor) assert tensor_to.dtype is dp.dtype + assert type(tensor) is torch.Tensor @pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video]) diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index c4b5ee48d68..613a1fb8b25 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -66,19 +66,12 @@ def __torch_function__( ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the ``args`` and ``kwargs`` of the original call. - The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint` - use case, this has two downsides: + Why do we override this? Because the base implementation in torch.Tensor would preserve the Datapoint type + of the output. In our case, we want to return pure tensors instead (with a few exceptions). Refer to the + "Datapoints FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet). - 1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e. - ``return cls(func(*args, **kwargs))``, will fail for them. - 2. For most operations, there is no way of knowing if the input type is still valid for the output. - - For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are - listed in _FORCE_TORCHFUNCTION_SUBCLASS + Our implementation below is very similar to the base implementation in ``torch.Tensor`` - go check it out. """ - # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we - # need to reimplement the functionality. - if not all(issubclass(cls, t) for t in types): return NotImplemented @@ -89,12 +82,13 @@ def __torch_function__( must_return_subclass = _must_return_subclass() if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)): - # We also require the primary operand, i.e. `args[0]`, to be - # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will - # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, - # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with - # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would - # be wrapped into a `datapoints.Image`. + # If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails + # in test_to_datapoint_reference(). + # The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in + # the computation by walking the MRO upwards. For example, + # `out = a_pure_tensor.to(an_image)` will invoke `Image.__torch_function__` with + # `args = (a_pure_tensor, an_image)` first. Without this guard, `out` would + # be wrapped into an `Image`. return cls._wrap_output(output, args, kwargs) if not must_return_subclass and isinstance(output, cls): diff --git a/torchvision/datapoints/_torch_function_helpers.py b/torchvision/datapoints/_torch_function_helpers.py index 68674eb024d..6ab4f415802 100644 --- a/torchvision/datapoints/_torch_function_helpers.py +++ b/torchvision/datapoints/_torch_function_helpers.py @@ -18,12 +18,18 @@ def __exit__(self, *args): def set_return_type(return_type: str): """Set the return type of torch operations on datapoints. + This only affects the behaviour of torch operations. It has no effect on + ``torchvision`` transforms or functionals, which will always return as + output the same type that was passed as input. + Can be used as a global flag for the entire program: .. code:: python - set_return_type("datapoints") img = datapoints.Image(torch.rand(3, 5, 5)) + img + 2 # This is a pure Tensor (default behaviour) + + set_return_type("datapoints") img + 2 # This is an Image or as a context manager to restrict the scope: @@ -31,6 +37,7 @@ def set_return_type(return_type: str): .. code:: python img = datapoints.Image(torch.rand(3, 5, 5)) + img + 2 # This is a pure Tensor with set_return_type("datapoints"): img + 2 # This is an Image img + 2 # This is a pure Tensor diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 95145beee4d..1f5c6f5eea0 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -19,8 +19,15 @@ def is_simple_tensor(inpt: Any) -> bool: def _kernel_datapoint_wrapper(kernel): @functools.wraps(kernel) def wrapper(inpt, *args, **kwargs): - # We always pass datapoints as pure tensors to the kernels to avoid going through the - # Tensor.__torch_function__ logic, which is costly. + # If you're wondering whether we could / should get rid of this wrapper, + # the answer is no: we want to pass pure Tensors to avoid the overhead + # of the __torch_function__ machinery. Note that this is always valid, + # regardless of whether we override __torch_function__ in our base class + # or not. + # Also, even if we didn't call `as_subclass` here, we would still need + # this wrapper to call wrap_like(), because the Datapoint type would be + # lost after the first operation due to our own __torch_function__ + # logic. output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) return type(inpt).wrap_like(inpt, output)