Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,8 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):


def cache(fn):
"""Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite buffer size,
but also caches exceptions.

.. warning::

Only use this on deterministic functions.
"""Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite cache size,
but this also caches exceptions.
"""
sentinel = object()
out_cache = {}
Expand Down Expand Up @@ -238,11 +234,3 @@ def wrapper(*args, **kwargs):
return out

return wrapper


@cache
def script(fn):
try:
return torch.jit.script(fn)
except Exception as error:
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
24 changes: 0 additions & 24 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,16 @@
import dataclasses
import functools
from typing import Callable, Dict, Type

import pytest
import torch
import torchvision.prototype.transforms.functional as F
from prototype_common_utils import ArgsKwargs
from prototype_transforms_kernel_infos import KERNEL_INFOS
from test_prototype_transforms_functional import FUNCTIONAL_INFOS
from torchvision.prototype import features

__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]

KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KERNEL_INFOS}


# Helper class to use the infos from the old framework for now tests
class PreloadedArgsKwargs(ArgsKwargs):
def load(self, device="cpu"):
args = tuple(arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in self.args)
kwargs = {
keyword: arg.to(device) if isinstance(arg, torch.Tensor) else arg for keyword, arg in self.kwargs.items()
}
return args, kwargs


def preloaded_sample_inputs(args_kwargs):
for args, kwargs in args_kwargs:
yield PreloadedArgsKwargs(*args, **kwargs)


KERNEL_SAMPLE_INPUTS_FN_MAP.update(
{info.functional: functools.partial(preloaded_sample_inputs, info.sample_inputs()) for info in FUNCTIONAL_INFOS}
)


@dataclasses.dataclass
class DispatcherInfo:
dispatcher: Callable
Expand Down
31 changes: 0 additions & 31 deletions test/test_prototype_transforms_dispatchers.py

This file was deleted.

Loading