diff --git a/test/test_transforms.py b/test/test_transforms.py index acee9629f65..da6b0c34645 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -258,6 +258,91 @@ def test_lambda(self): # Checking if Lambda can be printed as string trans.__repr__() + def test_random_apply(self): + random_state = random.getstate() + random.seed(42) + random_apply_transform = transforms.RandomApply( + [ + transforms.RandomRotation((-45, 45)), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + ], p=0.75 + ) + img = transforms.ToPILImage()(torch.rand(3, 10, 10)) + num_samples = 250 + num_applies = 0 + for _ in range(num_samples): + out = random_apply_transform(img) + if out != img: + num_applies += 1 + + p_value = stats.binom_test(num_applies, num_samples, p=0.75) + random.setstate(random_state) + assert p_value > 0.0001 + + # Checking if RandomApply can be printed as string + random_apply_transform.__repr__() + + def test_random_choice(self): + random_state = random.getstate() + random.seed(42) + random_choice_transform = transforms.RandomChoice( + [ + transforms.Resize(15), + transforms.Resize(20), + transforms.CenterCrop(10) + ] + ) + img = transforms.ToPILImage()(torch.rand(3, 25, 25)) + num_samples = 250 + num_resize_15 = 0 + num_resize_20 = 0 + num_crop_10 = 0 + for _ in range(num_samples): + out = random_choice_transform(img) + if out.size == (15, 15): + num_resize_15 += 1 + elif out.size == (20, 20): + num_resize_20 += 1 + elif out.size == (10, 10): + num_crop_10 += 1 + + p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333) + assert p_value > 0.0001 + p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333) + assert p_value > 0.0001 + p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333) + assert p_value > 0.0001 + + random.setstate(random_state) + # Checking if RandomChoice can be printed as string + random_choice_transform.__repr__() + + def test_random_order(self): + random_state = random.getstate() + random.seed(42) + random_order_transform = transforms.RandomOrder( + [ + transforms.Resize(20), + transforms.CenterCrop(10) + ] + ) + img = transforms.ToPILImage()(torch.rand(3, 25, 25)) + num_samples = 250 + num_normal_order = 0 + resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img)) + for _ in range(num_samples): + out = random_order_transform(img) + if out == resize_crop_out: + num_normal_order += 1 + + p_value = stats.binom_test(num_normal_order, num_samples, p=0.5) + random.setstate(random_state) + assert p_value > 0.0001 + + # Checking if RandomOrder can be printed as string + random_order_transform.__repr__() + def test_to_tensor(self): test_channels = [1, 3, 4] height, width = 4, 4 diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index aceb296adf3..caf4261335e 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -16,9 +16,9 @@ from . import functional as F __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", - "Lambda", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", - "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", - "Grayscale", "RandomGrayscale"] + "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", + "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", + "ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale"] class Compose(object): @@ -261,6 +261,77 @@ def __repr__(self): return self.__class__.__name__ + '()' +class RandomTransforms(object): + """Base class for a list of transformations with randomness + + Args: + transforms (list or tuple): list of transformations + """ + + def __init__(self, transforms): + assert isinstance(transforms, (list, tuple)) + self.transforms = transforms + + def __call__(self, *args, **kwargs): + raise NotImplementedError() + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class RandomApply(RandomTransforms): + """Apply randomly a list of transformations with a given probability + + Args: + transforms (list or tuple): list of transformations + p (float): probability + """ + + def __init__(self, transforms, p=0.5): + super(RandomApply, self).__init__(transforms) + self.p = p + + def __call__(self, img): + if self.p < random.random(): + return img + for t in self.transforms: + img = t(img) + return img + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + format_string += '\n p={}'.format(self.p) + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class RandomOrder(RandomTransforms): + """Apply a list of transformations in a random order + """ + def __call__(self, img): + order = list(range(len(self.transforms))) + random.shuffle(order) + for i in order: + img = self.transforms[i](img) + return img + + +class RandomChoice(RandomTransforms): + """Apply single transformation randomly picked from a list + """ + def __call__(self, img): + t = random.choice(self.transforms) + return t(img) + + class RandomCrop(object): """Crop the given PIL Image at a random location.