Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,3 @@ def test_deepcopy(datapoint, requires_grad):

assert type(datapoint_deepcopied) is type(datapoint)
assert datapoint_deepcopied.requires_grad is requires_grad
assert datapoint_deepcopied.is_leaf
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the deepcopy isn't a leaf anymore because it went through wrap_like(), so it's got an "ancestor".
I don't think is_leaf is part of the deepcopy contract anyway? I don't think we really need to enforce this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the deepcopy isn't a leaf anymore because it went through wrap_like(), so it's got an "ancestor". I don't think is_leaf is part of the deepcopy contract anyway? I don't think we really need to enforce this.

I don't think it is specified anywhere, so I'm ok with removing this check. Might be surprising to users though if they bank on this. Let's find out though 🤷

28 changes: 11 additions & 17 deletions torchvision/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,9 @@ def _to_tensor(
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)

_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
torch.Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically
torch.Tensor.requires_grad_: lambda cls, input, output: output,
}
# The ops in this set are those that should *preserve* the Datapoint type,
# i.e. they are exceptions to the "no wrapping" rule.
_NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}

@classmethod
def __torch_function__(
Expand Down Expand Up @@ -79,22 +74,21 @@ def __torch_function__(
with DisableTorchFunctionSubclass():
output = func(*args, **kwargs or dict())

wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls):
# We also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`.
if wrapper and isinstance(args[0], cls):
return wrapper(cls, args[0], output)
return cls.wrap_like(args[0], output)

# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
# will retain the input type. Thus, we need to unwrap here.
if isinstance(output, cls):
return output.as_subclass(torch.Tensor)
if isinstance(output, cls):
# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
# so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
Comment on lines +83 to +85
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also move most of the content of the DisableTorchFunctionSubclass out of it. The only part that matters is the call to func, the rest can be outside.

return output.as_subclass(torch.Tensor)

return output
return output

def _make_repr(self, **kwargs: Any) -> str:
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
Expand Down