Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,22 +328,6 @@ def test_auto_augment(self, transform, input):
def test_normalize(self, transform, input):
transform(input)

@parametrize(
[
(
transforms.RandomResizedCrop([16, 16], antialias=True),
itertools.chain(
make_images(extra_dims=[(4,)]),
make_vanilla_tensor_images(),
make_pil_images(),
make_videos(extra_dims=[()]),
),
)
]
)
def test_random_resized_crop(self, transform, input):
transform(input)


@pytest.mark.parametrize(
"flat_inputs",
Expand Down
25 changes: 0 additions & 25 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,30 +252,6 @@ def __init__(
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
v2_transforms.RandomResizedCrop,
legacy_transforms.RandomResizedCrop,
[
ArgsKwargs(16),
ArgsKwargs(17, scale=(0.3, 0.7)),
ArgsKwargs(25, ratio=(0.5, 1.5)),
ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs=dict(rtol=0, atol=1),
),
ConsistencyConfig(
v2_transforms.RandomResizedCrop,
legacy_transforms.RandomResizedCrop,
[
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC, antialias=True),
],
closeness_kwargs=dict(rtol=0, atol=21),
),
ConsistencyConfig(
v2_transforms.ColorJitter,
legacy_transforms.ColorJitter,
Expand Down Expand Up @@ -535,7 +511,6 @@ def test_call_consistency(config, args_kwargs):
id=transform_cls.__name__,
)
for transform_cls, get_params_args_kwargs in [
(v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
(v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
(v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
Expand Down
83 changes: 0 additions & 83 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import math
import os
import re

Expand Down Expand Up @@ -526,88 +525,6 @@ def test_tv_tensor_explicit_metadata(self, metadata):
# `transforms_v2_kernel_infos.py`


def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Driveby. This was not used anywhere anymore.

rot = math.radians(angle_)
cx, cy = center_
tx, ty = translate_
sx, sy = [math.radians(sh_) for sh_ in shear_]

c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
c_matrix_inv = np.linalg.inv(c_matrix)
rs_matrix = np.array(
[
[scale_ * math.cos(rot), -scale_ * math.sin(rot), 0],
[scale_ * math.sin(rot), scale_ * math.cos(rot), 0],
[0, 0, 1],
]
)
shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
return true_matrix


@pytest.mark.parametrize("device", cpu_and_cuda())
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Driveby. We forgot to remove this when porting the tests for F.vertical_flip.

mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
mask[:, 0, :] = 1

out_mask = F.vertical_flip_mask(mask)

expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
expected_mask[:, -1, :] = 1
torch.testing.assert_close(out_mask, expected_mask)


@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize(
"format",
[tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH, tv_tensors.BoundingBoxFormat.CXCYWH],
)
@pytest.mark.parametrize(
"top, left, height, width, size",
[
[0, 0, 30, 30, (60, 60)],
[-5, 5, 35, 45, (32, 34)],
],
)
def test_correctness_resized_crop_bounding_boxes(device, format, top, left, height, width, size):
def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
# bbox should be xyxy
bbox[0] = (bbox[0] - left_) * size_[1] / width_
bbox[1] = (bbox[1] - top_) * size_[0] / height_
bbox[2] = (bbox[2] - left_) * size_[1] / width_
bbox[3] = (bbox[3] - top_) * size_[0] / height_
return bbox

format = tv_tensors.BoundingBoxFormat.XYXY
canvas_size = (100, 100)
in_boxes = [
[10.0, 10.0, 20.0, 20.0],
[5.0, 10.0, 15.0, 20.0],
]
expected_bboxes = []
for in_box in in_boxes:
expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
expected_bboxes = torch.tensor(expected_bboxes, device=device)

in_boxes = tv_tensors.BoundingBoxes(
in_boxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
)
if format != tv_tensors.BoundingBoxFormat.XYXY:
in_boxes = convert_bounding_box_format(in_boxes, tv_tensors.BoundingBoxFormat.XYXY, format)

output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)

if format != tv_tensors.BoundingBoxFormat.XYXY:
output_boxes = convert_bounding_box_format(output_boxes, format, tv_tensors.BoundingBoxFormat.XYXY)

torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_canvas_size, size)


def _parse_padding(padding):
if isinstance(padding, int):
return [padding] * 4
Expand Down
133 changes: 133 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -3110,3 +3110,136 @@ def test_errors(self):
F.convert_bounding_box_format(
input_tv_tensor, old_format=input_tv_tensor.format, new_format=input_tv_tensor.format
)


class TestResizedCrop:
INPUT_SIZE = (17, 11)
CROP_KWARGS = dict(top=2, left=2, height=5, width=7)
OUTPUT_SIZE = (19, 32)

@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.resized_crop_image, make_image),
(F.resized_crop_bounding_boxes, make_bounding_boxes),
(F.resized_crop_mask, make_segmentation_mask),
(F.resized_crop_mask, make_detection_mask),
(F.resized_crop_video, make_video),
],
)
def test_kernel(self, kernel, make_input):
input = make_input(self.INPUT_SIZE)
if isinstance(input, tv_tensors.BoundingBoxes):
extra_kwargs = dict(format=input.format)
elif isinstance(input, tv_tensors.Mask):
extra_kwargs = dict()
else:
extra_kwargs = dict(antialias=True)

