Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_extensions():

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob(
os.path.join(extensions_dir, "ops", "*.cpp")
)
) + glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
source_cpu = (
glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
Expand Down Expand Up @@ -184,8 +184,6 @@ def get_extensions():
else:
source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))

source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))

sources = main_file + source_cpu
extension = CppExtension

Expand Down
35 changes: 35 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, determinist
tol = 5e-3
else:
tol = 4e-3

if x_dtype == torch.bfloat16:
tol = 5e-3

pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS operations.
Expand Down Expand Up @@ -493,6 +496,21 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
rois_dtype=rois_dtype,
)

@pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
@pytest.mark.parametrize("x_dtype", (torch.float, torch.bfloat16))
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.bfloat16))
def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
with torch.cpu.amp.autocast():
self.test_forward(
torch.device("cpu"),
contiguous=False,
deterministic=deterministic,
aligned=aligned,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
Expand Down Expand Up @@ -751,6 +769,17 @@ def test_qnms(self, iou, scale, zero_point):

torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_nms_cpu(self, iou, dtype=torch.float):
err_msg = "NMS incompatible between float and {dtype} for IoU={}"

boxes, scores = self._create_tensors_with_iou(1000, iou)
r_ref = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou)
r_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou)

is_eq = torch.allclose(r_ref, r_dtype)
assert is_eq, err_msg.format(iou)

@pytest.mark.parametrize(
"device",
(
Expand Down Expand Up @@ -782,6 +811,12 @@ def test_autocast(self, iou, dtype):
with torch.cuda.amp.autocast():
self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16))
def test_autocast_cpu(self, iou, dtype):
with torch.cpu.amp.autocast():
self.test_nms_cpu(iou=iou, dtype=dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be enough to just call test_nms_ref() here (perhaps with slight modifications) ?

If we really need a specific test as in test_nms_cpu then maybe we can just inline it below instead of having test_nms_cpu as a standalone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified test_nms_ref to check the case where the dtype is torch.bfloat16 on cpu. Could you please check if it looks good to you?


@pytest.mark.parametrize(
"device",
(
Expand Down
14 changes: 10 additions & 4 deletions torchvision/csrc/ops/autocast/nms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,27 @@ namespace ops {

namespace {

template<c10::DispatchKey autocast_key, c10::DeviceType device_type>
at::Tensor nms_autocast(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);

return nms(
at::autocast::cached_cast(at::kFloat, dets),
at::autocast::cached_cast(at::kFloat, scores),
at::autocast::cached_cast(at::kFloat, dets, device_type),
at::autocast::cached_cast(at::kFloat, scores, device_type),
iou_threshold);
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast));
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN((nms_autocast<c10::DispatchKey::Autocast, c10::DeviceType::CUDA>)));
}

TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN((nms_autocast<c10::DispatchKey::AutocastCPU, c10::DeviceType::CPU>)));
}

} // namespace ops
Expand Down
15 changes: 11 additions & 4 deletions torchvision/csrc/ops/autocast/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace ops {

namespace {

template<c10::DispatchKey autocast_key, c10::DeviceType device_type>
at::Tensor roi_align_autocast(
const at::Tensor& input,
const at::Tensor& rois,
Expand All @@ -17,10 +18,10 @@ at::Tensor roi_align_autocast(
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);
return roi_align(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, rois),
at::autocast::cached_cast(at::kFloat, input, device_type),
at::autocast::cached_cast(at::kFloat, rois, device_type),
spatial_scale,
pooled_height,
pooled_width,
Expand All @@ -34,7 +35,13 @@ at::Tensor roi_align_autocast(
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autocast));
TORCH_FN((roi_align_autocast<c10::DispatchKey::Autocast, c10::DeviceType::CUDA>)));
}

TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN((roi_align_autocast<c10::DispatchKey::AutocastCPU, c10::DeviceType::CPU>)));
}

} // namespace ops
Expand Down