Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,15 @@ def test_call_consistency(config, args_kwargs):
def test_get_params_alias(config):
assert config.prototype_cls.get_params is config.legacy_cls.get_params

if not config.args_kwargs:
return

args, kwargs = config.args_kwargs[0]
legacy_transform = config.legacy_cls(*args, **kwargs)
prototype_transform = config.prototype_cls(*args, **kwargs)

assert prototype_transform.get_params is legacy_transform.get_params


@pytest.mark.parametrize(
("transform_cls", "args_kwargs"),
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init_subclass__(cls) -> None:
# Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
# This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
cls.get_params = cls._v1_transform_cls.get_params # type: ignore[attr-defined]
cls.get_params = staticmethod(cls._v1_transform_cls.get_params) # type: ignore[attr-defined]

def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
Expand Down