Skip to content

Commit fc838ad

Browse files
authored
Add deterministic, pure-Python roi_align implementation (#7587)
Signed-off-by: Edward Z. Yang <[email protected]>
1 parent a557918 commit fc838ad

File tree

2 files changed

+227
-13
lines changed

2 files changed

+227
-13
lines changed

test/test_ops.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@
1919
from torchvision.models.feature_extraction import get_graph_node_names
2020

2121

22+
# Context manager for setting deterministic flag and automatically
23+
# resetting it to its original value
24+
class DeterministicGuard:
25+
def __init__(self, deterministic, *, warn_only=False):
26+
self.deterministic = deterministic
27+
self.warn_only = warn_only
28+
29+
def __enter__(self):
30+
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
31+
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
32+
torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)
33+
34+
def __exit__(self, exception_type, exception_value, traceback):
35+
torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore)
36+
37+
2238
class RoIOpTesterModuleWrapper(nn.Module):
2339
def __init__(self, obj):
2440
super().__init__()
@@ -83,7 +99,7 @@ class RoIOpTester(ABC):
8399

84100
@pytest.mark.parametrize("device", cpu_and_gpu())
85101
@pytest.mark.parametrize("contiguous", (True, False))
86-
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs):
102+
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs):
87103
x_dtype = self.dtype if x_dtype is None else x_dtype
88104
rois_dtype = self.dtype if rois_dtype is None else rois_dtype
89105
pool_size = 5
@@ -99,7 +115,8 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar
99115
)
100116

101117
pool_h, pool_w = pool_size, pool_size
102-
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
118+
with DeterministicGuard(deterministic):
119+
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
103120
# the following should be true whether we're running an autocast test or not.
104121
assert y.dtype == x.dtype
105122
gt_y = self.expected_fn(
@@ -140,7 +157,7 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa
140157
@pytest.mark.parametrize("seed", range(10))
141158
@pytest.mark.parametrize("device", cpu_and_gpu())
142159
@pytest.mark.parametrize("contiguous", (True, False))
143-
def test_backward(self, seed, device, contiguous):
160+
def test_backward(self, seed, device, contiguous, deterministic=False):
144161
torch.random.manual_seed(seed)
145162
pool_size = 2
146163
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
@@ -155,7 +172,9 @@ def func(z):
155172

156173
script_func = self.get_script_fn(rois, pool_size)
157174

158-
gradcheck(func, (x,))
175+
with DeterministicGuard(deterministic):
176+
gradcheck(func, (x,))
177+
159178
gradcheck(script_func, (x,))
160179

161180
@needs_cuda
@@ -384,7 +403,6 @@ def expected_fn(
384403
grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
385404

386405
for channel in range(0, n_channels):
387-
388406
val = 0
389407
for iy in range(0, grid_h):
390408
y = start_h + (iy + 0.5) * bin_h / grid_h
@@ -402,21 +420,44 @@ def test_boxes_shape(self):
402420
@pytest.mark.parametrize("aligned", (True, False))
403421
@pytest.mark.parametrize("device", cpu_and_gpu())
404422
@pytest.mark.parametrize("contiguous", (True, False))
405-
def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=None):
423+
@pytest.mark.parametrize("deterministic", (True, False))
424+
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None):
425+
if deterministic and device == "cpu":
426+
pytest.skip("cpu is always deterministic, don't retest")
406427
super().test_forward(
407-
device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned
428+
device=device,
429+
contiguous=contiguous,
430+
deterministic=deterministic,
431+
x_dtype=x_dtype,
432+
rois_dtype=rois_dtype,
433+
aligned=aligned,
408434
)
409435

410436
@needs_cuda
411437
@pytest.mark.parametrize("aligned", (True, False))
438+
@pytest.mark.parametrize("deterministic", (True, False))
412439
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
413440
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
414-
def test_autocast(self, aligned, x_dtype, rois_dtype):
441+
def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
415442
with torch.cuda.amp.autocast():
416443
self.test_forward(
417-
torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype
444+
torch.device("cuda"),
445+
contiguous=False,
446+
deterministic=deterministic,
447+
aligned=aligned,
448+
x_dtype=x_dtype,
449+
rois_dtype=rois_dtype,
418450
)
419451

