-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Add deterministic, pure-Python roi_align implementation #7587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
02203b2
f114d27
9e5cb7b
f7612bf
4d64bc3
54eecc5
2121bb0
0b51bd5
296bca6
41c1ff6
d7f4e80
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,149 @@ | ||
| from typing import List, Union | ||
|
|
||
| import torch | ||
| import torch._dynamo | ||
| import torch.fx | ||
| from torch import nn, Tensor | ||
| from torch.jit.annotations import BroadcastingList2 | ||
| from torch.nn.modules.utils import _pair | ||
| from torchvision.extension import _assert_has_ops | ||
| from torchvision.extension import _has_ops | ||
|
|
||
| from ..utils import _log_api_usage_once | ||
| from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format | ||
|
|
||
|
|
||
| # NB: all tensor inputs | ||
ezyang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def _bilinear_interpolate(input, roi_batch_ind, c, height, width, y, x, ymask, xmask): | ||
| # deal with inverse element out of feature map boundary | ||
| y = y.clamp(min=0) | ||
| x = x.clamp(min=0) | ||
| y_low = y.int() | ||
| x_low = x.int() | ||
| y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1) | ||
| y_low = torch.where(y_low >= height - 1, height - 1, y_low) | ||
| y = torch.where(y_low >= height - 1, y.to(input.dtype), y) | ||
|
|
||
| x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1) | ||
| x_low = torch.where(x_low >= width - 1, width - 1, x_low) | ||
| x = torch.where(x_low >= width - 1, x.to(input.dtype), x) | ||
|
|
||
| ly = y - y_low | ||
| lx = x - x_low | ||
| hy = 1.0 - ly | ||
| hx = 1.0 - lx | ||
|
|
||
| # do bilinear interpolation, but respect the masking! | ||
| # TODO: It's possible the masking here is unnecessary if y and | ||
| # x were clamped appropriately; hard to tell | ||
| def masked_index(y, x): | ||
| if ymask is not None: | ||
| assert xmask is not None | ||
| y = torch.where(ymask, y, 0) | ||
| x = torch.where(xmask, x, 0) | ||
| return input[roi_batch_ind, c, y, x] | ||
|
|
||
| v1 = masked_index(y_low, x_low) | ||
| v2 = masked_index(y_low, x_high) | ||
| v3 = masked_index(y_high, x_low) | ||
| v4 = masked_index(y_high, x_high) | ||
| w1 = hy * hx | ||
| w2 = hy * lx | ||
| w3 = ly * hx | ||
| w4 = ly * lx | ||
|
|
||
| val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 | ||
| return val | ||
|
|
||
|
|
||
| # TODO: this doesn't actually cache | ||
| # TODO: main library should make this easier to do | ||
| def maybe_cast(tensor): | ||
| if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double: | ||
| return tensor.float() | ||
| else: | ||
| return tensor | ||
|
|
||
|
|
||
| # This is a slow but pure Python and differentiable implementation of | ||
| # roi_align. It potentially is a good basis for Inductor compilation | ||
| # (but I have not benchmarked it) but today it is solely used for the | ||
| # fact that its backwards can be implemented deterministically. | ||
ezyang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # | ||
| # It is transcribed directly off of the roi_align CUDA kernel, see | ||
| # https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266 | ||
| @torch._dynamo.allow_in_graph | ||
| def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): | ||
| from functorch.dim import dims | ||
|
||
|
|
||
| orig_dtype = input.dtype | ||
|
|
||
| input = maybe_cast(input) | ||
| rois = maybe_cast(rois) | ||
|
|
||
| _, _, height, width = input.size() | ||
|
|
||
| n, c, ph, pw = dims(4) | ||
| ph.size = pooled_height | ||
| pw.size = pooled_width | ||
| offset_rois = rois[n] | ||
| roi_batch_ind = offset_rois[0].int() | ||
| offset = 0.5 if aligned else 0.0 | ||
| roi_start_w = offset_rois[1] * spatial_scale - offset | ||
| roi_start_h = offset_rois[2] * spatial_scale - offset | ||
| roi_end_w = offset_rois[3] * spatial_scale - offset | ||
| roi_end_h = offset_rois[4] * spatial_scale - offset | ||
|
|
||
| roi_width = roi_end_w - roi_start_w | ||
| roi_height = roi_end_h - roi_start_h | ||
| if not aligned: | ||
| roi_width = torch.clamp(roi_width, min=1.0) | ||
| roi_height = torch.clamp(roi_height, min=1.0) | ||
|
|
||
| bin_size_h = roi_height / pooled_height | ||
| bin_size_w = roi_width / pooled_width | ||
|
|
||
| exact_sampling = sampling_ratio > 0 | ||
|
|
||
| roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) | ||
| roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) | ||
|
|
||
| iy, ix = dims(2) | ||
|
|
||
| if exact_sampling: | ||
| count = max(roi_bin_grid_h * roi_bin_grid_w, 1) | ||
| iy.size = roi_bin_grid_h | ||
| ix.size = roi_bin_grid_w | ||
| ymask = None | ||
| xmask = None | ||
| else: | ||
| count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) | ||
| # When doing adaptive sampling, the number of samples we need to do | ||
| # is data-dependent based on how big the ROIs are. This is a bit | ||
| # awkward because first-class dims can't actually handle this. | ||
| # So instead, we inefficiently suppose that we needed to sample ALL | ||
| # the points and mask out things that turned out to be unnecessary | ||
| iy.size = height | ||
| ix.size = width | ||
| ymask = iy < roi_bin_grid_h | ||
| xmask = ix < roi_bin_grid_w | ||
|
|
||
| y = roi_start_h + ph * bin_size_h + (iy + 0.5) * bin_size_h / roi_bin_grid_h | ||
| x = roi_start_w + pw * bin_size_w + (ix + 0.5) * bin_size_w / roi_bin_grid_w | ||
| val = _bilinear_interpolate(input, roi_batch_ind, c, height, width, y, x, ymask, xmask) | ||
|
|
||
| # Mask out samples that weren't actually adaptively needed | ||
| if not exact_sampling: | ||
| val = torch.where(ymask, val, 0) | ||
| val = torch.where(xmask, val, 0) | ||
|
|
||
| output = val.sum((iy, ix)) | ||
| output /= count | ||
|
|
||
| output = output.to(orig_dtype) | ||
|
|
||
| return output.order(n, c, ph, pw) | ||
|
|
||
|
|
||
| @torch.fx.wrap | ||
| def roi_align( | ||
| input: Tensor, | ||
|
|
@@ -54,12 +187,14 @@ def roi_align( | |
| """ | ||
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | ||
| _log_api_usage_once(roi_align) | ||
| _assert_has_ops() | ||
| check_roi_boxes_shape(boxes) | ||
| rois = boxes | ||
| output_size = _pair(output_size) | ||
| if not isinstance(rois, torch.Tensor): | ||
| rois = convert_boxes_to_roi_format(rois) | ||
| if not torch.jit.is_scripting(): | ||
| if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda): | ||
| return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) | ||
ezyang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return torch.ops.torchvision.roi_align( | ||
| input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.