Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3e81aef
fix prototype kernels
pmeier Oct 24, 2022
33852be
fix stable kernels
pmeier Oct 24, 2022
3a92412
fix tests
pmeier Oct 25, 2022
e13613a
make test more robust
pmeier Oct 25, 2022
a400225
Merge branch 'main' into fix-hardcoded-255
pmeier Oct 25, 2022
e053125
Merge branch 'main' into fix-hardcoded-255
pmeier Oct 25, 2022
3327e04
improve invert for signed integers
pmeier Oct 27, 2022
91e8c66
Merge branch 'main' into fix-hardcoded-255
datumbox Oct 27, 2022
bdd8127
Merge branch 'main' into fix-hardcoded-255
pmeier Oct 28, 2022
c672425
improve invert
pmeier Oct 28, 2022
6375627
fix posterize
pmeier Oct 28, 2022
6895f71
Merge branch 'main' into fix-hardcoded-255
pmeier Oct 28, 2022
c0236fc
Revert "assume that integer images are [0, 255] in equalize (#6859)"
pmeier Oct 28, 2022
8713528
Merge branch 'fix-hardcoded-255' of https://github.com/pmeier/vision …
pmeier Oct 28, 2022
9acf2f4
Merge branch 'main' into fix-hardcoded-255
pmeier Nov 2, 2022
402b01f
fix solarize in AA
pmeier Nov 2, 2022
d0394b7
Merge branch 'main' into fix-hardcoded-255
pmeier Nov 3, 2022
5f33f4a
fix resize
pmeier Nov 3, 2022
3a13a08
Revert "fix resize"
pmeier Nov 3, 2022
7765a47
Merge branch 'main' into fix-hardcoded-255
pmeier Nov 3, 2022
2d0549d
Merge branch 'main' into fix-hardcoded-255
pmeier Nov 3, 2022
f594ceb
Merge branch 'main' into fix-hardcoded-255
pmeier Nov 3, 2022
48603b0
add comment to float max value
pmeier Nov 3, 2022
a61d44f
Merge branch 'fix-hardcoded-255' of https://github.com/pmeier/vision …
pmeier Nov 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
bound = 1.0 if fp else 255.0
bound = _FT._max_value(image1.dtype)
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
return output if fp else output.to(image1.dtype)

Expand All @@ -20,7 +20,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
_FT._assert_channels(image, [1, 3])

fp = image.is_floating_point()
bound = 1.0 if fp else 255.0
bound = _FT._max_value(image.dtype)
output = image.mul(brightness_factor).clamp_(0, bound)
return output if fp else output.to(image.dtype)

Expand Down Expand Up @@ -226,19 +226,15 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
return image

orig_dtype = image.dtype
if image.dtype == torch.uint8:
image = image / 255.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing the conversion manually, I've opted to use our kernel for this. Note that this also implicitly converts to float32 since the divisor is a float.

image = convert_dtype_image_tensor(image, torch.float32)

image = _rgb_to_hsv(image)
h, s, v = image.unbind(dim=-3)
h.add_(hue_factor).remainder_(1.0)
image = torch.stack((h, s, v), dim=-3)
image_hue_adj = _hsv_to_rgb(image)

if orig_dtype == torch.uint8:
image_hue_adj = image_hue_adj.mul_(255.0).to(dtype=orig_dtype)

return image_hue_adj
return convert_dtype_image_tensor(image_hue_adj, orig_dtype)


adjust_hue_image_pil = _FP.adjust_hue
Expand Down Expand Up @@ -313,8 +309,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:


def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
bound = 1 if image.is_floating_point() else 255
if threshold > bound:
if threshold > _FT._max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")

return torch.where(image >= threshold, invert_image_tensor(image), image)
Expand Down Expand Up @@ -350,7 +345,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
# exit earlier on empty images
return image

bound = 1.0 if image.is_floating_point() else 255.0
bound = _FT._max_value(image.dtype)
dtype = image.dtype if torch.is_floating_point(image) else torch.float32

minimum = image.amin(dim=(-2, -1), keepdim=True).to(dtype)
Expand Down Expand Up @@ -467,7 +462,7 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.dtype == torch.uint8:
return image.bitwise_not()
else:
return (1 if image.is_floating_point() else 255) - image # type: ignore[no-any-return]
return _FT._max_value(image.dtype) - image # type: ignore[no-any-return]


invert_image_pil = _FP.invert
Expand Down
24 changes: 7 additions & 17 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ def _assert_image_tensor(img: Tensor) -> None:
raise TypeError("Tensor is not a torch image.")


def _assert_threshold(img: Tensor, threshold: float) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was only used once so I inlined it.

bound = 1 if img.is_floating_point() else 255
if threshold > bound:
raise TypeError("Threshold should be less than bound of img.")


def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3]
Expand Down Expand Up @@ -212,19 +206,15 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
return img

orig_dtype = img.dtype
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0
img = convert_image_dtype(img, torch.float32)

img = _rgb2hsv(img)
h, s, v = img.unbind(dim=-3)
h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img)

if orig_dtype == torch.uint8:
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)

return img_hue_adj
return convert_image_dtype(img_hue_adj, orig_dtype)


def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
Expand Down Expand Up @@ -263,7 +253,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:

def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
ratio = float(ratio)
bound = 1.0 if img1.is_floating_point() else 255.0
bound = _max_value(img1.dtype)
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)


Expand Down Expand Up @@ -775,8 +765,7 @@ def invert(img: Tensor) -> Tensor:

_assert_channels(img, [1, 3])

bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device)
return bound - img
return _max_value(img.dtype) - img


def posterize(img: Tensor, bits: int) -> Tensor:
Expand All @@ -802,7 +791,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor:

_assert_channels(img, [1, 3])

_assert_threshold(img, threshold)
if threshold > _max_value(img.dtype):
raise TypeError("Threshold should be less than bound of img.")

inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)
Expand Down Expand Up @@ -849,7 +839,7 @@ def autocontrast(img: Tensor) -> Tensor:

_assert_channels(img, [1, 3])

bound = 1.0 if img.is_floating_point() else 255.0
bound = _max_value(img.dtype)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32

minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
Expand Down