Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 113 additions & 9 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
import collections.abc
import dataclasses

from collections import defaultdict
from typing import Callable, Dict, List, Sequence, Type
from typing import Callable, Dict, List, Optional, Sequence, Type

import pytest
import torchvision.prototype.transforms.functional as F
from prototype_transforms_kernel_infos import KERNEL_INFOS, Skip
from prototype_common_utils import BoundingBoxLoader
from prototype_transforms_kernel_infos import KERNEL_INFOS, KernelInfo, Skip
from torchvision.prototype import features

__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]

KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KERNEL_INFOS}
KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}


@dataclasses.dataclass
class PILKernelInfo:
kernel: Callable
kernel_name: str = dataclasses.field(default=None)

def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__


def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"):
Expand All @@ -28,21 +40,35 @@ def skip_integer_size_jit(name="size"):
class DispatcherInfo:
dispatcher: Callable
kernels: Dict[Type, Callable]
kernel_infos: Dict[Type, KernelInfo] = dataclasses.field(default=None)
pil_kernel_info: Optional[PILKernelInfo] = None
method_name: str = dataclasses.field(default=None)
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False)

def __post_init__(self):
self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()}
self.method_name = self.method_name or self.dispatcher.__name__
skips_map = defaultdict(list)
for skip in self.skips:
skips_map[skip.test_name].append(skip)
self._skips_map = dict(skips_map)

def sample_inputs(self, *types):
for type in types or self.kernels.keys():
if type not in self.kernels:
raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
def sample_inputs(self, *feature_types, filter_metadata=True):
for feature_type in feature_types or self.kernels.keys():
if feature_type not in self.kernels:
raise pytest.UsageError(f"There is no kernel registered for type {feature_type.__name__}")

sample_inputs = self.kernel_infos[feature_type].sample_inputs_fn()
if not filter_metadata:
yield from sample_inputs
else:
for args_kwargs in sample_inputs:
for attribute in feature_type.__annotations__.keys():
if attribute in args_kwargs.kwargs:
del args_kwargs.kwargs[attribute]
Comment on lines +67 to +69
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removes metadata like image_size or format from the sample inputs. They are only there in the first place since we are using the sample inputs of the kernels for the dispatcher.


yield from KERNEL_SAMPLE_INPUTS_FN_MAP[self.kernels[type]]()
yield args_kwargs

def maybe_skip(self, *, test_name, args_kwargs, device):
skips = self._skips_map.get(test_name)
Expand All @@ -54,6 +80,31 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
pytest.skip(skip.reason)


def fill_sequence_needs_broadcast(args_kwargs, device):
(image_loader, *_), kwargs = args_kwargs
try:
fill = kwargs["fill"]
except KeyError:
return False

if not isinstance(fill, collections.abc.Sequence) or len(fill) > 1:
return False

return image_loader.num_channels > 1


skip_dispatch_pil_if_fill_sequence_needs_broadcast = Skip(
"test_dispatch_pil",
condition=fill_sequence_needs_broadcast,
reason="PIL kernel doesn't support sequences of length 1 if the number of channels is larger.",
)

skip_dispatch_feature = Skip(
"test_dispatch_feature",
reason="Dispatcher doesn't support arbitrary feature dispatch.",
)


DISPATCHER_INFOS = [
DispatcherInfo(
F.horizontal_flip,
Expand All @@ -62,6 +113,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.horizontal_flip_bounding_box,
features.Mask: F.horizontal_flip_mask,
},
pil_kernel_info=PILKernelInfo(F.horizontal_flip_image_pil, kernel_name="horizontal_flip_image_pil"),
),
DispatcherInfo(
F.resize,
Expand All @@ -70,6 +122,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.resize_bounding_box,
features.Mask: F.resize_mask,
},
pil_kernel_info=PILKernelInfo(F.resize_image_pil),
skips=[
skip_integer_size_jit(),
],
Expand All @@ -81,7 +134,11 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.affine_bounding_box,
features.Mask: F.affine_mask,
},
skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")],
pil_kernel_info=PILKernelInfo(F.affine_image_pil),
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT"),
],
),
DispatcherInfo(
F.vertical_flip,
Expand All @@ -90,6 +147,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.vertical_flip_bounding_box,
features.Mask: F.vertical_flip_mask,
},
pil_kernel_info=PILKernelInfo(F.vertical_flip_image_pil, kernel_name="vertical_flip_image_pil"),
),
DispatcherInfo(
F.rotate,
Expand All @@ -98,6 +156,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.rotate_bounding_box,
features.Mask: F.rotate_mask,
},
pil_kernel_info=PILKernelInfo(F.rotate_image_pil),
),
DispatcherInfo(
F.crop,
Expand All @@ -106,6 +165,17 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.crop_bounding_box,
features.Mask: F.crop_mask,
},
pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"),
skips=[
Skip(
"test_dispatch_feature",
condition=lambda args_kwargs, device: isinstance(args_kwargs.args[0], BoundingBoxLoader),
reason=(
"F.crop expects 4 coordinates as input, but bounding box sample inputs only generate two "
"since that is sufficient for the kernel."
),
)
],
),
DispatcherInfo(
F.resized_crop,
Expand All @@ -114,6 +184,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.resized_crop_bounding_box,
features.Mask: F.resized_crop_mask,
},
pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil),
),
DispatcherInfo(
F.pad,
Expand All @@ -122,6 +193,10 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.pad_bounding_box,
features.Mask: F.pad_mask,
},
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
],
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
),
DispatcherInfo(
F.perspective,
Expand All @@ -130,6 +205,10 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.perspective_bounding_box,
features.Mask: F.perspective_mask,
},
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
],
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
),
DispatcherInfo(
F.elastic,
Expand All @@ -138,6 +217,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.elastic_bounding_box,
features.Mask: F.elastic_mask,
},
pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
),
DispatcherInfo(
F.center_crop,
Expand All @@ -146,6 +226,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.center_crop_bounding_box,
features.Mask: F.center_crop_mask,
},
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
skips=[
skip_integer_size_jit("output_size"),
],
Expand All @@ -155,6 +236,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
kernels={
features.Image: F.gaussian_blur_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil),
skips=[
skip_python_scalar_arg_jit("kernel_size"),
skip_python_scalar_arg_jit("sigma"),
Expand All @@ -165,95 +247,117 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
kernels={
features.Image: F.equalize_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.equalize_image_pil, kernel_name="equalize_image_pil"),
),
DispatcherInfo(
F.invert,
kernels={
features.Image: F.invert_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.invert_image_pil, kernel_name="invert_image_pil"),
),
DispatcherInfo(
F.posterize,
kernels={
features.Image: F.posterize_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.posterize_image_pil, kernel_name="posterize_image_pil"),
),
DispatcherInfo(
F.solarize,
kernels={
features.Image: F.solarize_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.solarize_image_pil, kernel_name="solarize_image_pil"),
),
DispatcherInfo(
F.autocontrast,
kernels={
features.Image: F.autocontrast_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
),
DispatcherInfo(
F.adjust_sharpness,
kernels={
features.Image: F.adjust_sharpness_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
),
DispatcherInfo(
F.erase,
kernels={
features.Image: F.erase_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.erase_image_pil),
skips=[
skip_dispatch_feature,
],
),
DispatcherInfo(
F.adjust_brightness,
kernels={
features.Image: F.adjust_brightness_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.adjust_brightness_image_pil, kernel_name="adjust_brightness_image_pil"),
),
DispatcherInfo(
F.adjust_contrast,
kernels={
features.Image: F.adjust_contrast_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
),
DispatcherInfo(
F.adjust_gamma,
kernels={
features.Image: F.adjust_gamma_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
),
DispatcherInfo(
F.adjust_hue,
kernels={
features.Image: F.adjust_hue_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
),
DispatcherInfo(
F.adjust_saturation,
kernels={
features.Image: F.adjust_saturation_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
),
DispatcherInfo(
F.five_crop,
kernels={
features.Image: F.five_crop_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
skips=[
skip_integer_size_jit(),
skip_dispatch_feature,
],
),
DispatcherInfo(
F.ten_crop,
kernels={
features.Image: F.ten_crop_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
skips=[
skip_integer_size_jit(),
skip_dispatch_feature,
],
),
DispatcherInfo(
F.normalize,
kernels={
features.Image: F.normalize_image_tensor,
},
skips=[
skip_dispatch_feature,
],
),
]
2 changes: 1 addition & 1 deletion test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class KernelInfo:
sample_inputs_fn: Callable[[], Iterable[ArgsKwargs]]
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
kernel_name: Optional[str] = None
kernel_name: str = dataclasses.field(default=None)
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take
# tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen
# inside the function. It should return a tensor or to be more precise an object that can be compared to a
Expand Down
Loading