452+
@pytest.mark.parametrize("seed", range(10))
453+
@pytest.mark.parametrize("device", cpu_and_gpu())
454+
@pytest.mark.parametrize("contiguous", (True, False))
455+
@pytest.mark.parametrize("deterministic", (True, False))
456+
def test_backward(self, seed, device, contiguous, deterministic):
457+
if deterministic and device == "cpu":
458+
pytest.skip("cpu is always deterministic, don't retest")
459+
super().test_backward(seed, device, contiguous, deterministic)
460+
420461
def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
421462
rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
422463
rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index
@@ -978,7 +1019,6 @@ def test_compare_cpu_cuda_grads(self, contiguous):
9781019
weight = init_weight
9791020

9801021
for d in ["cpu", "cuda"]:
981-
9821022
out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
9831023
out.mean().backward()
9841024
if true_cpu_grads is None:
@@ -1374,7 +1414,6 @@ class TestGeneralizedBoxIouLoss:
13741414
@pytest.mark.parametrize("device", cpu_and_gpu())
13751415
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
13761416
def test_giou_loss(self, dtype, device):
1377-
13781417
box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
13791418

13801419
# Identical boxes should have loss of 0

torchvision/ops/roi_align.py

Lines changed: 177 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,188 @@
11
from typing import List, Union
22

33
import torch
4+
import torch._dynamo
45
import torch.fx
56
from torch import nn, Tensor
67
from torch.jit.annotations import BroadcastingList2
78
from torch.nn.modules.utils import _pair
8-
from torchvision.extension import _assert_has_ops
9+
from torchvision.extension import _assert_has_ops, _has_ops
910

1011
from ..utils import _log_api_usage_once
1112
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
1213

1314

