@@ -700,40 +700,43 @@ def __init__(self, crop_to=None, crop_mode="c"):
700700 super (Crop , self ).__init__ (p = 1 , data_indices = None )
701701
702702 if crop_to is not None :
703- if not isinstance (crop_to , int ) and not isinstance ( crop_to , tuple ):
704- raise TypeError ("Argument crop_size has an incorrect type!" )
703+ if not isinstance (crop_to , ( int , tuple , list ) ):
704+ raise TypeError ("Argument crop_to has an incorrect type!" )
705705 if crop_mode not in ALLOWED_CROPS :
706706 raise ValueError ("Argument crop_mode has an incorrect type!" )
707707
708+ if isinstance (crop_to , list ):
709+ crop_to = tuple (crop_to )
710+
708711 if isinstance (crop_to , tuple ):
709712 if not isinstance (crop_to [0 ], int ) or not isinstance (crop_to [1 ], int ):
710- raise TypeError ("Incorrect type of the crop_size !" )
713+ raise TypeError ("Incorrect type of the crop_to !" )
711714
712715 if isinstance (crop_to , int ):
713716 crop_to = (crop_to , crop_to )
714717
715- self .crop_size = crop_to
718+ self .crop_to = crop_to
716719 self .crop_mode = crop_mode
717720
718721 def sample_transform (self , data : DataContainer ):
719722 h , w = super (Crop , self ).sample_transform (data )
720- if self .crop_size is not None :
721- if self .crop_size [0 ] > w or self .crop_size [1 ] > h :
723+ if self .crop_to is not None :
724+ if self .crop_to [0 ] > w or self .crop_to [1 ] > h :
722725 raise ValueError
723726
724727 if self .crop_mode == "r" :
725- self .state_dict ["x" ] = int (random .random () * (w - self .crop_size [0 ]))
726- self .state_dict ["y" ] = int (random .random () * (h - self .crop_size [1 ]))
728+ self .state_dict ["x" ] = int (random .random () * (w - self .crop_to [0 ]))
729+ self .state_dict ["y" ] = int (random .random () * (h - self .crop_to [1 ]))
727730
728731 else :
729- self .state_dict ["x" ] = w // 2 - self .crop_size [0 ] // 2
730- self .state_dict ["y" ] = h // 2 - self .crop_size [1 ] // 2
732+ self .state_dict ["x" ] = w // 2 - self .crop_to [0 ] // 2
733+ self .state_dict ["y" ] = h // 2 - self .crop_to [1 ] // 2
731734
732735 def __crop_img_or_mask (self , img_mask ):
733- if self .crop_size is not None :
736+ if self .crop_to is not None :
734737 return img_mask [
735- self .state_dict ["y" ] : self .state_dict ["y" ] + self .crop_size [1 ],
736- self .state_dict ["x" ] : self .state_dict ["x" ] + self .crop_size [0 ],
738+ self .state_dict ["y" ] : self .state_dict ["y" ] + self .crop_to [1 ],
739+ self .state_dict ["x" ] : self .state_dict ["x" ] + self .crop_to [0 ],
737740 ]
738741 return img_mask
739742
@@ -748,15 +751,15 @@ def _apply_labels(self, labels, settings: dict):
748751 return labels
749752
750753 def _apply_pts (self , pts : Keypoints , settings : dict ):
751- if self .crop_size is None :
754+ if self .crop_to is None :
752755 return pts
753756 pts_data = pts .data .copy ()
754757 x , y = self .state_dict ["x" ], self .state_dict ["y" ]
755758
756759 pts_data [:, 0 ] -= x
757760 pts_data [:, 1 ] -= y
758761
759- return Keypoints (pts_data , self .crop_size [1 ], self .crop_size [0 ])
762+ return Keypoints (pts_data , self .crop_to [1 ], self .crop_to [0 ])
760763
761764
762765class Noise (BaseTransform ):
0 commit comments