@@ -223,19 +223,16 @@ def __init__(
223223 _check_padding_arg (padding )
224224 _check_padding_mode_arg (padding_mode )
225225
226+ # This cast does Sequence[int] -> List[int] and is required to make mypy happy
227+ if not isinstance (padding , int ):
228+ padding = list (padding )
226229 self .padding = padding
227230 self .fill = _setup_fill_arg (fill )
228231 self .padding_mode = padding_mode
229232
230233 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
231234 fill = self .fill [type (inpt )]
232-
233- # This cast does Sequence[int] -> List[int] and is required to make mypy happy
234- padding = self .padding
235- if not isinstance (padding , int ):
236- padding = list (padding )
237-
238- return F .pad (inpt , padding = padding , fill = fill , padding_mode = self .padding_mode )
235+ return F .pad (inpt , padding = self .padding , fill = fill , padding_mode = self .padding_mode )
239236
240237
241238class RandomZoomOut (_RandomApplyTransform ):
@@ -298,7 +295,7 @@ def __init__(
298295 self .center = center
299296
300297 def _get_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
301- angle = float ( torch .empty (1 ).uniform_ (float ( self .degrees [0 ]), float ( self .degrees [1 ])) .item () )
298+ angle = torch .empty (1 ).uniform_ (self .degrees [0 ], self .degrees [1 ]).item ()
302299 return dict (angle = angle )
303300
304301 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
@@ -355,7 +352,7 @@ def __init__(
355352 def _get_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
356353 height , width = query_spatial_size (flat_inputs )
357354
358- angle = float ( torch .empty (1 ).uniform_ (float ( self .degrees [0 ]), float ( self .degrees [1 ])) .item () )
355+ angle = torch .empty (1 ).uniform_ (self .degrees [0 ], self .degrees [1 ]).item ()
359356 if self .translate is not None :
360357 max_dx = float (self .translate [0 ] * width )
361358 max_dy = float (self .translate [1 ] * height )
@@ -366,15 +363,15 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
366363 translate = (0 , 0 )
367364
368365 if self .scale is not None :
369- scale = float ( torch .empty (1 ).uniform_ (self .scale [0 ], self .scale [1 ]).item () )
366+ scale = torch .empty (1 ).uniform_ (self .scale [0 ], self .scale [1 ]).item ()
370367 else :
371368 scale = 1.0
372369
373370 shear_x = shear_y = 0.0
374371 if self .shear is not None :
375- shear_x = float ( torch .empty (1 ).uniform_ (self .shear [0 ], self .shear [1 ]).item () )
372+ shear_x = torch .empty (1 ).uniform_ (self .shear [0 ], self .shear [1 ]).item ()
376373 if len (self .shear ) == 4 :
377- shear_y = float ( torch .empty (1 ).uniform_ (self .shear [2 ], self .shear [3 ]).item () )
374+ shear_y = torch .empty (1 ).uniform_ (self .shear [2 ], self .shear [3 ]).item ()
378375
379376 shear = (shear_x , shear_y )
380377 return dict (angle = angle , translate = translate , scale = scale , shear = shear )
@@ -451,12 +448,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
451448 needs_pad = any (padding )
452449
453450 needs_vert_crop , top = (
454- (True , int ( torch .randint (0 , padded_height - cropped_height + 1 , size = ())))
451+ (True , torch .randint (0 , padded_height - cropped_height + 1 , size = ()). item ( ))
455452 if padded_height > cropped_height
456453 else (False , 0 )
457454 )
458455 needs_horz_crop , left = (
459- (True , int ( torch .randint (0 , padded_width - cropped_width + 1 , size = ())))
456+ (True , torch .randint (0 , padded_width - cropped_width + 1 , size = ()). item ( ))
460457 if padded_width > cropped_width
461458 else (False , 0 )
462459 )
@@ -506,21 +503,23 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
506503
507504 half_height = height // 2
508505 half_width = width // 2
506+ bound_height = int (distortion_scale * half_height ) + 1
507+ bound_width = int (distortion_scale * half_width ) + 1
509508 topleft = [
510- int ( torch .randint (0 , int ( distortion_scale * half_width ) + 1 , size = (1 ,)).item () ),
511- int ( torch .randint (0 , int ( distortion_scale * half_height ) + 1 , size = (1 ,)).item () ),
509+ torch .randint (0 , bound_width + 1 , size = (1 ,)).item (),
510+ torch .randint (0 , bound_height , size = (1 ,)).item (),
512511 ]
513512 topright = [
514- int ( torch .randint (width - int ( distortion_scale * half_width ) - 1 , width , size = (1 ,)).item () ),
515- int ( torch .randint (0 , int ( distortion_scale * half_height ) + 1 , size = (1 ,)).item () ),
513+ torch .randint (width - bound_width , width , size = (1 ,)).item (),
514+ torch .randint (0 , bound_height , size = (1 ,)).item (),
516515 ]
517516 botright = [
518- int ( torch .randint (width - int ( distortion_scale * half_width ) - 1 , width , size = (1 ,)).item () ),
519- int ( torch .randint (height - int ( distortion_scale * half_height ) - 1 , height , size = (1 ,)).item () ),
517+ torch .randint (width - bound_width , width , size = (1 ,)).item (),
518+ torch .randint (height - bound_height , height , size = (1 ,)).item (),
520519 ]
521520 botleft = [
522- int ( torch .randint (0 , int ( distortion_scale * half_width ) + 1 , size = (1 ,)).item () ),
523- int ( torch .randint (height - int ( distortion_scale * half_height ) - 1 , height , size = (1 ,)).item () ),
521+ torch .randint (0 , bound_width , size = (1 ,)).item (),
522+ torch .randint (height - bound_height , height , size = (1 ,)).item (),
524523 ]
525524 startpoints = [[0 , 0 ], [width - 1 , 0 ], [width - 1 , height - 1 ], [0 , height - 1 ]]
526525 endpoints = [topleft , topright , botright , botleft ]
@@ -623,7 +622,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
623622
624623 while True :
625624 # sample an option
626- idx = int ( torch .randint (low = 0 , high = len (self .options ), size = (1 ,)))
625+ idx = torch .randint (low = 0 , high = len (self .options ), size = (1 ,)). item ( )
627626 min_jaccard_overlap = self .options [idx ]
628627 if min_jaccard_overlap >= 1.0 : # a value larger than 1 encodes the leave as-is option
629628 return dict ()
0 commit comments