-
Notifications
You must be signed in to change notification settings - Fork 7.2k
use pytest markers instead of custom solution for prototype transforms functional tests #6653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
0f81474
220a31d
6d994a9
4d6fafb
9feb1ab
d6e619e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,13 +2,16 @@ | |
| import functools | ||
| import itertools | ||
| import math | ||
| from typing import Any, Callable, Dict, Iterable, Optional, Sequence | ||
| from collections import defaultdict | ||
| from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import torch.testing | ||
| import torchvision.ops | ||
| import torchvision.prototype.transforms.functional as F | ||
|
|
||
| from _pytest.mark.structures import MarkDecorator | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only needed for annotating the attributes of the data classes below. I'll send a follow-up PR removing them, since the upside of |
||
| from datasets_utils import combinations_grid | ||
| from prototype_common_utils import ArgsKwargs, make_bounding_box_loaders, make_image_loaders, make_mask_loaders | ||
| from torchvision.prototype import features | ||
|
|
@@ -17,11 +20,24 @@ | |
| __all__ = ["KernelInfo", "KERNEL_INFOS"] | ||
|
|
||
|
|
||
| TestID = Tuple[Optional[str], str] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class Skip: | ||
| test_name: str | ||
| reason: str | ||
| condition: Callable[[ArgsKwargs, str], bool] = lambda args_kwargs, device: True | ||
| class TestMark: | ||
| test_id: TestID | ||
| mark: MarkDecorator | ||
| condition: Callable[[ArgsKwargs], bool] = lambda args_kwargs: True | ||
|
|
||
|
|
||
| def mark_framework_limitation(test_id, reason): | ||
| # The purpose of this function is to have a single entry point for skip marks that are only there, because the test | ||
| # framework cannot handle the kernel in general or a specific parameter combination. | ||
| # As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is | ||
| # still justified. | ||
| # We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus, | ||
| # we are wasting CI resources for no reason for most of the time | ||
| return TestMark(test_id, pytest.mark.skip(reason=reason)) | ||
pmeier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
|
|
@@ -43,18 +59,24 @@ class KernelInfo: | |
| reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None | ||
| # Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. | ||
| closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) | ||
| skips: Sequence[Skip] = dataclasses.field(default_factory=list) | ||
| _skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False) | ||
| test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list) | ||
| _test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False) | ||
|
|
||
| def __post_init__(self): | ||
| self.kernel_name = self.kernel_name or self.kernel.__name__ | ||
| self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn | ||
| self._skips_map = {skip.test_name: skip for skip in self.skips} | ||
|
|
||
| def maybe_skip(self, *, test_name, args_kwargs, device): | ||
| skip = self._skips_map.get(test_name) | ||
| if skip and skip.condition(args_kwargs, device): | ||
| pytest.skip(skip.reason) | ||
| test_marks_map = defaultdict(list) | ||
| for test_mark in self.test_marks: | ||
| test_marks_map[test_mark.test_id].append(test_mark) | ||
| self._test_marks_map = dict(test_marks_map) | ||
|
|
||
| def get_marks(self, test_id, args_kwargs): | ||
| return [ | ||
| conditional_mark.mark | ||
| for conditional_mark in self._test_marks_map.get(test_id, []) | ||
| if conditional_mark.condition(args_kwargs) | ||
| ] | ||
|
|
||
|
|
||
| DEFAULT_IMAGE_CLOSENESS_KWARGS = dict( | ||
|
|
@@ -1449,15 +1471,16 @@ def reference_inputs_ten_crop_image_tensor(): | |
| sample_inputs_fn=sample_inputs_five_crop_image_tensor, | ||
| reference_fn=pil_reference_wrapper(F.five_crop_image_pil), | ||
| reference_inputs_fn=reference_inputs_five_crop_image_tensor, | ||
| skips=[ | ||
| Skip( | ||
| "test_scripted_vs_eager", | ||
| condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int), | ||
| reason="Integer size is not supported when scripting five_crop_image_tensor.", | ||
| test_marks=[ | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| TestMark( | ||
| ("TestKernels", "test_scripted_vs_eager"), | ||
| pytest.mark.xfail(reason="Integer size is not supported when scripting five_crop_image_tensor."), | ||
| condition=lambda args_kwargs: isinstance(args_kwargs.kwargs["size"], int), | ||
| ), | ||
| mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."), | ||
| mark_framework_limitation( | ||
| ("TestKernels", "test_dtype_and_device_consistency"), "Output is not a tensor." | ||
| ), | ||
| Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."), | ||
| Skip("test_no_inplace", reason="Output of five_crop_image_tensor is not a tensor."), | ||
| Skip("test_dtype_and_device_consistency", reason="Output of five_crop_image_tensor is not a tensor."), | ||
| ], | ||
| closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, | ||
| ), | ||
|
|
@@ -1466,15 +1489,16 @@ def reference_inputs_ten_crop_image_tensor(): | |
| sample_inputs_fn=sample_inputs_ten_crop_image_tensor, | ||
| reference_fn=pil_reference_wrapper(F.ten_crop_image_pil), | ||
| reference_inputs_fn=reference_inputs_ten_crop_image_tensor, | ||
| skips=[ | ||
| Skip( | ||
| "test_scripted_vs_eager", | ||
| condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int), | ||
| reason="Integer size is not supported when scripting ten_crop_image_tensor.", | ||
| test_marks=[ | ||
| TestMark( | ||
| ("TestKernels", "test_scripted_vs_eager"), | ||
| pytest.mark.xfail(reason="Integer size is not supported when scripting ten_crop_image_tensor."), | ||
| condition=lambda args_kwargs: isinstance(args_kwargs.kwargs["size"], int), | ||
| ), | ||
| mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."), | ||
| mark_framework_limitation( | ||
| ("TestKernels", "test_dtype_and_device_consistency"), "Output is not a tensor." | ||
| ), | ||
| Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."), | ||
| Skip("test_no_inplace", reason="Output of ten_crop_image_tensor is not a tensor."), | ||
| Skip("test_dtype_and_device_consistency", reason="Output of ten_crop_image_tensor is not a tensor."), | ||
| ], | ||
| closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, | ||
| ), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the same on
KernelInfoand here on theDispatcherInfo. I'm going to factor this out into a common base class in a follow-up PR since this PR is already quite large.