check_kernel(kernel, input, **self.CROP_KWARGS, size=self.OUTPUT_SIZE, **extra_kwargs)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
)
def test_functional(self, make_input):
check_functional(
F.resized_crop, make_input(self.INPUT_SIZE), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, antialias=True
)

@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.resized_crop_image, torch.Tensor),
(F._resized_crop_image_pil, PIL.Image.Image),
(F.resized_crop_image, tv_tensors.Image),
(F.resized_crop_bounding_boxes, tv_tensors.BoundingBoxes),
(F.resized_crop_mask, tv_tensors.Mask),
(F.resized_crop_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.resized_crop, kernel=kernel, input_type=input_type)

@param_value_parametrization(
scale=[(0.1, 0.2), [0.0, 1.0]],
ratio=[(0.3, 0.7), [0.1, 5.0]],
)
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
)
def test_transform(self, param, value, make_input):
check_transform(
transforms.RandomResizedCrop(size=self.OUTPUT_SIZE, **{param: value}, antialias=True),
make_input(self.INPUT_SIZE),
check_v1_compatibility=dict(rtol=0, atol=1),
)

# `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2.
# The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT`
@pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST})
def test_functional_image_correctness(self, interpolation):
image = make_image(self.INPUT_SIZE, dtype=torch.uint8)

actual = F.resized_crop(
image, **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation, antialias=True
)
expected = F.to_image(
F.resized_crop(
F.to_pil_image(image), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation
)
)

torch.testing.assert_close(actual, expected, atol=1, rtol=0)

def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width, size):
new_height, new_width = size

crop_affine_matrix = np.array(
[
[1, 0, -left],
[0, 1, -top],
[0, 0, 1],
],
)
resize_affine_matrix = np.array(
[
[new_width / width, 0, 0],
[0, new_height / height, 0],
[0, 0, 1],
],
)
affine_matrix = (resize_affine_matrix @ crop_affine_matrix)[:2, :]

return reference_affine_bounding_boxes_helper(
bounding_boxes,
affine_matrix=affine_matrix,
new_canvas_size=size,
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
def test_functional_bounding_boxes_correctness(self, format):
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format)

actual = F.resized_crop(bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE)
expected = self._reference_resized_crop_bounding_boxes(
bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE
)

assert_equal(actual, expected)
assert_equal(F.get_size(actual), F.get_size(expected))

def test_transform_errors_warnings(self):
with pytest.raises(ValueError, match="provide only two dimensions"):
transforms.RandomResizedCrop(size=(1, 2, 3))

with pytest.raises(TypeError, match="Scale should be a sequence"):
transforms.RandomResizedCrop(size=self.INPUT_SIZE, scale=123)

with pytest.raises(TypeError, match="Ratio should be a sequence"):
transforms.RandomResizedCrop(size=self.INPUT_SIZE, ratio=123)

for param in ["scale", "ratio"]:
with pytest.warns(match="Scale and ratio should be of kind"):
transforms.RandomResizedCrop(size=self.INPUT_SIZE, **{param: [1, 0]})
10 changes: 0 additions & 10 deletions test/transforms_v2_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):


DISPATCHER_INFOS = [
DispatcherInfo(
F.resized_crop,
kernels={
tv_tensors.Image: F.resized_crop_image,
tv_tensors.Video: F.resized_crop_video,
tv_tensors.BoundingBoxes: F.resized_crop_bounding_boxes,
tv_tensors.Mask: F.resized_crop_mask,
},
pil_kernel_info=PILKernelInfo(F._resized_crop_image_pil),
),
DispatcherInfo(
F.pad,
kernels={
Expand Down
Loading