diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index 2fc79a50612..e6d2321fc80 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -3,6 +3,25 @@ from torchvision.prototype import datapoints +@pytest.mark.parametrize( + ("data", "input_requires_grad", "expected_requires_grad"), + [ + ([0.0], None, False), + ([0.0], False, False), + ([0.0], True, True), + (torch.tensor([0.0], requires_grad=False), None, False), + (torch.tensor([0.0], requires_grad=False), False, False), + (torch.tensor([0.0], requires_grad=False), True, True), + (torch.tensor([0.0], requires_grad=True), None, True), + (torch.tensor([0.0], requires_grad=True), False, False), + (torch.tensor([0.0], requires_grad=True), True, True), + ], +) +def test_new_requires_grad(data, input_requires_grad, expected_requires_grad): + datapoint = datapoints.Label(data, requires_grad=input_requires_grad) + assert datapoint.requires_grad is expected_requires_grad + + def test_isinstance(): assert isinstance( datapoints.Label([0, 1, 0], categories=["foo", "bar"]), diff --git a/torchvision/prototype/datapoints/_bounding_box.py b/torchvision/prototype/datapoints/_bounding_box.py index 398770cbf6a..f3c9b6b345b 100644 --- a/torchvision/prototype/datapoints/_bounding_box.py +++ b/torchvision/prototype/datapoints/_bounding_box.py @@ -34,7 +34,7 @@ def __new__( spatial_size: Tuple[int, int], dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, - requires_grad: bool = False, + requires_grad: Optional[bool] = None, ) -> BoundingBox: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) diff --git a/torchvision/prototype/datapoints/_datapoint.py b/torchvision/prototype/datapoints/_datapoint.py index 848808d0250..fbd19ad86f1 100644 --- a/torchvision/prototype/datapoints/_datapoint.py +++ b/torchvision/prototype/datapoints/_datapoint.py @@ -23,8 +23,10 @@ def _to_tensor( data: Any, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, - requires_grad: bool = False, + requires_grad: Optional[bool] = None, ) -> torch.Tensor: + if requires_grad is None: + requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) # FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a @@ -36,7 +38,7 @@ def __new__( data: Any, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, - requires_grad: bool = False, + requires_grad: Optional[bool] = None, ) -> Datapoint: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) return tensor.as_subclass(Datapoint) diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index 56939bf14d9..4ffeb37d5eb 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -21,7 +21,7 @@ def __new__( *, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, - requires_grad: bool = False, + requires_grad: Optional[bool] = None, ) -> Image: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if tensor.ndim < 2: diff --git a/torchvision/prototype/datapoints/_label.py b/torchvision/prototype/datapoints/_label.py index 54915493390..0ee2eb9f8e1 100644 --- a/torchvision/prototype/datapoints/_label.py +++ b/torchvision/prototype/datapoints/_label.py @@ -27,7 +27,7 @@ def __new__( categories: Optional[Sequence[str]] = None, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, - requires_grad: bool = False, + requires_grad: Optional[bool] = None, ) -> L: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) return cls._wrap(tensor, categories=categories) diff --git a/torchvision/prototype/datapoints/_mask.py b/torchvision/prototype/datapoints/_mask.py index ca4aba87d2e..834f990512b 100644 --- a/torchvision/prototype/datapoints/_mask.py +++ b/torchvision/prototype/datapoints/_mask.py @@ -19,7 +19,7 @@ def __new__( *, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, - requires_grad: bool = False, + requires_grad: Optional[bool] = None, ) -> Mask: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) return cls._wrap(tensor) diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index 6c24197a9ca..5cc8370cd7b 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -20,7 +20,7 @@ def __new__( *, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, - requires_grad: bool = False, + requires_grad: Optional[bool] = None, ) -> Video: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if data.ndim < 4: