Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ def __init__(
],
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)]),
# We updated gaussian blur kernel generation with a faster and numerically more stable version
# This brings float32 accumulation visible in elastic transform -> we need to relax consistency tolerance
closeness_kwargs={"rtol": 1e-1, "atol": 1},
),
ConsistencyConfig(
prototype_transforms.GaussianBlur,
Expand All @@ -333,6 +336,7 @@ def __init__(
ArgsKwargs(kernel_size=3, sigma=0.7),
ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)),
],
closeness_kwargs={"rtol": 1e-5, "atol": 1e-5},
),
ConsistencyConfig(
prototype_transforms.RandomAffine,
Expand Down
30 changes: 29 additions & 1 deletion torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import List, Optional, Union

import PIL.Image
Expand Down Expand Up @@ -32,6 +33,22 @@ def normalize(
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)


def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma)
x = torch.linspace(-lim, lim, steps=kernel_size)
kernel1d = torch.softmax(-x.pow_(2), dim=0)
return kernel1d


def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x
return kernel2d


def gaussian_blur_image_tensor(
image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
Expand Down Expand Up @@ -70,7 +87,18 @@ def gaussian_blur_image_tensor(
else:
needs_unsquash = False

output = _FT.gaussian_blur(image, kernel_size, sigma)
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=image.device)
kernel = kernel.expand(image.shape[-3], 1, kernel.shape[0], kernel.shape[1])

image, need_cast, need_squeeze, out_dtype = _FT._cast_squeeze_in(image, [kernel.dtype])

# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
output = _FT.torch_pad(image, padding, mode="reflect")
output = _FT.conv2d(output, kernel, groups=output.shape[-3])

output = _FT._cast_squeeze_out(output, need_cast, need_squeeze, out_dtype)

if needs_unsquash:
output = output.reshape(shape)
Expand Down