Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2181,3 +2181,36 @@ def test_unsupported_types(self, dispatcher, make_input):

with pytest.raises(TypeError, match=re.escape(str(type(input)))):
dispatcher(input)


class TestRegisterKernel:
@pytest.mark.parametrize("dispatcher", (F.resize, "resize"))
def test_register_kernel(self, dispatcher):
class CustomDatapoint(datapoints.Datapoint):
pass

kernel_was_called = False

@F.register_kernel(dispatcher, CustomDatapoint)
def new_resize(dp, *args, **kwargs):
nonlocal kernel_was_called
kernel_was_called = True
return dp

t = transforms.Resize(size=(224, 224), antialias=True)

my_dp = CustomDatapoint(torch.rand(3, 10, 10))
out = t(my_dp)
assert out is my_dp
assert kernel_was_called

# Sanity check to make sure we didn't override the kernel of other types
t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)

def test_bad_disaptcher_name(self):
class CustomDatapoint(datapoints.Datapoint):
pass

with pytest.raises(ValueError, match="Could not find dispatcher with name"):
F.register_kernel("bad_name", CustomDatapoint)
15 changes: 15 additions & 0 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,22 @@ def decorator(kernel):
return decorator


def _name_to_dispatcher(name):
import torchvision.transforms.v2.functional # noqa

try:
return next(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the most efficient, we could store the mapping somewhere. I don't think we care anyway since this is just executed during registration, not when calling the dispatcher.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand the iteration here. Can't we just do

try:
    return getattr(torchvision.transforms.v2.functional, name)
except AttributeError:
    raise ValueError(...) from None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes much better!

obj
for obj in torchvision.transforms.v2.functional.__dict__.values()
if getattr(obj, "__name__", "") == name
)
except StopIteration:
raise ValueError(f"Could not find dispatcher with name '{name}'.")


def register_kernel(dispatcher, datapoint_cls):
if isinstance(dispatcher, str):
dispatcher = _name_to_dispatcher(name=dispatcher)
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)


Expand Down