-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[prototype] Switch to spatial_size
#6736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
0e2240c to
973fe25
Compare
spatial_sizespatial_size
datumbox
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding comments to places where they didn't happen automatically with the IDE:
| def get_num_channels_video(video: torch.Tensor) -> int: | ||
| return get_num_channels_image_tensor(video) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addition of get_num_channels_video kernel.
| def get_spatial_size_video(video: torch.Tensor) -> List[int]: | ||
| return get_spatial_size_image_tensor(video) | ||
|
|
||
|
|
||
| def get_spatial_size_mask(mask: torch.Tensor) -> List[int]: | ||
| return get_spatial_size_image_tensor(mask) | ||
|
|
||
|
|
||
| @torch.jit.unused | ||
| def get_spatial_size_bounding_box(bounding_box: features.BoundingBox) -> List[int]: | ||
| return list(bounding_box.spatial_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addition of the get_spatial_size_* kernels. The one of BBox can't have a JIT-scriptable implementation as it relies on Tensor Subclassing to retrieve this info.
| elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)): | ||
| return list(inpt.spatial_size) | ||
| else: | ||
| return get_spatial_size_image_pil(inpt) | ||
| return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactoring to avoid the getattr idiom. After that, mypy complains for the PIL kernel. It's unclear to be why it thinks we return Any. The get_spatial_size_video returns a List[int].
| lam = float(self._dist.sample(())) | ||
|
|
||
| _, H, W = query_chw(sample) | ||
| H, W = query_hw(sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the query_chw in favour of query_hw where possible. This happens in multiple places in the code-base.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method was renamed to query_spatial_size after #6736 (review)
|
|
||
| def _get_params(self, sample: Any) -> Dict[str, Any]: | ||
| num_channels, _, _ = query_chw(sample) | ||
| num_channels, *_ = query_chw(sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't introduce yet another method for extracting channels only. This is indeed less elegant but doesn't introduce any limitations as the input is required to have channels. This happens in one more place in the codebase.
| def query_hw(sample: Any) -> Tuple[int, int]: | ||
| flat_sample, _ = tree_flatten(sample) | ||
| hws = { | ||
| tuple(get_spatial_size(item)) | ||
| for item in flat_sample | ||
| if isinstance(item, (features.Image, PIL.Image.Image, features.Video, features.Mask, features.BoundingBox)) | ||
| or features.is_simple_tensor(item) | ||
| } | ||
| if not hws: | ||
| raise TypeError("No image, video, mask or bounding box was found in the sample") | ||
| elif len(hws) > 1: | ||
| raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(hws))}") | ||
| h, w = hws.pop() | ||
| return h, w |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of code duplication with query_chw. The two methods differ on the callable, the checked types, the error messages and the return type. I was tempted to write something that passes a callable and tries to reduce duplicate code but it become unnecessarily complex. Happy to implement other approaches if you have better ideas.
|
I got a few failures on Windows. They don't look like related at a first glance but then this PR touches too many things, so I'm not 100% sure. I'll check again tomorrow. |
vfdev-5
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK to me, thanks @datumbox !
Just a minor suggestion to rename query_hw to query_spatial_size... (not blocking)
Summary: * Change `image_size` to `spatial_size` * Fix linter * Fixing more tests. * Adding get_num_channels_video and get_spatial_size_* kernels for video, masks and bboxes. * Refactor get_spatial_size * Reduce the usage of `query_chw` where possible * Rename `query_chw` to `query_spatial_size` * Adding `get_num_frames` dispatcher and kernel. * Adding jit-scriptability tests Reviewed By: NicolasHug Differential Revision: D40427485 fbshipit-source-id: 2401fe20877177459fe23181655c9cf429cb0cc5
This PR:
image_sizetospatial_sizeeverywhere in the code-baseget_num_channels_videoandget_spatial_size_*kernels for videos, masks and bboxesget_num_framesdispatcher andget_num_frames_videokernel for JIT.query_chwto make things work with bboxes and masks