-
Notifications
You must be signed in to change notification settings - Fork 184
BitwiseHamming distance for NN Descent #1101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
a81b370
4250f32
b4cc11e
b8e44ed
b19a9dc
fa662ab
c2aa027
e2eb7c3
4c56110
a372adf
a58e2ae
982b139
9ff3cd8
e593ba7
2d1d17c
6c6d1bd
883ac1a
f294140
67add6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
|
|
||
| #include "ann_utils.cuh" | ||
| #include "cagra/device_common.hpp" | ||
| #include "cuvs/distance/distance.h" | ||
| #include "nn_descent_gnnd.hpp" | ||
|
|
||
| #include <cuvs/distance/distance.hpp> | ||
|
|
@@ -285,6 +286,10 @@ RAFT_KERNEL preprocess_data_kernel( | |
| } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { | ||
| output_data[list_id * dim + idx] = | ||
| (float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm); | ||
| } else if (metric == cuvs::distance::DistanceType::BitwiseHamming) { | ||
| int idx_for_byte = list_id * dim + idx; // uint8 or int8 data | ||
| uint8_t* output_bytes = reinterpret_cast<uint8_t*>(output_data); | ||
| output_bytes[idx_for_byte] = input_data[idx_for_byte]; | ||
| } else { // L2Expanded or L2SqrtExpanded | ||
| output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx]; | ||
| if (idx == 0) { l2_norms[list_id] = l2_norm; } | ||
|
|
@@ -588,40 +593,46 @@ __launch_bounds__(BLOCK_SIZE) | |
| wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag; | ||
| wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag; | ||
| wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag; | ||
| wmma::fill_fragment(c_frag, 0.0); | ||
| for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) { | ||
| int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1) | ||
| ? data_dim - step * TILE_COL_WIDTH | ||
| : TILE_COL_WIDTH; | ||
| if (metric != cuvs::distance::DistanceType::BitwiseHamming) { | ||
| wmma::fill_fragment(c_frag, 0.0); | ||
|
|
||
| for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) { | ||
| int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1) | ||
| ? data_dim - step * TILE_COL_WIDTH | ||
| : TILE_COL_WIDTH; | ||
| #pragma unroll | ||
| for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { | ||
| int idx = i * num_warps + warp_id; | ||
| if (idx < list_new_size) { | ||
| size_t neighbor_id = new_neighbors[idx]; | ||
| size_t idx_in_data = neighbor_id * data_dim; | ||
| load_vec(s_nv[idx], | ||
| data + idx_in_data + step * TILE_COL_WIDTH, | ||
| num_load_elems, | ||
| TILE_COL_WIDTH, | ||
| lane_id); | ||
| for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { | ||
| int idx = i * num_warps + warp_id; | ||
| if (idx < list_new_size) { | ||
| size_t neighbor_id = new_neighbors[idx]; | ||
| size_t idx_in_data = neighbor_id * data_dim; | ||
| load_vec(s_nv[idx], | ||
| data + idx_in_data + step * TILE_COL_WIDTH, | ||
| num_load_elems, | ||
| TILE_COL_WIDTH, | ||
| lane_id); | ||
| } | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { | ||
| wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); | ||
| wmma::load_matrix_sync(b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); | ||
| wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); | ||
| __syncthreads(); | ||
|
|
||
| for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { | ||
| wmma::load_matrix_sync( | ||
| a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); | ||
| wmma::load_matrix_sync( | ||
| b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); | ||
| wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); | ||
| __syncthreads(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| wmma::store_matrix_sync( | ||
| s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, | ||
| c_frag, | ||
| SKEWED_MAX_NUM_BI_SAMPLES, | ||
| wmma::mem_row_major); | ||
| __syncthreads(); | ||
| wmma::store_matrix_sync( | ||
| s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, | ||
| c_frag, | ||
| SKEWED_MAX_NUM_BI_SAMPLES, | ||
| wmma::mem_row_major); | ||
|
|
||
| __syncthreads(); | ||
| } | ||
|
|
||
| for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { | ||
| int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES; | ||
|
|
@@ -632,6 +643,15 @@ __launch_bounds__(BLOCK_SIZE) | |
| s_distances[i] = -s_distances[i]; | ||
| } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { | ||
| s_distances[i] = 1.0 - s_distances[i]; | ||
| } else if (metric == cuvs::distance::DistanceType::BitwiseHamming) { | ||
| s_distances[i] = 0.0; | ||
| int n1 = new_neighbors[row_id]; | ||
| int n2 = new_neighbors[col_id]; | ||
| const uint8_t* data_n1 = reinterpret_cast<const uint8_t*>(data) + n1 * data_dim; | ||
| const uint8_t* data_n2 = reinterpret_cast<const uint8_t*>(data) + n2 * data_dim; | ||
| for (int d = 0; d < data_dim; d++) { | ||
| s_distances[i] += __popc(static_cast<uint32_t>(data_n1[d] ^ data_n2[d]) & 0xff); | ||
| } | ||
| } else { // L2Expanded or L2SqrtExpanded | ||
| s_distances[i] = | ||
| l2_norms[new_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i]; | ||
|
|
@@ -659,56 +679,60 @@ __launch_bounds__(BLOCK_SIZE) | |
|
|
||
| __syncthreads(); | ||
|
|
||
| wmma::fill_fragment(c_frag, 0.0); | ||
| for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) { | ||
| int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1) | ||
| ? data_dim - step * TILE_COL_WIDTH | ||
| : TILE_COL_WIDTH; | ||
| if (TILE_COL_WIDTH < data_dim) { | ||
| if (metric != cuvs::distance::DistanceType::BitwiseHamming) { | ||
| wmma::fill_fragment(c_frag, 0.0); | ||
| for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) { | ||
| int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1) | ||
| ? data_dim - step * TILE_COL_WIDTH | ||
| : TILE_COL_WIDTH; | ||
| if (TILE_COL_WIDTH < data_dim) { | ||
| #pragma unroll | ||
| for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { | ||
| int idx = i * num_warps + warp_id; | ||
| if (idx < list_new_size) { | ||
| size_t neighbor_id = new_neighbors[idx]; | ||
| size_t idx_in_data = neighbor_id * data_dim; | ||
| load_vec(s_nv[idx], | ||
| data + idx_in_data + step * TILE_COL_WIDTH, | ||
| num_load_elems, | ||
| TILE_COL_WIDTH, | ||
| lane_id); | ||
| } | ||
| } | ||
| } | ||
| #pragma unroll | ||
| for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { | ||
| int idx = i * num_warps + warp_id; | ||
| if (idx < list_new_size) { | ||
| size_t neighbor_id = new_neighbors[idx]; | ||
| if (idx < list_old_size) { | ||
| size_t neighbor_id = old_neighbors[idx]; | ||
| size_t idx_in_data = neighbor_id * data_dim; | ||
| load_vec(s_nv[idx], | ||
| load_vec(s_ov[idx], | ||
| data + idx_in_data + step * TILE_COL_WIDTH, | ||
| num_load_elems, | ||
| TILE_COL_WIDTH, | ||
| lane_id); | ||
| } | ||
| } | ||
| } | ||
| #pragma unroll | ||
| for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { | ||
| int idx = i * num_warps + warp_id; | ||
| if (idx < list_old_size) { | ||
| size_t neighbor_id = old_neighbors[idx]; | ||
| size_t idx_in_data = neighbor_id * data_dim; | ||
| load_vec(s_ov[idx], | ||
| data + idx_in_data + step * TILE_COL_WIDTH, | ||
| num_load_elems, | ||
| TILE_COL_WIDTH, | ||
| lane_id); | ||
| __syncthreads(); | ||
|
|
||
| for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { | ||
| wmma::load_matrix_sync( | ||
| a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); | ||
| wmma::load_matrix_sync( | ||
| b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); | ||
| wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); | ||
| __syncthreads(); | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { | ||
| wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); | ||
| wmma::load_matrix_sync(b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); | ||
| wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); | ||
| __syncthreads(); | ||
| } | ||
| wmma::store_matrix_sync( | ||
| s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, | ||
| c_frag, | ||
| SKEWED_MAX_NUM_BI_SAMPLES, | ||
| wmma::mem_row_major); | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| wmma::store_matrix_sync( | ||
| s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, | ||
| c_frag, | ||
| SKEWED_MAX_NUM_BI_SAMPLES, | ||
| wmma::mem_row_major); | ||
| __syncthreads(); | ||
|
|
||
| for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { | ||
| int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES; | ||
| int col_id = i / SKEWED_MAX_NUM_BI_SAMPLES; | ||
|
|
@@ -717,6 +741,14 @@ __launch_bounds__(BLOCK_SIZE) | |
| s_distances[i] = -s_distances[i]; | ||
| } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { | ||
| s_distances[i] = 1.0 - s_distances[i]; | ||
| } else if (metric == cuvs::distance::DistanceType::BitwiseHamming) { | ||
| int n1 = old_neighbors[row_id]; | ||
| int n2 = new_neighbors[col_id]; | ||
| const uint8_t* data_n1 = reinterpret_cast<const uint8_t*>(data) + n1 * data_dim; | ||
| const uint8_t* data_n2 = reinterpret_cast<const uint8_t*>(data) + n2 * data_dim; | ||
| for (int d = 0; d < data_dim; d++) { | ||
| s_distances[i] += __popc(static_cast<uint32_t>(data_n1[d] ^ data_n2[d]) & 0xff); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wherever you are doing this, have you ensured that the dim that you pass along is correctly scaled? If you were to convert a half to two uint8s, the dim would have to be doubled.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I am allocating dim/2 dimensions for the half type array. To make things straightforward, the dim is always used as the dimension of the "given dataset" (fp32, int8 etc). And dimensions for allocating the fp16 type device array is configured based on the data type (dim/2 for int8 and uint8, as-is for other types)
the issue with keeping the pointer fp16 and doing this for data_dim/2 dimensions results in having to check if the last byte is a valid value or not inside the kernel (because the original int8 data could have an odd number of dimensions). I thought it would be easier to cast to int8 to loop over the original data_dim instead : )
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Yes there can be an odd number of dims, for which we can fall back to int8, but if its not an odd number of dims, we can do it in the half space. I'd argue that we can do even more -- if it is divisible by 4, reinterpret_cast to uint32_t so you'd have to popcount only over one fourth the dim (I'm doing the same thing with BitwiseHamming in ivf-flat). However, considering the deadlines, we can look into these optimizations later. Can we create a github issue for this and write it as a TODO here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, added an issue here #1127 |
||
| } | ||
| } else { // L2Expanded or L2SqrtExpanded | ||
| s_distances[i] = | ||
| l2_norms[old_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i]; | ||
|
|
@@ -980,7 +1012,11 @@ GNND<Data_t, Index_t>::GNND(raft::resources const& res, const BuildConfig& build | |
| nrow_(build_config.max_dataset_size), | ||
| ndim_(build_config.dataset_dim), | ||
| d_data_{raft::make_device_matrix<__half, size_t, raft::row_major>( | ||
| res, nrow_, build_config.dataset_dim)}, | ||
| res, | ||
| nrow_, | ||
| build_config.metric == cuvs::distance::DistanceType::BitwiseHamming | ||
| ? (build_config.dataset_dim + 1) / 2 | ||
| : build_config.dataset_dim)}, | ||
|
Comment on lines
+1017
to
+1021
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep that's what we do here! |
||
| l2_norms_{raft::make_device_vector<DistData_t, size_t>(res, 0)}, | ||
| graph_buffer_{ | ||
| raft::make_device_matrix<ID_t, size_t, raft::row_major>(res, nrow_, DEGREE_ON_DEVICE)}, | ||
|
|
@@ -1071,12 +1107,19 @@ void GNND<Data_t, Index_t>::build(Data_t* data, | |
| { | ||
| using input_t = typename std::remove_const<Data_t>::type; | ||
|
|
||
| if (build_config_.metric == cuvsDistanceType::BitwiseHamming && | ||
| !(std::is_same_v<input_t, uint8_t> || std::is_same_v<input_t, int8_t>)) { | ||
| RAFT_FAIL( | ||
| "Data type needs to be int8 or uint8 for NN Descent to run with BitwiseHamming distance."); | ||
| } | ||
|
|
||
| cudaStream_t stream = raft::resource::get_cuda_stream(res); | ||
| nrow_ = nrow; | ||
| graph_.nrow = nrow; | ||
| graph_.bloom_filter.set_nrow(nrow); | ||
| update_counter_ = 0; | ||
| graph_.h_graph = (InternalID_t<Index_t>*)output_graph; | ||
| raft::matrix::fill(res, d_data_.view(), static_cast<__half>(0)); | ||
|
|
||
| cudaPointerAttributes data_ptr_attr; | ||
| RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only diff here is wrapping the wmma operation part with
if (metric != cuvs::distance::DistanceType::BitwiseHamming)so that we don't perform unnecessary computations forBitwiseHamming. (the diff seems to detect all additional indents as changes which makes it confusing)