Skip to content

Commit 33b2149

Browse files
authored
Merge pull request #62 from imelekhov/master
bugfix: incorrect serialization (cropping)
2 parents 770e397 + 3ef7cf4 commit 33b2149

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

solt/transforms/_transforms.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -700,40 +700,43 @@ def __init__(self, crop_to=None, crop_mode="c"):
700700
super(Crop, self).__init__(p=1, data_indices=None)
701701

702702
if crop_to is not None:
703-
if not isinstance(crop_to, int) and not isinstance(crop_to, tuple):
704-
raise TypeError("Argument crop_size has an incorrect type!")
703+
if not isinstance(crop_to, (int, tuple, list)):
704+
raise TypeError("Argument crop_to has an incorrect type!")
705705
if crop_mode not in ALLOWED_CROPS:
706706
raise ValueError("Argument crop_mode has an incorrect type!")
707707

708+
if isinstance(crop_to, list):
709+
crop_to = tuple(crop_to)
710+
708711
if isinstance(crop_to, tuple):
709712
if not isinstance(crop_to[0], int) or not isinstance(crop_to[1], int):
710-
raise TypeError("Incorrect type of the crop_size!")
713+
raise TypeError("Incorrect type of the crop_to!")
711714

712715
if isinstance(crop_to, int):
713716
crop_to = (crop_to, crop_to)
714717

715-
self.crop_size = crop_to
718+
self.crop_to = crop_to
716719
self.crop_mode = crop_mode
717720

718721
def sample_transform(self, data: DataContainer):
719722
h, w = super(Crop, self).sample_transform(data)
720-
if self.crop_size is not None:
721-
if self.crop_size[0] > w or self.crop_size[1] > h:
723+
if self.crop_to is not None:
724+
if self.crop_to[0] > w or self.crop_to[1] > h:
722725
raise ValueError
723726

724727
if self.crop_mode == "r":
725-
self.state_dict["x"] = int(random.random() * (w - self.crop_size[0]))
726-
self.state_dict["y"] = int(random.random() * (h - self.crop_size[1]))
728+
self.state_dict["x"] = int(random.random() * (w - self.crop_to[0]))
729+
self.state_dict["y"] = int(random.random() * (h - self.crop_to[1]))
727730

728731
else:
729-
self.state_dict["x"] = w // 2 - self.crop_size[0] // 2
730-
self.state_dict["y"] = h // 2 - self.crop_size[1] // 2
732+
self.state_dict["x"] = w // 2 - self.crop_to[0] // 2
733+
self.state_dict["y"] = h // 2 - self.crop_to[1] // 2
731734

732735
def __crop_img_or_mask(self, img_mask):
733-
if self.crop_size is not None:
736+
if self.crop_to is not None:
734737
return img_mask[
735-
self.state_dict["y"] : self.state_dict["y"] + self.crop_size[1],
736-
self.state_dict["x"] : self.state_dict["x"] + self.crop_size[0],
738+
self.state_dict["y"] : self.state_dict["y"] + self.crop_to[1],
739+
self.state_dict["x"] : self.state_dict["x"] + self.crop_to[0],
737740
]
738741
return img_mask
739742

@@ -748,15 +751,15 @@ def _apply_labels(self, labels, settings: dict):
748751
return labels
749752

750753
def _apply_pts(self, pts: Keypoints, settings: dict):
751-
if self.crop_size is None:
754+
if self.crop_to is None:
752755
return pts
753756
pts_data = pts.data.copy()
754757
x, y = self.state_dict["x"], self.state_dict["y"]
755758

756759
pts_data[:, 0] -= x
757760
pts_data[:, 1] -= y
758761

759-
return Keypoints(pts_data, self.crop_size[1], self.crop_size[0])
762+
return Keypoints(pts_data, self.crop_to[1], self.crop_to[0])
760763

761764

762765
class Noise(BaseTransform):

0 commit comments

Comments
 (0)