Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@
from torchvision.models.feature_extraction import get_graph_node_names


# Context manager for setting deterministic flag and automatically
# resetting it to its original value
class DeterministicGuard:
def __init__(self, deterministic, *, warn_only=False):
self.deterministic = deterministic
self.warn_only = warn_only

def __enter__(self):
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)

def __exit__(self, exception_type, exception_value, traceback):
torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore)


class RoIOpTesterModuleWrapper(nn.Module):
def __init__(self, obj):
super().__init__()
Expand Down Expand Up @@ -83,7 +99,7 @@ class RoIOpTester(ABC):

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False))
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs):
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs):
x_dtype = self.dtype if x_dtype is None else x_dtype
rois_dtype = self.dtype if rois_dtype is None else rois_dtype
pool_size = 5
Expand All @@ -99,7 +115,8 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar
)

pool_h, pool_w = pool_size, pool_size
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
with DeterministicGuard(deterministic):
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
# the following should be true whether we're running an autocast test or not.
assert y.dtype == x.dtype
gt_y = self.expected_fn(
Expand Down Expand Up @@ -140,7 +157,8 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False))
def test_backward(self, seed, device, contiguous):
@pytest.mark.parametrize("deterministic", (False,))
def test_backward(self, seed, device, contiguous, deterministic):
torch.random.manual_seed(seed)
pool_size = 2
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
Expand All @@ -155,7 +173,9 @@ def func(z):

script_func = self.get_script_fn(rois, pool_size)

gradcheck(func, (x,))
with DeterministicGuard(deterministic):
gradcheck(func, (x,))

gradcheck(script_func, (x,))

@needs_cuda
Expand Down Expand Up @@ -384,7 +404,6 @@ def expected_fn(
grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))

for channel in range(0, n_channels):

val = 0
for iy in range(0, grid_h):
y = start_h + (iy + 0.5) * bin_h / grid_h
Expand All @@ -402,21 +421,42 @@ def test_boxes_shape(self):
@pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False))
def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=None):
@pytest.mark.parametrize("deterministic", (True, False))
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None):
if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest")
super().test_forward(
device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned
device=device,
contiguous=contiguous,
deterministic=deterministic,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
aligned=aligned,
)

@needs_cuda
@pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
def test_autocast(self, aligned, x_dtype, rois_dtype):
def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
with torch.cuda.amp.autocast():
self.test_forward(
torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype
torch.device("cuda"),
contiguous=False,
deterministic=deterministic,
aligned=aligned,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
def test_backward(self, seed, device, contiguous, deterministic):
super().test_backward(seed, device, contiguous, deterministic)

def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index
Expand Down Expand Up @@ -978,7 +1018,6 @@ def test_compare_cpu_cuda_grads(self, contiguous):
weight = init_weight

for d in ["cpu", "cuda"]:

out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
out.mean().backward()
if true_cpu_grads is None:
Expand Down Expand Up @@ -1374,7 +1413,6 @@ class TestGeneralizedBoxIouLoss:
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_giou_loss(self, dtype, device):

box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)

# Identical boxes should have loss of 0
Expand Down
139 changes: 137 additions & 2 deletions torchvision/ops/roi_align.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,149 @@
from typing import List, Union

import torch
import torch._dynamo
import torch.fx
from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops
from torchvision.extension import _has_ops

from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format


# NB: all tensor inputs
def _bilinear_interpolate(input, roi_batch_ind, c, height, width, y, x, ymask, xmask):
# deal with inverse element out of feature map boundary
y = y.clamp(min=0)
x = x.clamp(min=0)
y_low = y.int()
x_low = x.int()
y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
y_low = torch.where(y_low >= height - 1, height - 1, y_low)
y = torch.where(y_low >= height - 1, y.to(input.dtype), y)

x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
x_low = torch.where(x_low >= width - 1, width - 1, x_low)
x = torch.where(x_low >= width - 1, x.to(input.dtype), x)

ly = y - y_low
lx = x - x_low
hy = 1.0 - ly
hx = 1.0 - lx

