@@ -4514,7 +4514,7 @@ class TestFiveTenCrop:
45144514
45154515 @pytest .mark .parametrize ("dtype" , [torch .uint8 , torch .float32 ])
45164516 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
4517- @pytest .mark .parametrize ("kernel" , [F .five_crop_image , F .ten_crop ])
4517+ @pytest .mark .parametrize ("kernel" , [F .five_crop_image , F .ten_crop_image ])
45184518 def test_kernel_image (self , dtype , device , kernel ):
45194519 check_kernel (
45204520 kernel ,
@@ -4527,9 +4527,9 @@ def test_kernel_image(self, dtype, device, kernel):
45274527 def test_kernel_video (self , kernel ):
45284528 check_kernel (kernel , make_video (self .INPUT_SIZE ), size = self .OUTPUT_SIZE , check_batched_vs_unbatched = False )
45294529
4530- def _five_ten_crop_functional_wrapper (self , fn ):
4530+ def _functional_wrapper (self , fn ):
45314531 # This wrapper is needed to make five_crop / ten_crop compatible with check_functional, since that requires a
4532- # single rather than a sequence.
4532+ # single output rather than a sequence.
45334533 @functools .wraps (fn )
45344534 def wrapper (* args , ** kwargs ):
45354535 outputs = fn (* args , ** kwargs )
@@ -4544,7 +4544,7 @@ def wrapper(*args, **kwargs):
45444544 @pytest .mark .parametrize ("functional" , [F .five_crop , F .ten_crop ])
45454545 def test_functional (self , make_input , functional ):
45464546 check_functional (
4547- self ._five_ten_crop_functional_wrapper (functional ),
4547+ self ._functional_wrapper (functional ),
45484548 make_input (self .INPUT_SIZE ),
45494549 size = self .OUTPUT_SIZE ,
45504550 check_scripted_smoke = False ,
@@ -4566,13 +4566,13 @@ def test_functional(self, make_input, functional):
45664566 def test_functional_signature (self , functional , kernel , input_type ):
45674567 check_functional_kernel_signature_match (functional , kernel = kernel , input_type = input_type )
45684568
4569- class _FiveTenCropTransformWrapper (nn .Module ):
4569+ class _TransformWrapper (nn .Module ):
45704570 # This wrapper is needed to make FiveCrop / TenCrop compatible with check_transform, since that requires a
45714571 # single output rather than a sequence.
45724572 _v1_transform_cls = None
45734573
45744574 def _extract_params_for_v1_transform (self ):
4575- return dict (five_ten_crop_transform = self .five_ten_crop_transform . __prepare_scriptable__ () )
4575+ return dict (five_ten_crop_transform = self .five_ten_crop_transform )
45764576
45774577 def __init__ (self , five_ten_crop_transform ):
45784578 super ().__init__ ()
@@ -4589,9 +4589,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
45894589 )
45904590 @pytest .mark .parametrize ("transform_cls" , [transforms .FiveCrop , transforms .TenCrop ])
45914591 def test_transform (self , make_input , transform_cls ):
4592- check_transform (
4593- self ._FiveTenCropTransformWrapper (transform_cls (size = self .OUTPUT_SIZE )), make_input (self .INPUT_SIZE )
4594- )
4592+ check_transform (self ._TransformWrapper (transform_cls (size = self .OUTPUT_SIZE )), make_input (self .INPUT_SIZE ))
45954593
45964594 @pytest .mark .parametrize ("make_input" , [make_bounding_boxes , make_detection_mask ])
45974595 @pytest .mark .parametrize ("transform_cls" , [transforms .FiveCrop , transforms .TenCrop ])
0 commit comments