11import torch
2+ from torch .nn .functional import conv2d
23from torchvision .prototype import features
34from torchvision .transforms import functional_pil as _FP , functional_tensor as _FT
45
@@ -111,6 +112,8 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
111112 if image .numel () == 0 or height <= 2 or width <= 2 :
112113 return image
113114
115+ bound = _FT ._max_value (image .dtype )
116+ fp = image .is_floating_point ()
114117 shape = image .shape
115118
116119 if image .ndim > 4 :
@@ -119,7 +122,30 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
119122 else :
120123 needs_unsquash = False
121124
122- output = _blend (image , _FT ._blurred_degenerate_image (image ), sharpness_factor )
125+ # The following is a normalized 3x3 kernel with 1s in the edges and a 5 in the middle.
126+ kernel_dtype = image .dtype if fp else torch .float32
127+ a , b = 1.0 / 13.0 , 5.0 / 13.0
128+ kernel = torch .tensor ([[a , a , a ], [a , b , a ], [a , a , a ]], dtype = kernel_dtype , device = image .device )
129+ kernel = kernel .expand (num_channels , 1 , 3 , 3 )
130+
131+ # We copy and cast at the same time to avoid modifications on the original data
132+ output = image .to (dtype = kernel_dtype , copy = True )
133+ blurred_degenerate = conv2d (output , kernel , groups = num_channels )
134+ if not fp :
135+ # it is better to round before cast
136+ blurred_degenerate = blurred_degenerate .round_ ()
137+
138+ # Create a view on the underlying output while pointing at the same data. We do this to avoid indexing twice.
139+ view = output [..., 1 :- 1 , 1 :- 1 ]
140+
141+ # We speed up blending by minimizing flops and doing in-place. The 2 blend options are mathematically equivalent:
142+ # x+(1-r)*(y-x) = x + (1-r)*y - (1-r)*x = x*r + y*(1-r)
143+ view .add_ (blurred_degenerate .sub_ (view ), alpha = (1.0 - sharpness_factor ))
144+
145+ # The actual data of ouput have been modified by the above. We only need to clamp and cast now.
146+ output = output .clamp_ (0 , bound )
147+ if not fp :
148+ output = output .to (image .dtype )
123149
124150 if needs_unsquash :
125151 output = output .reshape (shape )
0 commit comments