Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "torchvision", "csrc")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob(
os.path.join(extensions_dir, "ops", "*.cpp")
main_file = (
glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
)
source_cpu = (
glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
Expand Down Expand Up @@ -184,8 +186,6 @@ def get_extensions():
else:
source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))

source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))

sources = main_file + source_cpu
extension = CppExtension

Expand Down
30 changes: 29 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, determinist
tol = 5e-3
else:
tol = 4e-3
elif x_dtype == torch.bfloat16:
tol = 5e-3

pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS operations.
Expand Down Expand Up @@ -493,6 +495,21 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
rois_dtype=rois_dtype,
)

@pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
@pytest.mark.parametrize("x_dtype", (torch.float, torch.bfloat16))
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.bfloat16))
def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
with torch.cpu.amp.autocast():
self.test_forward(
torch.device("cpu"),
contiguous=False,
deterministic=deterministic,
aligned=aligned,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
Expand Down Expand Up @@ -712,14 +729,19 @@ def _create_tensors_with_iou(self, N, iou_thresh):

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("seed", range(10))
def test_nms_ref(self, iou, seed):
def test_nms_ref(self, iou, seed, dtype=torch.float):
torch.random.manual_seed(seed)
err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self._reference_nms(boxes, scores, iou)
keep = ops.nms(boxes, scores, iou)
torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))

if dtype == torch.bfloat16:
keep_ref_float = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou)
keep_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou)
torch.testing.assert_close(keep_ref_float, keep_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update and sorry for the noise - shouldn't we instead assert that we get equivalent results when autocast is on and when it's off? e.g. something like

    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
    @pytest.mark.parametrize("seed", range(10))
    def test_autocast_cpu(self, iou, seed):
        torch.random.manual_seed(seed)
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        keep = ops.nms(boxes, scores, iou)
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            keep_autocast = ops.nms(boxes, scores, iou)
        torch.testing.assert_close(keep, keep_autocast)

The test for roi_align calls test_forward() which has a slightly different logic where instead we check our implem against a reference, both with autocast. Unfortunately that same test seems to fail for nms, so maybe we can just use the snippet above. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. Here are two concerns:

  1. Since nms is not in the white list of autocast, simply running the above UT won't be able to expose the issue that this PR is trying to fix, since the input to nms will still be fp32 in the above snippet. Instead, we need to provide an input in BF16 to nms under autocast, so that without this PR, it will throw RuntimeError: "nms_kernel" not implemented for 'BFloat16', while with this PR, autocast will convert the input back to FP32 so that it could run successfully.
  2. A modified version following item 1 will be like this:
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
    @pytest.mark.parametrize("seed", range(10))
    def test_autocast_cpu(self, iou, seed):
        torch.random.manual_seed(seed)
        boxes, scores = self._create_tensors_with_iou(1000, iou)
        keep = ops.nms(boxes, scores, iou)
        with torch.cpu.amp.autocast(dtype=torch.bfloat16):
            keep_autocast = ops.nms(boxes.to(torch.bfloat16), scores.to(torch.bfloat16), iou)
        torch.testing.assert_close(keep, keep_autocast)

The issue with this version is that, since boxes.to(torch.bfloat16) and scores.to(torch.bfloat16) has converted data from high precision to low precision, when directly comparing the result keep_autocast and keep, we'll meet:
AssertionError: The values for attribute 'shape' do not match: torch.Size([434]) != torch.Size([436])..
I used boxes.to(dtype).float() in the current implementation here to simulate the process of converting from high precision to low precision and convert it back to float to be the reference result to compare with the autocast result.


def test_nms_input_errors(self):
with pytest.raises(RuntimeError):
ops.nms(torch.rand(4), torch.rand(3), 0.5)
Expand Down Expand Up @@ -782,6 +804,12 @@ def test_autocast(self, iou, dtype):
with torch.cuda.amp.autocast():
self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16))
def test_autocast_cpu(self, iou, dtype):
with torch.cpu.amp.autocast():
self.test_nms_ref(iou=iou, seed=0, dtype=dtype)

@pytest.mark.parametrize(
"device",
(
Expand Down
20 changes: 16 additions & 4 deletions torchvision/csrc/ops/autocast/nms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,33 @@ namespace ops {

namespace {

template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
at::Tensor nms_autocast(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);

return nms(
at::autocast::cached_cast(at::kFloat, dets),
at::autocast::cached_cast(at::kFloat, scores),
at::autocast::cached_cast(at::kFloat, dets, device_type),
at::autocast::cached_cast(at::kFloat, scores, device_type),
iou_threshold);
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::nms"),
TORCH_FN(
(nms_autocast<c10::DispatchKey::Autocast, c10::DeviceType::CUDA>)));
}

TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::nms"),
TORCH_FN(
(nms_autocast<c10::DispatchKey::AutocastCPU, c10::DeviceType::CPU>)));
}

} // namespace ops
Expand Down
19 changes: 15 additions & 4 deletions torchvision/csrc/ops/autocast/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace ops {

namespace {

template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
at::Tensor roi_align_autocast(
const at::Tensor& input,
const at::Tensor& rois,
Expand All @@ -17,10 +18,10 @@ at::Tensor roi_align_autocast(
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);
return roi_align(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, rois),
at::autocast::cached_cast(at::kFloat, input, device_type),
at::autocast::cached_cast(at::kFloat, rois, device_type),
spatial_scale,
pooled_height,
pooled_width,
Expand All @@ -34,7 +35,17 @@ at::Tensor roi_align_autocast(
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autocast));
TORCH_FN((roi_align_autocast<
c10::DispatchKey::Autocast,
c10::DeviceType::CUDA>)));
}

TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN((roi_align_autocast<
c10::DispatchKey::AutocastCPU,
c10::DeviceType::CPU>)));
}

} // namespace ops
Expand Down