Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def normalize(


def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma)
lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma)
x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
kernel1d = torch.softmax(-x.pow_(2), dim=0)
kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0)
return kernel1d


Expand All @@ -89,54 +89,61 @@ def gaussian_blur_image_tensor(
# TODO: consider deprecating integers from sigma on the future
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2:
elif len(kernel_size) != 2:
raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0:
raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")

if sigma is None:
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]

if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
if isinstance(sigma, (int, float)):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
if len(sigma) != 2:
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
else:
if isinstance(sigma, (list, tuple)):
length = len(sigma)
if length == 1:
s = float(sigma[0])
sigma = [s, s]
elif length != 2:
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}")
elif isinstance(sigma, (int, float)):
s = float(sigma)
sigma = [s, s]
else:
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
for s in sigma:
if s <= 0.0:
raise ValueError(f"sigma should have positive values. Got {sigma}")

if image.numel() == 0:
return image

dtype = image.dtype
shape = image.shape

if image.ndim > 4:
ndim = image.ndim
if ndim == 3:
image = image.unsqueeze(dim=0)
elif ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

dtype = image.dtype if torch.is_floating_point(image) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=image.device)
kernel = kernel.expand(image.shape[-3], 1, kernel.shape[0], kernel.shape[1])
fp = torch.is_floating_point(image)
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device)
kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1])

image, need_cast, need_squeeze, out_dtype = _FT._cast_squeeze_in(image, [kernel.dtype])
output = image if fp else image.to(dtype=torch.float32)

# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
output = torch_pad(image, padding, mode="reflect")
output = conv2d(output, kernel, groups=output.shape[-3])

output = _FT._cast_squeeze_out(output, need_cast, need_squeeze, out_dtype)
output = torch_pad(output, padding, mode="reflect")
output = conv2d(output, kernel, groups=shape[-3])

if needs_unsquash:
if ndim == 3:
output = output.squeeze(dim=0)
elif ndim > 4:
output = output.reshape(shape)

if not fp:
output = output.round_().to(dtype=dtype)

return output


Expand Down