@@ -2280,3 +2280,61 @@ def resize_my_datapoint():
22802280 _register_kernel_internal (F .resize , MyDatapoint , datapoint_wrapper = False )(resize_my_datapoint )
22812281
22822282 assert _get_kernel (F .resize , MyDatapoint ) is resize_my_datapoint
2283+
2284+
2285+ class TestPermuteChannels :
2286+ _DEFAULT_PERMUTATION = [2 , 0 , 1 ]
2287+
2288+ @pytest .mark .parametrize (
2289+ ("kernel" , "make_input" ),
2290+ [
2291+ (F .permute_channels_image_tensor , make_image_tensor ),
2292+ # FIXME
2293+ # check_kernel does not support PIL kernel, but it should
2294+ (F .permute_channels_image_tensor , make_image ),
2295+ (F .permute_channels_video , make_video ),
2296+ ],
2297+ )
2298+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
2299+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
2300+ def test_kernel (self , kernel , make_input , dtype , device ):
2301+ check_kernel (kernel , make_input (dtype = dtype , device = device ), permutation = self ._DEFAULT_PERMUTATION )
2302+
2303+ @pytest .mark .parametrize (
2304+ ("kernel" , "make_input" ),
2305+ [
2306+ (F .permute_channels_image_tensor , make_image_tensor ),
2307+ (F .permute_channels_image_pil , make_image_pil ),
2308+ (F .permute_channels_image_tensor , make_image ),
2309+ (F .permute_channels_video , make_video ),
2310+ ],
2311+ )
2312+ def test_dispatcher (self , kernel , make_input ):
2313+ check_dispatcher (F .permute_channels , kernel , make_input (), permutation = self ._DEFAULT_PERMUTATION )
2314+
2315+ @pytest .mark .parametrize (
2316+ ("kernel" , "input_type" ),
2317+ [
2318+ (F .permute_channels_image_tensor , torch .Tensor ),
2319+ (F .permute_channels_image_pil , PIL .Image .Image ),
2320+ (F .permute_channels_image_tensor , datapoints .Image ),
2321+ (F .permute_channels_video , datapoints .Video ),
2322+ ],
2323+ )
2324+ def test_dispatcher_signature (self , kernel , input_type ):
2325+ check_dispatcher_kernel_signature_match (F .permute_channels , kernel = kernel , input_type = input_type )
2326+
2327+ def reference_image_correctness (self , image , permutation ):
2328+ channel_images = image .split (1 , dim = - 3 )
2329+ permuted_channel_images = [channel_images [channel_idx ] for channel_idx in permutation ]
2330+ return datapoints .Image (torch .concat (permuted_channel_images , dim = - 3 ))
2331+
2332+ @pytest .mark .parametrize ("permutation" , [[2 , 0 , 1 ], [1 , 2 , 0 ], [2 , 0 , 1 ], [0 , 1 , 2 ]])
2333+ @pytest .mark .parametrize ("batch_dims" , [(), (2 ,), (2 , 1 )])
2334+ def test_image_correctness (self , permutation , batch_dims ):
2335+ image = make_image (batch_dims = batch_dims )
2336+
2337+ actual = F .permute_channels (image , permutation = permutation )
2338+ expected = self .reference_image_correctness (image , permutation = permutation )
2339+
2340+ torch .testing .assert_close (actual , expected )
0 commit comments