Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 135 additions & 160 deletions test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,16 @@

import pytest
import torch
from common_utils import assert_equal
from common_utils import assert_equal, make_bounding_box, make_image, make_segmentation_mask, make_video
from PIL import Image

from torchvision import datapoints
from common_utils import (
make_bounding_box,
make_detection_mask,
make_image,
make_image_tensor,
make_segmentation_mask,
make_video,
)


@pytest.fixture(autouse=True)
def preserve_default_wrapping_behaviour():
def restore_tensor_return_type():
# This is for security, as we should already be restoring the default manually in each test anyway
# (at least at the time of writing...)
yield
datapoints.set_return_type("Tensor")

Expand Down Expand Up @@ -74,8 +68,9 @@ def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
assert datapoint.requires_grad is expected_requires_grad


def test_isinstance():
assert isinstance(datapoints.Image(torch.rand(3, 16, 16)), torch.Tensor)
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
def test_isinstance(make_input):
assert isinstance(make_input(), torch.Tensor)


def test_wrapping_no_copy():
Expand All @@ -85,65 +80,71 @@ def test_wrapping_no_copy():
assert image.data_ptr() == tensor.data_ptr()


def test_to_wrapping():
image = datapoints.Image(torch.rand(3, 16, 16))
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
def test_to_wrapping(make_input):
dp = make_input()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also call it datapoint, although it's very close to the datapoints module. Or inpt. Or input. Or foo. Or aejfbnakejfbk (preferred). I don't care.


image_to = image.to(torch.float64)
dp_to = dp.to(torch.float64)

assert type(image_to) is datapoints.Image
assert image_to.dtype is torch.float64
assert type(dp_to) is type(dp)
assert dp_to.dtype is torch.float64


@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_to_datapoint_reference(return_type):
def test_to_datapoint_reference(make_input, return_type):
tensor = torch.rand((3, 16, 16), dtype=torch.float64)
image = datapoints.Image(tensor)
dp = make_input()

with datapoints.set_return_type(return_type):
tensor_to = tensor.to(image)
tensor_to = tensor.to(dp)

assert type(tensor_to) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
assert tensor_to.dtype is torch.float64
assert type(tensor_to) is (type(dp) if return_type == "datapoint" else torch.Tensor)
assert tensor_to.dtype is dp.dtype


@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_clone_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16))
def test_clone_wrapping(make_input, return_type):
dp = make_input()

with datapoints.set_return_type(return_type):
image_clone = image.clone()
dp_clone = dp.clone()

assert type(image_clone) is datapoints.Image
assert image_clone.data_ptr() != image.data_ptr()
assert type(dp_clone) is type(dp)
assert dp_clone.data_ptr() != dp.data_ptr()


@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_requires_grad__wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16))
def test_requires_grad__wrapping(make_input, return_type):
dp = make_input(dtype=torch.float)

assert not image.requires_grad
assert not dp.requires_grad

with datapoints.set_return_type(return_type):
image_requires_grad = image.requires_grad_(True)
dp_requires_grad = dp.requires_grad_(True)

assert type(image_requires_grad) is datapoints.Image
assert image.requires_grad
assert image_requires_grad.requires_grad
assert type(dp_requires_grad) is type(dp)
assert dp.requires_grad
assert dp_requires_grad.requires_grad


@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_detach_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16), requires_grad=True)
def test_detach_wrapping(make_input, return_type):
dp = make_input(dtype=torch.float).requires_grad_(True)

with datapoints.set_return_type(return_type):
image_detached = image.detach()
dp_detached = dp.detach()

assert type(image_detached) is datapoints.Image
assert type(dp_detached) is type(dp)


@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_force_subclass_with_metadata(return_type):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and datapoints with metadata
# Largely the same as above, we additionally check that the metadata is preserved
format, canvas_size = "XYXY", (32, 32)
bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)

Expand All @@ -165,19 +166,22 @@ def test_force_subclass_with_metadata(return_type):
if return_type == "datapoint":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad
datapoints.set_return_type("tensor")


@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_other_op_no_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16))
def test_other_op_no_wrapping(make_input, return_type):
dp = make_input()

with datapoints.set_return_type(return_type):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = image * 2
output = dp * 2

assert type(output) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
assert type(output) is (type(dp) if return_type == "datapoint" else torch.Tensor)


@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
@pytest.mark.parametrize(
"op",
[
Expand All @@ -186,146 +190,117 @@ def test_other_op_no_wrapping(return_type):
lambda t: t.max(dim=-1),
],
)
def test_no_tensor_output_op_no_wrapping(op):
image = datapoints.Image(torch.rand(3, 16, 16))
def test_no_tensor_output_op_no_wrapping(make_input, op):
dp = make_input()

output = op(image)
output = op(dp)

assert type(output) is not datapoints.Image
assert type(output) is not type(dp)


@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_inplace_op_no_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16))
def test_inplace_op_no_wrapping(make_input, return_type):
dp = make_input()
original_type = type(dp)

with datapoints.set_return_type(return_type):
output = image.add_(0)
output = dp.add_(0)

assert type(output) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
assert type(image) is datapoints.Image
assert type(output) is (type(dp) if return_type == "datapoint" else torch.Tensor)
assert type(dp) is original_type


def test_wrap_like():
image = datapoints.Image(torch.rand(3, 16, 16))
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
def test_wrap_like(make_input):
dp = make_input()

# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = image * 2
output = dp * 2

image_new = datapoints.Image.wrap_like(image, output)
dp_new = type(dp).wrap_like(dp, output)

assert type(image_new) is datapoints.Image
assert image_new.data_ptr() == output.data_ptr()
assert type(dp_new) is type(dp)
assert dp_new.data_ptr() == output.data_ptr()


