diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ce97ce0575d..98bd7a52712 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -727,39 +727,38 @@ def _pad_with_scalar_fill( shape = image.shape num_channels, height, width = shape[-3:] - if image.numel() > 0: - image = image.reshape(-1, num_channels, height, width) - - if padding_mode == "edge": - # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map - # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad` - # name. - padding_mode = "replicate" - - if padding_mode == "constant": - image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill)) - elif padding_mode in ("reflect", "replicate"): - # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs. - # TODO: See https://github.com/pytorch/pytorch/issues/40763 - dtype = image.dtype - if not image.is_floating_point(): - needs_cast = True - image = image.to(torch.float32) - else: - needs_cast = False - - image = torch_pad(image, torch_padding, mode=padding_mode) - - if needs_cast: - image = image.to(dtype) - else: # padding_mode == "symmetric" - image = _FT._pad_symmetric(image, torch_padding) + batch_size = 1 + for s in shape[:-3]: + batch_size *= s + + image = image.reshape(batch_size, num_channels, height, width) + + if padding_mode == "edge": + # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map + # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad` + # name. + padding_mode = "replicate" + + if padding_mode == "constant": + image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill)) + elif padding_mode in ("reflect", "replicate"): + # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs. + # TODO: See https://github.com/pytorch/pytorch/issues/40763 + dtype = image.dtype + if not image.is_floating_point(): + needs_cast = True + image = image.to(torch.float32) + else: + needs_cast = False - new_height, new_width = image.shape[-2:] - else: - left, right, top, bottom = torch_padding - new_height = height + top + bottom - new_width = width + left + right + image = torch_pad(image, torch_padding, mode=padding_mode) + + if needs_cast: + image = image.to(dtype) + else: # padding_mode == "symmetric" + image = _FT._pad_symmetric(image, torch_padding) + + new_height, new_width = image.shape[-2:] return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) @@ -868,7 +867,24 @@ def pad( return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) -crop_image_tensor = _FT.crop +def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: + h, w = image.shape[-2:] + + right = left + width + bottom = top + height + + if left < 0 or top < 0 or right > w or bottom > h: + image = image[..., max(top, 0) : bottom, max(left, 0) : right] + torch_padding = [ + max(min(right, 0) - left, 0), + max(right - max(w, left), 0), + max(min(bottom, 0) - top, 0), + max(bottom - max(h, top), 0), + ] + return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant") + return image[..., top:bottom, left:right] + + crop_image_pil = _FP.crop @@ -893,7 +909,18 @@ def crop_bounding_box( def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: - return crop_image_tensor(mask, top, left, height, width) + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = crop_image_tensor(mask, top, left, height, width) + + if needs_squeeze: + output = output.squeeze(0) + + return output def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: