Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions solt/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,40 +700,43 @@ def __init__(self, crop_to=None, crop_mode="c"):
super(Crop, self).__init__(p=1, data_indices=None)

if crop_to is not None:
if not isinstance(crop_to, int) and not isinstance(crop_to, tuple):
raise TypeError("Argument crop_size has an incorrect type!")
if not isinstance(crop_to, (int, tuple, list)):
raise TypeError("Argument crop_to has an incorrect type!")
if crop_mode not in ALLOWED_CROPS:
raise ValueError("Argument crop_mode has an incorrect type!")

if isinstance(crop_to, list):
crop_to = tuple(crop_to)

if isinstance(crop_to, tuple):
if not isinstance(crop_to[0], int) or not isinstance(crop_to[1], int):
raise TypeError("Incorrect type of the crop_size!")
raise TypeError("Incorrect type of the crop_to!")

if isinstance(crop_to, int):
crop_to = (crop_to, crop_to)

self.crop_size = crop_to
self.crop_to = crop_to
self.crop_mode = crop_mode

def sample_transform(self, data: DataContainer):
h, w = super(Crop, self).sample_transform(data)
if self.crop_size is not None:
if self.crop_size[0] > w or self.crop_size[1] > h:
if self.crop_to is not None:
if self.crop_to[0] > w or self.crop_to[1] > h:
raise ValueError

if self.crop_mode == "r":
self.state_dict["x"] = int(random.random() * (w - self.crop_size[0]))
self.state_dict["y"] = int(random.random() * (h - self.crop_size[1]))
self.state_dict["x"] = int(random.random() * (w - self.crop_to[0]))
self.state_dict["y"] = int(random.random() * (h - self.crop_to[1]))

else:
self.state_dict["x"] = w // 2 - self.crop_size[0] // 2
self.state_dict["y"] = h // 2 - self.crop_size[1] // 2
self.state_dict["x"] = w // 2 - self.crop_to[0] // 2
self.state_dict["y"] = h // 2 - self.crop_to[1] // 2

def __crop_img_or_mask(self, img_mask):
if self.crop_size is not None:
if self.crop_to is not None:
return img_mask[
self.state_dict["y"] : self.state_dict["y"] + self.crop_size[1],
self.state_dict["x"] : self.state_dict["x"] + self.crop_size[0],
self.state_dict["y"] : self.state_dict["y"] + self.crop_to[1],
self.state_dict["x"] : self.state_dict["x"] + self.crop_to[0],
]
return img_mask

Expand All @@ -748,15 +751,15 @@ def _apply_labels(self, labels, settings: dict):
return labels

def _apply_pts(self, pts: Keypoints, settings: dict):
if self.crop_size is None:
if self.crop_to is None:
return pts
pts_data = pts.data.copy()
x, y = self.state_dict["x"], self.state_dict["y"]

pts_data[:, 0] -= x
pts_data[:, 1] -= y

return Keypoints(pts_data, self.crop_size[1], self.crop_size[0])
return Keypoints(pts_data, self.crop_to[1], self.crop_to[0])


class Noise(BaseTransform):
Expand Down