@pytest.mark.parametrize(
"datapoint",
[
datapoints.Image(torch.rand(3, 16, 16)),
datapoints.Video(torch.rand(2, 3, 16, 16)),
datapoints.BoundingBoxes([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(10, 10)),
datapoints.Mask(torch.randint(0, 256, (16, 16), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
@pytest.mark.parametrize("requires_grad", [False, True])
def test_deepcopy(datapoint, requires_grad):
if requires_grad and not datapoint.dtype.is_floating_point:
return
def test_deepcopy(make_input, requires_grad):
dp = make_input(dtype=torch.float)

datapoint.requires_grad_(requires_grad)
dp.requires_grad_(requires_grad)

datapoint_deepcopied = deepcopy(datapoint)
dp_deepcopied = deepcopy(dp)

assert datapoint_deepcopied is not datapoint
assert datapoint_deepcopied.data_ptr() != datapoint.data_ptr()
assert_equal(datapoint_deepcopied, datapoint)
assert dp_deepcopied is not dp
assert dp_deepcopied.data_ptr() != dp.data_ptr()
assert_equal(dp_deepcopied, dp)

assert type(datapoint_deepcopied) is type(datapoint)
assert datapoint_deepcopied.requires_grad is requires_grad
assert type(dp_deepcopied) is type(dp)
assert dp_deepcopied.requires_grad is requires_grad


@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_operations(return_type):
datapoints.set_return_type(return_type)
@pytest.mark.parametrize(
"op",
(
lambda dp: dp + torch.rand(*dp.shape),
lambda dp: torch.rand(*dp.shape) + dp,
lambda dp: dp * torch.rand(*dp.shape),
lambda dp: torch.rand(*dp.shape) * dp,
lambda dp: dp + 3,
lambda dp: 3 + dp,
lambda dp: dp + dp,
lambda dp: dp.sum(),
lambda dp: dp.reshape(-1),
lambda dp: dp.int(),
lambda dp: torch.stack([dp, dp]),
lambda dp: torch.chunk(dp, 2)[0],
lambda dp: torch.unbind(dp)[0],
),
)
def test_usual_operations(make_input, return_type, op):

dp = make_input()
with datapoints.set_return_type(return_type):
out = op(dp)
assert type(out) is (type(dp) if return_type == "datapoint" else torch.Tensor)
if isinstance(dp, datapoints.BoundingBoxes) and return_type == "datapoint":
assert hasattr(out, "format")
assert hasattr(out, "canvas_size")


def test_subclasses():
img = make_image()
masks = make_segmentation_mask()

with pytest.raises(TypeError, match="unsupported operand"):
img + masks


def test_set_return_type():
img = make_image()

assert type(img + 3) is torch.Tensor

with datapoints.set_return_type("datapoint"):
assert type(img + 3) is datapoints.Image
assert type(img + 3) is torch.Tensor

datapoints.set_return_type("datapoint")
assert type(img + 3) is datapoints.Image

with datapoints.set_return_type("tensor"):
assert type(img + 3) is torch.Tensor
with datapoints.set_return_type("datapoint"):
assert type(img + 3) is datapoints.Image
datapoints.set_return_type("tensor")
assert type(img + 3) is torch.Tensor
assert type(img + 3) is torch.Tensor
# Exiting a context manager will restore the return type as it was prior to entering it,
# regardless of whether the "global" datapoints.set_return_type() was called within the context manager.
assert type(img + 3) is datapoints.Image
Comment on lines +302 to +304
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW This part of the test is more for illustration purpose of what happens, rather than something we want to enforce. It's not something we need to prevent either - really not worth it.


img = datapoints.Image(torch.rand(3, 10, 10))
t = torch.rand(3, 10, 10)
mask = datapoints.Mask(torch.rand(1, 10, 10))

for out in (
[
img + t,
t + img,
img * t,
t * img,
img + 3,
3 + img,
img * 3,
3 * img,
img + img,
img.sum(),
img.reshape(-1),
img.float(),
torch.stack([img, img]),
]
+ list(torch.chunk(img, 2))
+ list(torch.unbind(img))
):
assert type(out) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)

for out in (
[
mask + t,
t + mask,
mask * t,
t * mask,
mask + 3,
3 + mask,
mask * 3,
3 * mask,
mask + mask,
mask.sum(),
mask.reshape(-1),
mask.float(),
torch.stack([mask, mask]),
]
+ list(torch.chunk(mask, 2))
+ list(torch.unbind(mask))
):
assert type(out) is (datapoints.Mask if return_type == "datapoint" else torch.Tensor)

with pytest.raises(TypeError, match="unsupported operand type"):
img + mask

with pytest.raises(TypeError, match="unsupported operand type"):
img * mask

bboxes = datapoints.BoundingBoxes(
[[17, 16, 344, 495], [0, 10, 0, 10]], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1000, 1000)
)
t = torch.rand(2, 4)

for out in (
[
bboxes + t,
t + bboxes,
bboxes * t,
t * bboxes,
bboxes + 3,
3 + bboxes,
bboxes * 3,
3 * bboxes,
bboxes + bboxes,
bboxes.sum(),
bboxes.reshape(-1),
bboxes.float(),
torch.stack([bboxes, bboxes]),
]
+ list(torch.chunk(bboxes, 2))
+ list(torch.unbind(bboxes))
):
if return_type == "Tensor":
assert type(out) is torch.Tensor
else:
assert isinstance(out, datapoints.BoundingBoxes)
assert hasattr(out, "format")
assert hasattr(out, "canvas_size")
datapoints.set_return_type("tensor")