diff --git a/cpp/src/neighbors/detail/ann_utils.cuh b/cpp/src/neighbors/detail/ann_utils.cuh index 29f790ec51..652d41c853 100644 --- a/cpp/src/neighbors/detail/ann_utils.cuh +++ b/cpp/src/neighbors/detail/ann_utils.cuh @@ -63,14 +63,9 @@ struct pointer_residency_count { auto [on_device, on_host] = pointer_residency_count::run(ptrs...); cudaPointerAttributes attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); - switch (attr.type) { - case cudaMemoryTypeUnregistered: return std::make_tuple(on_device, on_host + 1); - case cudaMemoryTypeHost: - return std::make_tuple(on_device + int(attr.devicePointer == ptr), on_host + 1); - case cudaMemoryTypeDevice: return std::make_tuple(on_device + 1, on_host); - case cudaMemoryTypeManaged: return std::make_tuple(on_device + 1, on_host + 1); - default: return std::make_tuple(on_device, on_host); - } + if (attr.devicePointer || attr.type == cudaMemoryTypeDevice) { ++on_device; } + if (attr.hostPointer || attr.type == cudaMemoryTypeUnregistered) { ++on_host; } + return std::make_tuple(on_device, on_host); } };