diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 6d649ab18fa..2d71e7fef17 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -11,7 +11,7 @@ import numbers import types import collections - +import functools class Compose(object): """Composes several transforms together. @@ -236,13 +236,52 @@ class Pad(object): padding (int or sequence): Padding on each border. If a sequence of length 4, it is used to pad left, top, right and bottom borders respectively. fill: Pixel fill value. Default is 0. + type:padding type: constant,reflect,edge,symmetric """ - def __init__(self, padding, fill=0): + def __init__(self, padding, fill=0,type="constant"): assert isinstance(padding, numbers.Number) assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) self.padding = padding self.fill = fill + assert (type in ["constant","edge","symmetric","reflect"]) + self.type = type + + def __expand_reflect(slef,image, border=0): + """ + Add border to the image(Symmetric padding) + + :param image: The image to expand. + :param border: Border width, in pixels. + :return: An image. + """ + img = np.asarray(image) + img = np.pad(img, pad_width=border, mode="reflect") + return Image.fromarray(np.uint8(img)) + + def __expand_edge(self,image, border=0): + """ + Add border to the image(Symmetric padding) + + :param image: The image to expand. + :param border: Border width, in pixels. + :return: An image. + """ + img = np.asarray(image) + img = np.pad(img, pad_width=border, mode="edge") + return Image.fromarray(np.uint8(img)) + + def __expand_symmetric(self,image, border=0): + """ + Add border to the image(Symmetric padding) + + :param image: The image to expand. + :param border: Border width, in pixels. + :return: An image. + """ + img = np.asarray(image) + img = np.pad(img, pad_width=border, mode="symmetric") + return Image.fromarray(np.uint8(img)) def __call__(self, img): """ @@ -252,7 +291,14 @@ def __call__(self, img): Returns: PIL.Image: Padded image. """ - return ImageOps.expand(img, border=self.padding, fill=self.fill) + if self.type == "constant": + return ImageOps.expand(img, border=self.padding, fill=self.fill) + elif self.type == "symmetric": + return self.__expand_symmetric(img, border=self.padding) + elif self.type == "reflect": + return self.__expand_reflect(img, border=self.padding) + elif self.type == "edge": + return self.__expand_edge(img, border=self.padding) class Lambda(object): @@ -369,3 +415,7 @@ def __call__(self, img): scale = Scale(self.size, interpolation=self.interpolation) crop = CenterCrop(self.size) return crop(scale(img)) + + + +