Skip to content

Conversation

@pmeier
Copy link
Contributor

@pmeier pmeier commented Jan 26, 2023

Previously, constructing a new datapoint from an existing tensor would simply overwrite the requires_grad flag:

return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)

This is the correct behavior in all but one case: if we pass a tensor that has requires_grad=True and don't specify that in the call, the creation fails:

import torch
from torchvision.prototype import datapoints

t1 = torch.rand(3, 2, 2).requires_grad_(True)
assert t1.requires_grad

i1 = datapoints.Image(t1, requires_grad=True)
assert i1.data_ptr() == t1.data_ptr()
assert i1.requires_grad

t2 = t1.clone()
assert t2.requires_grad

i2 = datapoints.Image(t2)
RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().

This PR implements a passthrough for requires_grad in the default case if the input is a tensor. For non-tensor inputs, the behavior stays exactly the same.

cc @bjuncek

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix Philip.

This makes me wonder whether we actually want/need to have a requires_grad parameter to the datapoints constructors? (I guess the same question goes for the rest of the parameters e.g. dtype and device)

@pmeier
Copy link
Contributor Author

pmeier commented Jan 27, 2023

The reason we have them there is that our constructor is basically acts as creation function. Meaning you can do datapoints.Label(0, dtype=..., ...) without going the long way like datapoints.Label(torch.tensor(0, dtype=..., ...)).

This is different from the (undocumented) constructor of a plain tensor. It just takes the input data and handles it in a way that is different from torch.tensor and the rest of PyTorch (sequence inputs are the same, but if you input a scalar int, the output is basically torch.empty(size)).

Our thinking was that especially wrapping non-tensor stuff into our datapoints will be really common coming from arbitrary datasets, so we can make this use case more convenient. Although we have different behavior compared to a plain tensor constructor, we were ok with that since it is undocumented and I haven't seen this used in the real world anywhere.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the details. I had the wrong understanding that users would mostly just create datapoints out of tensors, but I realize now that Image(some_list) is a perfectly valid thing to do

@pmeier pmeier merged commit 2bc8a14 into pytorch:main Jan 27, 2023
@pmeier pmeier deleted the requires-grad-passthrough branch January 27, 2023 10:25
@github-actions
Copy link

Hey @pmeier!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

facebook-github-bot pushed a commit that referenced this pull request Feb 8, 2023
Reviewed By: vmoens

Differential Revision: D43116110

fbshipit-source-id: 1d58bff6d14b79849fbe76a8a2b017e5461315c1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants