Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 108 additions & 65 deletions cpp/src/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "ann_utils.cuh"
#include "cagra/device_common.hpp"
#include "cuvs/distance/distance.h"
#include "nn_descent_gnnd.hpp"

#include <cuvs/distance/distance.hpp>
Expand Down Expand Up @@ -285,6 +286,10 @@ RAFT_KERNEL preprocess_data_kernel(
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
output_data[list_id * dim + idx] =
(float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm);
} else if (metric == cuvs::distance::DistanceType::BitwiseHamming) {
int idx_for_byte = list_id * dim + idx; // uint8 or int8 data
uint8_t* output_bytes = reinterpret_cast<uint8_t*>(output_data);
output_bytes[idx_for_byte] = input_data[idx_for_byte];
} else { // L2Expanded or L2SqrtExpanded
output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx];
if (idx == 0) { l2_norms[list_id] = l2_norm; }
Expand Down Expand Up @@ -588,40 +593,46 @@ __launch_bounds__(BLOCK_SIZE)
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
wmma::fill_fragment(c_frag, 0.0);
for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) {
int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1)
? data_dim - step * TILE_COL_WIDTH
: TILE_COL_WIDTH;
if (metric != cuvs::distance::DistanceType::BitwiseHamming) {
Copy link
Copy Markdown
Contributor Author

@jinsolp jinsolp Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only diff here is wrapping the wmma operation part with if (metric != cuvs::distance::DistanceType::BitwiseHamming) so that we don't perform unnecessary computations for BitwiseHamming. (the diff seems to detect all additional indents as changes which makes it confusing)

wmma::fill_fragment(c_frag, 0.0);

for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) {
int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1)
? data_dim - step * TILE_COL_WIDTH
: TILE_COL_WIDTH;
#pragma unroll
for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) {
int idx = i * num_warps + warp_id;
if (idx < list_new_size) {
size_t neighbor_id = new_neighbors[idx];
size_t idx_in_data = neighbor_id * data_dim;
load_vec(s_nv[idx],
data + idx_in_data + step * TILE_COL_WIDTH,
num_load_elems,
TILE_COL_WIDTH,
lane_id);
for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) {
int idx = i * num_warps + warp_id;
if (idx < list_new_size) {
size_t neighbor_id = new_neighbors[idx];
size_t idx_in_data = neighbor_id * data_dim;
load_vec(s_nv[idx],
data + idx_in_data + step * TILE_COL_WIDTH,
num_load_elems,
TILE_COL_WIDTH,
lane_id);
}
}
}
__syncthreads();

for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) {
wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD);
wmma::load_matrix_sync(b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
__syncthreads();

for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) {
wmma::load_matrix_sync(
a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD);
wmma::load_matrix_sync(
b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
__syncthreads();
}
}
}

wmma::store_matrix_sync(
s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N,
c_frag,
SKEWED_MAX_NUM_BI_SAMPLES,
wmma::mem_row_major);
__syncthreads();
wmma::store_matrix_sync(
s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N,
c_frag,
SKEWED_MAX_NUM_BI_SAMPLES,
wmma::mem_row_major);

__syncthreads();
}

for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) {
int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES;
Expand All @@ -632,6 +643,15 @@ __launch_bounds__(BLOCK_SIZE)
s_distances[i] = -s_distances[i];
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
s_distances[i] = 1.0 - s_distances[i];
} else if (metric == cuvs::distance::DistanceType::BitwiseHamming) {
s_distances[i] = 0.0;
int n1 = new_neighbors[row_id];
int n2 = new_neighbors[col_id];
const uint8_t* data_n1 = reinterpret_cast<const uint8_t*>(data) + n1 * data_dim;
const uint8_t* data_n2 = reinterpret_cast<const uint8_t*>(data) + n2 * data_dim;
for (int d = 0; d < data_dim; d++) {
s_distances[i] += __popc(static_cast<uint32_t>(data_n1[d] ^ data_n2[d]) & 0xff);
}
} else { // L2Expanded or L2SqrtExpanded
s_distances[i] =
l2_norms[new_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i];
Expand Down Expand Up @@ -659,56 +679,60 @@ __launch_bounds__(BLOCK_SIZE)

__syncthreads();

wmma::fill_fragment(c_frag, 0.0);
for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) {
int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1)
? data_dim - step * TILE_COL_WIDTH
: TILE_COL_WIDTH;
if (TILE_COL_WIDTH < data_dim) {
if (metric != cuvs::distance::DistanceType::BitwiseHamming) {
wmma::fill_fragment(c_frag, 0.0);
for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) {
int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1)
? data_dim - step * TILE_COL_WIDTH
: TILE_COL_WIDTH;
if (TILE_COL_WIDTH < data_dim) {
#pragma unroll
for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) {
int idx = i * num_warps + warp_id;
if (idx < list_new_size) {
size_t neighbor_id = new_neighbors[idx];
size_t idx_in_data = neighbor_id * data_dim;
load_vec(s_nv[idx],
data + idx_in_data + step * TILE_COL_WIDTH,
num_load_elems,
TILE_COL_WIDTH,
lane_id);
}
}
}
#pragma unroll
for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) {
int idx = i * num_warps + warp_id;
if (idx < list_new_size) {
size_t neighbor_id = new_neighbors[idx];
if (idx < list_old_size) {
size_t neighbor_id = old_neighbors[idx];
size_t idx_in_data = neighbor_id * data_dim;
load_vec(s_nv[idx],
load_vec(s_ov[idx],
data + idx_in_data + step * TILE_COL_WIDTH,
num_load_elems,
TILE_COL_WIDTH,
lane_id);
}
}
}
#pragma unroll
for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) {
int idx = i * num_warps + warp_id;
if (idx < list_old_size) {
size_t neighbor_id = old_neighbors[idx];
size_t idx_in_data = neighbor_id * data_dim;
load_vec(s_ov[idx],
data + idx_in_data + step * TILE_COL_WIDTH,
num_load_elems,
TILE_COL_WIDTH,
lane_id);
__syncthreads();

for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) {
wmma::load_matrix_sync(
a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD);
wmma::load_matrix_sync(
b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
__syncthreads();
}
}
__syncthreads();

for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) {
wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD);
wmma::load_matrix_sync(b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
__syncthreads();
}
wmma::store_matrix_sync(
s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N,
c_frag,
SKEWED_MAX_NUM_BI_SAMPLES,
wmma::mem_row_major);
__syncthreads();
}

wmma::store_matrix_sync(
s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N,
c_frag,
SKEWED_MAX_NUM_BI_SAMPLES,
wmma::mem_row_major);
__syncthreads();

for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) {
int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES;
int col_id = i / SKEWED_MAX_NUM_BI_SAMPLES;
Expand All @@ -717,6 +741,14 @@ __launch_bounds__(BLOCK_SIZE)
s_distances[i] = -s_distances[i];
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
s_distances[i] = 1.0 - s_distances[i];
} else if (metric == cuvs::distance::DistanceType::BitwiseHamming) {
int n1 = old_neighbors[row_id];
int n2 = new_neighbors[col_id];
const uint8_t* data_n1 = reinterpret_cast<const uint8_t*>(data) + n1 * data_dim;
const uint8_t* data_n2 = reinterpret_cast<const uint8_t*>(data) + n2 * data_dim;
for (int d = 0; d < data_dim; d++) {
s_distances[i] += __popc(static_cast<uint32_t>(data_n1[d] ^ data_n2[d]) & 0xff);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wherever you are doing this, have you ensured that the dim that you pass along is correctly scaled? If you were to convert a half to two uint8s, the dim would have to be doubled.
Furthermore, I haven't looked into the nn descent logic entirely, but in case you are operating in the half space, I dont think you'd have to reinterpret_cast everywhere to uint8. Its more efficient to stay in the half space and do something like:
__popc(static_cast<uint32_t>(data_n1[d] ^ data_n2[d]) & 0xffffu)

Copy link
Copy Markdown
Contributor Author

@jinsolp jinsolp Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wherever you are doing this, have you ensured that the dim that you pass along is correctly scaled? If you were to convert a half to two uint8s, the dim would have to be doubled.

I am allocating dim/2 dimensions for the half type array. To make things straightforward, the dim is always used as the dimension of the "given dataset" (fp32, int8 etc). And dimensions for allocating the fp16 type device array is configured based on the data type (dim/2 for int8 and uint8, as-is for other types)

I dont think you'd have to reinterpret_cast everywhere to uint8.

the issue with keeping the pointer fp16 and doing this for data_dim/2 dimensions results in having to check if the last byte is a valid value or not inside the kernel (because the original int8 data could have an odd number of dimensions). I thought it would be easier to cast to int8 to loop over the original data_dim instead : )

Copy link
Copy Markdown
Contributor

@tarang-jain tarang-jain Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Yes there can be an odd number of dims, for which we can fall back to int8, but if its not an odd number of dims, we can do it in the half space. I'd argue that we can do even more -- if it is divisible by 4, reinterpret_cast to uint32_t so you'd have to popcount only over one fourth the dim (I'm doing the same thing with BitwiseHamming in ivf-flat). However, considering the deadlines, we can look into these optimizations later. Can we create a github issue for this and write it as a TODO here?
Regarding the dims, I just wanted to verify if the dims being used everywhere need to be scaled, but it looks like you have already checked those things, so apart from the creation of that github issue this PR looks okay to me.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, added an issue here #1127

}
} else { // L2Expanded or L2SqrtExpanded
s_distances[i] =
l2_norms[old_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i];
Expand Down Expand Up @@ -980,7 +1012,11 @@ GNND<Data_t, Index_t>::GNND(raft::resources const& res, const BuildConfig& build
nrow_(build_config.max_dataset_size),
ndim_(build_config.dataset_dim),
d_data_{raft::make_device_matrix<__half, size_t, raft::row_major>(
res, nrow_, build_config.dataset_dim)},
res,
nrow_,
build_config.metric == cuvs::distance::DistanceType::BitwiseHamming
? (build_config.dataset_dim + 1) / 2
: build_config.dataset_dim)},
Comment on lines +1017 to +1021
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's what we do here!

l2_norms_{raft::make_device_vector<DistData_t, size_t>(res, 0)},
graph_buffer_{
raft::make_device_matrix<ID_t, size_t, raft::row_major>(res, nrow_, DEGREE_ON_DEVICE)},
Expand Down Expand Up @@ -1071,12 +1107,19 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
{
using input_t = typename std::remove_const<Data_t>::type;

if (build_config_.metric == cuvsDistanceType::BitwiseHamming &&
!(std::is_same_v<input_t, uint8_t> || std::is_same_v<input_t, int8_t>)) {
RAFT_FAIL(
"Data type needs to be int8 or uint8 for NN Descent to run with BitwiseHamming distance.");
}

cudaStream_t stream = raft::resource::get_cuda_stream(res);
nrow_ = nrow;
graph_.nrow = nrow;
graph_.bloom_filter.set_nrow(nrow);
update_counter_ = 0;
graph_.h_graph = (InternalID_t<Index_t>*)output_graph;
raft::matrix::fill(res, d_data_.view(), static_cast<__half>(0));

cudaPointerAttributes data_ptr_attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data));
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/neighbors/detail/nn_descent_gnnd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,11 @@ inline BuildConfig get_build_config(raft::resources const& res,
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded ||
params.metric == cuvs::distance::DistanceType::InnerProduct;
params.metric == cuvs::distance::DistanceType::InnerProduct ||
params.metric == cuvs::distance::DistanceType::BitwiseHamming;
RAFT_EXPECTS(allowed_metrics,
"The metric for NN Descent should be L2Expanded, L2SqrtExpanded, CosineExpanded or "
"InnerProduct");
"The metric for NN Descent should be L2Expanded, L2SqrtExpanded, CosineExpanded, "
"InnerProduct or BitwiseHamming");
RAFT_EXPECTS(
metric == params.metric,
"The metrics set in nn_descent::index_params and nn_descent::index are inconsistent");
Expand Down
8 changes: 7 additions & 1 deletion cpp/tests/neighbors/ann_nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/nn_descent.hpp>

#include <raft/core/host_mdarray.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/util/cudart_utils.hpp>
#include <raft/util/itertools.hpp>
Expand Down Expand Up @@ -86,6 +87,10 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
protected:
void testNNDescent()
{
if (ps.metric == cuvs::distance::DistanceType::BitwiseHamming &&
!(std::is_same_v<DataT, uint8_t> || std::is_same_v<DataT, int8_t>)) {
GTEST_SKIP();
}
size_t queries_size = ps.n_rows * ps.graph_degree;
std::vector<IdxT> indices_NNDescent(queries_size);
std::vector<DistanceT> distances_NNDescent(queries_size);
Expand Down Expand Up @@ -474,7 +479,8 @@ const std::vector<AnnNNDescentInputs> inputs =
raft::util::itertools::product<AnnNNDescentInputs>({2000, 4000}, // n_rows
{4, 16, 64, 256, 1024}, // dim
{32, 64}, // graph_degree
{cuvs::distance::DistanceType::L2Expanded,
{cuvs::distance::DistanceType::BitwiseHamming,
cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::L2SqrtExpanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/neighbors/naive_knn.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -66,7 +66,7 @@ RAFT_KERNEL naive_distance_kernel(EvalT* dist,
acc += diff * diff;
} break;
case cuvs::distance::DistanceType::BitwiseHamming: {
if constexpr (std::is_same_v<uint8_t, DataT>) {
if constexpr (std::is_same_v<uint8_t, DataT> || std::is_same_v<int8_t, DataT>) {
acc += __popc(static_cast<uint32_t>(xv ^ yv) & 0xff);
}
} break;
Expand Down
Loading