diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 45668fda1ca..8a858bf58c2 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2181,3 +2181,36 @@ def test_unsupported_types(self, dispatcher, make_input): with pytest.raises(TypeError, match=re.escape(str(type(input)))): dispatcher(input) + + +class TestRegisterKernel: + @pytest.mark.parametrize("dispatcher", (F.resize, "resize")) + def test_register_kernel(self, dispatcher): + class CustomDatapoint(datapoints.Datapoint): + pass + + kernel_was_called = False + + @F.register_kernel(dispatcher, CustomDatapoint) + def new_resize(dp, *args, **kwargs): + nonlocal kernel_was_called + kernel_was_called = True + return dp + + t = transforms.Resize(size=(224, 224), antialias=True) + + my_dp = CustomDatapoint(torch.rand(3, 10, 10)) + out = t(my_dp) + assert out is my_dp + assert kernel_was_called + + # Sanity check to make sure we didn't override the kernel of other types + t(torch.rand(3, 10, 10)).shape == (3, 224, 224) + t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224) + + def test_bad_disaptcher_name(self): + class CustomDatapoint(datapoints.Datapoint): + pass + + with pytest.raises(ValueError, match="Could not find dispatcher with name"): + F.register_kernel("bad_name", CustomDatapoint) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 63e029d6c77..1eaa54102a4 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -37,7 +37,18 @@ def decorator(kernel): return decorator +def _name_to_dispatcher(name): + import torchvision.transforms.v2.functional # noqa + + try: + return getattr(torchvision.transforms.v2.functional, name) + except AttributeError: + raise ValueError(f"Could not find dispatcher with name '{name}'.") from None + + def register_kernel(dispatcher, datapoint_cls): + if isinstance(dispatcher, str): + dispatcher = _name_to_dispatcher(name=dispatcher) return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)