Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2181,3 +2181,36 @@ def test_unsupported_types(self, dispatcher, make_input):

with pytest.raises(TypeError, match=re.escape(str(type(input)))):
dispatcher(input)


class TestRegisterKernel:
@pytest.mark.parametrize("dispatcher", (F.resize, "resize"))
def test_register_kernel(self, dispatcher):
class CustomDatapoint(datapoints.Datapoint):
pass

kernel_was_called = False

@F.register_kernel(dispatcher, CustomDatapoint)
def new_resize(dp, *args, **kwargs):
nonlocal kernel_was_called
kernel_was_called = True
return dp

t = transforms.Resize(size=(224, 224), antialias=True)

my_dp = CustomDatapoint(torch.rand(3, 10, 10))
out = t(my_dp)
assert out is my_dp
assert kernel_was_called

# Sanity check to make sure we didn't override the kernel of other types
t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)

def test_bad_disaptcher_name(self):
class CustomDatapoint(datapoints.Datapoint):
pass

with pytest.raises(ValueError, match="Could not find dispatcher with name"):
F.register_kernel("bad_name", CustomDatapoint)
11 changes: 11 additions & 0 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,18 @@ def decorator(kernel):
return decorator


def _name_to_dispatcher(name):
import torchvision.transforms.v2.functional # noqa

try:
return getattr(torchvision.transforms.v2.functional, name)
except AttributeError:
raise ValueError(f"Could not find dispatcher with name '{name}'.") from None


def register_kernel(dispatcher, datapoint_cls):
if isinstance(dispatcher, str):
dispatcher = _name_to_dispatcher(name=dispatcher)
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)


Expand Down