Skip to content

Commit 7902956

Browse files
committed
simplify equalize correctness
1 parent 58d6ecb commit 7902956

File tree

1 file changed

+9
-55
lines changed

1 file changed

+9
-55
lines changed

test/test_transforms_v2_refactored.py

Lines changed: 9 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3891,62 +3891,16 @@ def test_functional_signature(self, kernel, input_type):
38913891
def test_transform(self, make_input):
38923892
check_transform(transforms.RandomEqualize(p=1), make_input())
38933893

3894-
# We are not using the default `make_image` here since that uniformly samples the values over the whole value range.
3895-
# Since the whole point of F.equalize is to transform an arbitrary distribution of values into a uniform one,
3896-
# the information gain is low if we already provide something really close to the expected value.
3897-
def _make_correctness_image(self, *, type, **kwargs):
3898-
shape = (3, 117, 253)
3899-
dtype = torch.uint8
3900-
device = "cpu"
3901-
3902-
max_value = get_max_value(dtype)
3903-
3904-
def make_constant_image(*, value_factor=0.0):
3905-
return torch.full(shape, value_factor * max_value, dtype=dtype, device=device)
3906-
3907-
def make_uniform_band_distributed_image(*, low_factor=0.1, high_factor=0.9):
3908-
return torch.testing.make_tensor(
3909-
shape, dtype=dtype, device=device, low=low_factor * max_value, high=high_factor * max_value
3910-
)
3911-
3912-
def make_beta_distributed_image(*, alpha=2.0, beta=5.0):
3913-
image = torch.distributions.Beta(alpha, beta).sample(shape)
3914-
image.mul_(get_max_value(dtype)).round_()
3915-
return image.to(dtype=dtype, device=device)
3916-
3917-
make_fn = {
3918-
"constant": make_constant_image,
3919-
"uniform_band_distributed": make_uniform_band_distributed_image,
3920-
"beta_distributed": make_beta_distributed_image,
3921-
}[type]
3922-
return tv_tensors.Image(make_fn(**kwargs))
3923-
3924-
@pytest.mark.parametrize(
3925-
"make_correctness_image_kwargs",
3926-
[
3927-
*[dict(type="constant", value_factor=value_factor) for value_factor in [0.0, 0.5, 1.0]],
3928-
*[
3929-
dict(type="uniform_band_distributed", low_factor=low_factor, high_factor=high_factor)
3930-
for low_factor, high_factor in [
3931-
(0.0, 0.25),
3932-
(0.25, 0.75),
3933-
(0.75, 1.0),
3934-
]
3935-
],
3936-
*[
3937-
dict(type="beta_distributed", alpha=alpha, beta=beta)
3938-
for alpha, beta in [
3939-
(0.5, 0.5),
3940-
(2.0, 2.0),
3941-
(2.0, 5.0),
3942-
(5.0, 2.0),
3943-
]
3944-
],
3945-
],
3946-
)
3894+
@pytest.mark.parametrize(("low", "high"), [(0, 64), (64, 192), (192, 256), (0, 1), (127, 128), (255, 256)])
39473895
@pytest.mark.parametrize("fn", [F.equalize, transform_cls_to_functional(transforms.RandomEqualize, p=1)])
3948-
def test_image_correctness(self, make_correctness_image_kwargs, fn):
3949-
image = self._make_correctness_image(**make_correctness_image_kwargs)
3896+
def test_image_correctness(self, low, high, fn):
3897+
# We are not using the default `make_image` here since that uniformly samples the values over the whole value
3898+
# range. Since the whole point of F.equalize is to transform an arbitrary distribution of values into a uniform
3899+
# one over the full range, the information gain is low if we already provide something really close to the
3900+
# expected value.
3901+
image = tv_tensors.Image(
3902+
torch.testing.make_tensor((3, 117, 253), dtype=torch.uint8, device="cpu", low=low, high=high)
3903+
)
39503904

39513905
actual = fn(image)
39523906
expected = F.to_image(F.equalize(F.to_pil_image(image)))

0 commit comments

Comments
 (0)