# do bilinear interpolation, but respect the masking!
# TODO: It's possible the masking here is unnecessary if y and
# x were clamped appropriately; hard to tell
def masked_index(y, x):
if ymask is not None:
assert xmask is not None
y = torch.where(ymask, y, 0)
x = torch.where(xmask, x, 0)
return input[roi_batch_ind, c, y, x]

v1 = masked_index(y_low, x_low)
v2 = masked_index(y_low, x_high)
v3 = masked_index(y_high, x_low)
v4 = masked_index(y_high, x_high)
w1 = hy * hx
w2 = hy * lx
w3 = ly * hx
w4 = ly * lx

val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val


# TODO: this doesn't actually cache
# TODO: main library should make this easier to do
def maybe_cast(tensor):
if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double:
return tensor.float()
else:
return tensor


# This is a slow but pure Python and differentiable implementation of
# roi_align. It potentially is a good basis for Inductor compilation
# (but I have not benchmarked it) but today it is solely used for the
# fact that its backwards can be implemented deterministically.
#
# It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@torch._dynamo.allow_in_graph
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
from functorch.dim import dims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to lazy import? Isn't the functorch namespace always available now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is supposed to always be available but @zou3519 mentioned to me that xplat mumble mumble doesn't have working torchdims build mumble? In any case, there is not much harm in making it lazy like this, so I went ahead and did it this way.

Copy link
Contributor

@zou3519 zou3519 May 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better reason that I remembered just now is that import functorch.dim monkey-patches torch and we really do not want to monkey patch torch by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For safety purposes, I hand-inserted all of the expands necessary to remove the first-class dims impl. We should still keep the first class dims version around for documentary purposes though.


orig_dtype = input.dtype

input = maybe_cast(input)
rois = maybe_cast(rois)

_, _, height, width = input.size()

n, c, ph, pw = dims(4)
ph.size = pooled_height
pw.size = pooled_width
offset_rois = rois[n]
roi_batch_ind = offset_rois[0].int()
offset = 0.5 if aligned else 0.0
roi_start_w = offset_rois[1] * spatial_scale - offset
roi_start_h = offset_rois[2] * spatial_scale - offset
roi_end_w = offset_rois[3] * spatial_scale - offset
roi_end_h = offset_rois[4] * spatial_scale - offset

roi_width = roi_end_w - roi_start_w
roi_height = roi_end_h - roi_start_h
if not aligned:
roi_width = torch.clamp(roi_width, min=1.0)
roi_height = torch.clamp(roi_height, min=1.0)

bin_size_h = roi_height / pooled_height
bin_size_w = roi_width / pooled_width

exact_sampling = sampling_ratio > 0

roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height)
roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width)

iy, ix = dims(2)

if exact_sampling:
count = max(roi_bin_grid_h * roi_bin_grid_w, 1)
iy.size = roi_bin_grid_h
ix.size = roi_bin_grid_w
ymask = None
xmask = None
else:
count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1)
# When doing adaptive sampling, the number of samples we need to do
# is data-dependent based on how big the ROIs are. This is a bit
# awkward because first-class dims can't actually handle this.
# So instead, we inefficiently suppose that we needed to sample ALL
# the points and mask out things that turned out to be unnecessary
iy.size = height
ix.size = width
ymask = iy < roi_bin_grid_h
xmask = ix < roi_bin_grid_w

y = roi_start_h + ph * bin_size_h + (iy + 0.5) * bin_size_h / roi_bin_grid_h
x = roi_start_w + pw * bin_size_w + (ix + 0.5) * bin_size_w / roi_bin_grid_w
val = _bilinear_interpolate(input, roi_batch_ind, c, height, width, y, x, ymask, xmask)

# Mask out samples that weren't actually adaptively needed
if not exact_sampling:
val = torch.where(ymask, val, 0)
val = torch.where(xmask, val, 0)

output = val.sum((iy, ix))
output /= count

output = output.to(orig_dtype)

return output.order(n, c, ph, pw)


@torch.fx.wrap
def roi_align(
input: Tensor,
Expand Down Expand Up @@ -54,12 +187,14 @@ def roi_align(
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(roi_align)
_assert_has_ops()
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting():
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda):
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
return torch.ops.torchvision.roi_align(
input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
)
Expand Down