Skip to content

Commit 58d6ecb

Browse files
committed
fix five / ten crop
1 parent 64737ba commit 58d6ecb

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

test/test_transforms_v2_refactored.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4514,7 +4514,7 @@ class TestFiveTenCrop:
45144514

45154515
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
45164516
@pytest.mark.parametrize("device", cpu_and_cuda())
4517-
@pytest.mark.parametrize("kernel", [F.five_crop_image, F.ten_crop])
4517+
@pytest.mark.parametrize("kernel", [F.five_crop_image, F.ten_crop_image])
45184518
def test_kernel_image(self, dtype, device, kernel):
45194519
check_kernel(
45204520
kernel,
@@ -4527,9 +4527,9 @@ def test_kernel_image(self, dtype, device, kernel):
45274527
def test_kernel_video(self, kernel):
45284528
check_kernel(kernel, make_video(self.INPUT_SIZE), size=self.OUTPUT_SIZE, check_batched_vs_unbatched=False)
45294529

4530-
def _five_ten_crop_functional_wrapper(self, fn):
4530+
def _functional_wrapper(self, fn):
45314531
# This wrapper is needed to make five_crop / ten_crop compatible with check_functional, since that requires a
4532-
# single rather than a sequence.
4532+
# single output rather than a sequence.
45334533
@functools.wraps(fn)
45344534
def wrapper(*args, **kwargs):
45354535
outputs = fn(*args, **kwargs)
@@ -4544,7 +4544,7 @@ def wrapper(*args, **kwargs):
45444544
@pytest.mark.parametrize("functional", [F.five_crop, F.ten_crop])
45454545
def test_functional(self, make_input, functional):
45464546
check_functional(
4547-
self._five_ten_crop_functional_wrapper(functional),
4547+
self._functional_wrapper(functional),
45484548
make_input(self.INPUT_SIZE),
45494549
size=self.OUTPUT_SIZE,
45504550
check_scripted_smoke=False,
@@ -4566,13 +4566,13 @@ def test_functional(self, make_input, functional):
45664566
def test_functional_signature(self, functional, kernel, input_type):
45674567
check_functional_kernel_signature_match(functional, kernel=kernel, input_type=input_type)
45684568

4569-
class _FiveTenCropTransformWrapper(nn.Module):
4569+
class _TransformWrapper(nn.Module):
45704570
# This wrapper is needed to make FiveCrop / TenCrop compatible with check_transform, since that requires a
45714571
# single output rather than a sequence.
45724572
_v1_transform_cls = None
45734573

45744574
def _extract_params_for_v1_transform(self):
4575-
return dict(five_ten_crop_transform=self.five_ten_crop_transform.__prepare_scriptable__())
4575+
return dict(five_ten_crop_transform=self.five_ten_crop_transform)
45764576

45774577
def __init__(self, five_ten_crop_transform):
45784578
super().__init__()
@@ -4589,9 +4589,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
45894589
)
45904590
@pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop])
45914591
def test_transform(self, make_input, transform_cls):
4592-
check_transform(
4593-
self._FiveTenCropTransformWrapper(transform_cls(size=self.OUTPUT_SIZE)), make_input(self.INPUT_SIZE)
4594-
)
4592+
check_transform(self._TransformWrapper(transform_cls(size=self.OUTPUT_SIZE)), make_input(self.INPUT_SIZE))
45954593

45964594
@pytest.mark.parametrize("make_input", [make_bounding_boxes, make_detection_mask])
45974595
@pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop])

0 commit comments

Comments
 (0)