Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import dataclasses
from typing import Callable, Dict, Sequence, Type

from collections import defaultdict
from typing import Callable, Dict, List, Sequence, Type

import pytest
import torchvision.prototype.transforms.functional as F
from prototype_transforms_kernel_infos import KERNEL_INFOS, Skip
from prototype_transforms_kernel_infos import KERNEL_INFOS, TestMark
from torchvision.prototype import features

__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
Expand All @@ -15,11 +17,24 @@
class DispatcherInfo:
dispatcher: Callable
kernels: Dict[Type, Callable]
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False)
skips: Sequence = dataclasses.field(default_factory=list)
_skips_map: Dict = dataclasses.field(default=None, init=False)

test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)

def __post_init__(self):
self._skips_map = {skip.test_name: skip for skip in self.skips}
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same on KernelInfo and here on the DispatcherInfo. I'm going to factor this out into a common base class in a follow-up PR since this PR is already quite large.


def get_marks(self, test_id, args_kwargs):
return [
conditional_mark.mark
for conditional_mark in self._test_marks_map.get(test_id, [])
if conditional_mark.condition(args_kwargs)
]

def sample_inputs(self, *types):
for type in types or self.kernels.keys():
Expand All @@ -28,11 +43,6 @@ def sample_inputs(self, *types):

yield from KERNEL_SAMPLE_INPUTS_FN_MAP[self.kernels[type]]()

def maybe_skip(self, *, test_name, args_kwargs, device):
skip = self._skips_map.get(test_name)
if skip and skip.condition(args_kwargs, device):
pytest.skip(skip.reason)


DISPATCHER_INFOS = [
DispatcherInfo(
Expand Down Expand Up @@ -206,25 +216,25 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
kernels={
features.Image: F.five_crop_image_tensor,
},
skips=[
Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting five_crop_image_tensor.",
),
test_marks=[
TestMark(
("TestDispatchers", "test_scripted_smoke"),
pytest.mark.xfail(reason="Integer size is not supported when scripting five_crop."),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs["size"], int),
)
],
),
DispatcherInfo(
F.ten_crop,
kernels={
features.Image: F.ten_crop_image_tensor,
},
skips=[
Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting ten_crop_image_tensor.",
),
test_marks=[
TestMark(
("TestDispatchers", "test_scripted_smoke"),
pytest.mark.xfail(reason="Integer size is not supported when scripting ten_crop."),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs["size"], int),
)
],
),
DispatcherInfo(
Expand Down
80 changes: 52 additions & 28 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import functools
import itertools
import math
from typing import Any, Callable, Dict, Iterable, Optional, Sequence
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np
import pytest
import torch.testing
import torchvision.ops
import torchvision.prototype.transforms.functional as F

from _pytest.mark.structures import MarkDecorator
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only needed for annotating the attributes of the data classes below. I'll send a follow-up PR removing them, since the upside of @dataclasses.dataclass is outweighed by the downside of importing from a private namespace.

from datasets_utils import combinations_grid
from prototype_common_utils import ArgsKwargs, make_bounding_box_loaders, make_image_loaders, make_mask_loaders
from torchvision.prototype import features
Expand All @@ -17,11 +20,24 @@
__all__ = ["KernelInfo", "KERNEL_INFOS"]


TestID = Tuple[Optional[str], str]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.



@dataclasses.dataclass
class Skip:
test_name: str
reason: str
condition: Callable[[ArgsKwargs, str], bool] = lambda args_kwargs, device: True
class TestMark:
test_id: TestID
mark: MarkDecorator
condition: Callable[[ArgsKwargs], bool] = lambda args_kwargs: True


def mark_framework_limitation(test_id, reason):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time
return TestMark(test_id, pytest.mark.skip(reason=reason))


@dataclasses.dataclass
Expand All @@ -43,18 +59,24 @@ class KernelInfo:
reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False)
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)

def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__
self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn
self._skips_map = {skip.test_name: skip for skip in self.skips}

def maybe_skip(self, *, test_name, args_kwargs, device):
skip = self._skips_map.get(test_name)
if skip and skip.condition(args_kwargs, device):
pytest.skip(skip.reason)
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)

def get_marks(self, test_id, args_kwargs):
return [
conditional_mark.mark
for conditional_mark in self._test_marks_map.get(test_id, [])
if conditional_mark.condition(args_kwargs)
]


DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
Expand Down Expand Up @@ -1449,15 +1471,16 @@ def reference_inputs_ten_crop_image_tensor():
sample_inputs_fn=sample_inputs_five_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor,
skips=[
Skip(
"test_scripted_vs_eager",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting five_crop_image_tensor.",
test_marks=[
TestMark(
("TestKernels", "test_scripted_vs_eager"),
pytest.mark.xfail(reason="Integer size is not supported when scripting five_crop_image_tensor."),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs["size"], int),
),
mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
mark_framework_limitation(
("TestKernels", "test_dtype_and_device_consistency"), "Output is not a tensor."
),
Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."),
Skip("test_no_inplace", reason="Output of five_crop_image_tensor is not a tensor."),
Skip("test_dtype_and_device_consistency", reason="Output of five_crop_image_tensor is not a tensor."),
],
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
Expand All @@ -1466,15 +1489,16 @@ def reference_inputs_ten_crop_image_tensor():
sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
skips=[
Skip(
"test_scripted_vs_eager",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting ten_crop_image_tensor.",
test_marks=[
TestMark(
("TestKernels", "test_scripted_vs_eager"),
pytest.mark.xfail(reason="Integer size is not supported when scripting ten_crop_image_tensor."),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs["size"], int),
),
mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
mark_framework_limitation(
("TestKernels", "test_dtype_and_device_consistency"), "Output is not a tensor."
),
Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."),
Skip("test_no_inplace", reason="Output of ten_crop_image_tensor is not a tensor."),
Skip("test_dtype_and_device_consistency", reason="Output of ten_crop_image_tensor is not a tensor."),
],
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
Expand Down
101 changes: 61 additions & 40 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import math
import os

Expand Down Expand Up @@ -26,33 +27,60 @@ def script(fn):
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error


@pytest.fixture(autouse=True)
def maybe_skip(request):
# In case the test uses no parametrization or fixtures, the `callspec` attribute does not exist
try:
callspec = request.node.callspec
except AttributeError:
return
def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, name_fn=lambda info: str(info)):
if condition is None:

try:
info = callspec.params["info"]
args_kwargs = callspec.params["args_kwargs"]
except KeyError:
return
def condition(info):
return True

info.maybe_skip(
test_name=request.node.originalname, args_kwargs=args_kwargs, device=callspec.params.get("device", "cpu")
)
def decorator(test_fn):
parts = test_fn.__qualname__.split(".")
if len(parts) == 1:
test_class_name = None
test_function_name = parts[0]
elif len(parts) == 2:
test_class_name, test_function_name = parts
else:
raise pytest.UsageError()
test_id = (test_class_name, test_function_name)

argnames = ("info", "args_kwargs")
argvalues = []
for info in infos:
if not condition(info):
continue

args_kwargs = list(args_kwargs_fn(info))
name = name_fn(info)
idx_field_len = len(str(len(args_kwargs)))

for idx, args_kwargs_ in enumerate(args_kwargs):
argvalues.append(
pytest.param(
info,
args_kwargs_,
marks=info.get_marks(test_id, args_kwargs_),
id=f"{name}-{idx:0{idx_field_len}}",
)
)

return pytest.mark.parametrize(argnames, argvalues)(test_fn)

return decorator


class TestKernels:
sample_inputs = pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}")
for info in KERNEL_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs_fn())
],
make_kernel_args_kwargs_parametrization = functools.partial(
make_args_kwargs_parametrization, name_fn=lambda info: info.kernel_name
)
sample_inputs = kernel_sample_inputs = make_kernel_args_kwargs_parametrization(
KERNEL_INFOS,
args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
)
reference_inputs = make_kernel_args_kwargs_parametrization(
KERNEL_INFOS,
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
condition=lambda info: info.reference_fn is not None,
)

@sample_inputs
Expand Down Expand Up @@ -125,7 +153,7 @@ def test_no_inplace(self, info, args_kwargs, device):
input_version = input._version
output = info.kernel(input, *other_args, **kwargs)

assert output is not input or output._version == input_version
assert output is not input or input._version == input_version

@sample_inputs
@needs_cuda
Expand All @@ -148,15 +176,7 @@ def test_dtype_and_device_consistency(self, info, args_kwargs, device):
assert output.dtype == input.dtype
assert output.device == input.device

@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}")
for info in KERNEL_INFOS
for idx, args_kwargs in enumerate(info.reference_inputs_fn())
if info.reference_fn is not None
],
)
@reference_inputs
def test_against_reference(self, info, args_kwargs):
args, kwargs = args_kwargs.load("cpu")

Expand All @@ -167,15 +187,16 @@ def test_against_reference(self, info, args_kwargs):


class TestDispatchers:
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
for info in DISPATCHER_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs(features.Image))
if features.Image in info.kernels
],
make_dispatcher_args_kwargs_parametrization = functools.partial(
make_args_kwargs_parametrization, name_fn=lambda info: info.dispatcher.__name__
)
image_sample_inputs = kernel_sample_inputs = make_dispatcher_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
condition=lambda info: features.Image in info.kernels,
)

@image_sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_scripted_smoke(self, info, args_kwargs, device):
dispatcher = script(info.dispatcher)
Expand Down