15+
# NB: all inputs are tensors
16+
def _bilinear_interpolate(
17+
input, # [N, C, H, W]
18+
roi_batch_ind, # [K]
19+
y, # [K, PH, IY]
20+
x, # [K, PW, IX]
21+
ymask, # [K, IY]
22+
xmask, # [K, IX]
23+
):
24+
_, channels, height, width = input.size()
25+
26+
# deal with inverse element out of feature map boundary
27+
y = y.clamp(min=0)
28+
x = x.clamp(min=0)
29+
y_low = y.int()
30+
x_low = x.int()
31+
y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
32+
y_low = torch.where(y_low >= height - 1, height - 1, y_low)
33+
y = torch.where(y_low >= height - 1, y.to(input.dtype), y)
34+
35+
x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
36+
x_low = torch.where(x_low >= width - 1, width - 1, x_low)
37+
x = torch.where(x_low >= width - 1, x.to(input.dtype), x)
38+
39+
ly = y - y_low
40+
lx = x - x_low
41+
hy = 1.0 - ly
42+
hx = 1.0 - lx
43+
44+
# do bilinear interpolation, but respect the masking!
45+
# TODO: It's possible the masking here is unnecessary if y and
46+
# x were clamped appropriately; hard to tell
47+
def masked_index(
48+
y, # [K, PH, IY]
49+
x, # [K, PW, IX]
50+
):
51+
if ymask is not None:
52+
assert xmask is not None
53+
y = torch.where(ymask[:, None, :], y, 0)
54+
x = torch.where(xmask[:, None, :], x, 0)
55+
return input[
56+
roi_batch_ind[:, None, None, None, None, None],
57+
torch.arange(channels, device=input.device)[None, :, None, None, None, None],
58+
y[:, None, :, None, :, None], # prev [K, PH, IY]
59+
x[:, None, None, :, None, :], # prev [K, PW, IX]
60+
] # [K, C, PH, PW, IY, IX]
61+
62+
v1 = masked_index(y_low, x_low)
63+
v2 = masked_index(y_low, x_high)
64+
v3 = masked_index(y_high, x_low)
65+
v4 = masked_index(y_high, x_high)
66+
67+
# all ws preemptively [K, C, PH, PW, IY, IX]
68+
def outer_prod(y, x):
69+
return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]
70+
71+
w1 = outer_prod(hy, hx)
72+
w2 = outer_prod(hy, lx)
73+
w3 = outer_prod(ly, hx)
74+
w4 = outer_prod(ly, lx)
75+
76+
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
77+
return val
78+
79+
80+
# TODO: this doesn't actually cache
81+
# TODO: main library should make this easier to do
82+
def maybe_cast(tensor):
83+
if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double:
84+
return tensor.float()
85+
else:
86+
return tensor
87+
88+
89+
# This is a slow but pure Python and differentiable implementation of
90+
# roi_align. It potentially is a good basis for Inductor compilation
91+
# (but I have not benchmarked it) but today it is solely used for the
92+
# fact that its backwards can be implemented deterministically,
93+
# which is needed for the PT2 benchmark suite.
94+
#
95+
# It is transcribed directly off of the roi_align CUDA kernel, see
96+
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
97+
@torch._dynamo.allow_in_graph
98+
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
99+
orig_dtype = input.dtype
100+
101+
input = maybe_cast(input)
102+
rois = maybe_cast(rois)
103+
104+
_, _, height, width = input.size()
105+
106+
ph = torch.arange(pooled_height, device=input.device) # [PH]
107+
pw = torch.arange(pooled_width, device=input.device) # [PW]
108+
109+
# input: [N, C, H, W]
110+
# rois: [K, 5]
111+
112+
roi_batch_ind = rois[:, 0].int() # [K]
113+
offset = 0.5 if aligned else 0.0
114+
roi_start_w = rois[:, 1] * spatial_scale - offset # [K]
115+
roi_start_h = rois[:, 2] * spatial_scale - offset # [K]
116+
roi_end_w = rois[:, 3] * spatial_scale - offset # [K]
117+
roi_end_h = rois[:, 4] * spatial_scale - offset # [K]
118+
119+
roi_width = roi_end_w - roi_start_w # [K]
120+
roi_height = roi_end_h - roi_start_h # [K]
121+
if not aligned:
122+
roi_width = torch.clamp(roi_width, min=1.0) # [K]
123+
roi_height = torch.clamp(roi_height, min=1.0) # [K]
124+
125+
bin_size_h = roi_height / pooled_height # [K]
126+
bin_size_w = roi_width / pooled_width # [K]
127+
128+
exact_sampling = sampling_ratio > 0
129+
130+
roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) # scalar or [K]
131+
roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) # scalar or [K]
132+
133+
"""
134+
iy, ix = dims(2)
135+
"""
136+
137+
if exact_sampling:
138+
count = max(roi_bin_grid_h * roi_bin_grid_w, 1) # scalar
139+
iy = torch.arange(roi_bin_grid_h, device=input.device) # [IY]
140+
ix = torch.arange(roi_bin_grid_w, device=input.device) # [IX]
141+
ymask = None
142+
xmask = None
143+
else:
144+
count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) # [K]
145+
# When doing adaptive sampling, the number of samples we need to do
146+
# is data-dependent based on how big the ROIs are. This is a bit
147+
# awkward because first-class dims can't actually handle this.
148+
# So instead, we inefficiently suppose that we needed to sample ALL
149+
# the points and mask out things that turned out to be unnecessary
150+
iy = torch.arange(height, device=input.device) # [IY]
151+
ix = torch.arange(width, device=input.device) # [IX]
152+
ymask = iy[None, :] < roi_bin_grid_h[:, None] # [K, IY]
153+
xmask = ix[None, :] < roi_bin_grid_w[:, None] # [K, IX]
154+
155+
def from_K(t):
156+
return t[:, None, None]
157+
158+
y = (
159+
from_K(roi_start_h)
160+
+ ph[None, :, None] * from_K(bin_size_h)
161+
+ (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h)
162+
) # [K, PH, IY]
163+
x = (
164+
from_K(roi_start_w)
165+
+ pw[None, :, None] * from_K(bin_size_w)
166+
+ (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w)
167+
) # [K, PW, IX]
168+
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX]
169+
170+
# Mask out samples that weren't actually adaptively needed
171+
if not exact_sampling:
172+
val = torch.where(ymask[:, None, None, None, :, None], val, 0)
173+
val = torch.where(xmask[:, None, None, None, None, :], val, 0)
174+
175+
output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW]
176+
if isinstance(count, torch.Tensor):
177+
output /= count[:, None, None, None]
178+
else:
179+
output /= count
180+
181+
output = output.to(orig_dtype)
182+
183+
return output
184+
185+
14186
@torch.fx.wrap
15187
def roi_align(
16188
input: Tensor,
@@ -54,12 +226,15 @@ def roi_align(
54226
"""
55227
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
56228
_log_api_usage_once(roi_align)
57-
_assert_has_ops()
58229
check_roi_boxes_shape(boxes)
59230
rois = boxes
60231
output_size = _pair(output_size)
61232
if not isinstance(rois, torch.Tensor):
62233
rois = convert_boxes_to_roi_format(rois)
234+
if not torch.jit.is_scripting():
235+
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda):
236+
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
237+
_assert_has_ops()
63238
return torch.ops.torchvision.roi_align(
64239
input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
65240
)

0 commit comments

Comments
 (0)