Skip to content

Commit 89a2ec7

Browse files
committed
Comments + add simple bfloat16 w/ AVQ test
1 parent 8824bcd commit 89a2ec7

5 files changed

Lines changed: 34 additions & 16 deletions

File tree

cpp/include/cuvs/neighbors/scann.hpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,20 @@ struct index_params : cuvs::neighbors::index_params {
7575
uint32_t pq_train_iters = 10;
7676

7777
/** whether to apply bf16 quantization of dataset vectors **/
78-
bool reordering_bf16_enabled = false;
79-
80-
/** Threshold for computing AVQ eta va Theorem 3.4 in https://arxiv.org/abs/1908.10396
78+
bool reordering_bf16 = false;
79+
80+
/** Threshold T for computing AVQ eta = (dim - 1) ( T^2 / || x ||^2) / ( 1 - T^2 / || x ||^2)
81+
*
82+
* When quantizing a vector x to x_q, AVQ minimizes the loss function
83+
* L(x, x_q) = eta * || r_para ||^2 + || r_perp ||^2, where
84+
* r = x - x_q, r_para = <r, x> * x / || x ||^2, r_perp = r - r_para
85+
*
86+
* Compared to L2 loss, This produces an x_q which better approximates
87+
* the dot product of a query vector with x
88+
*
8189
* If the threshold is NAN, AVQ is not performed during bfloat16 quant
8290
*/
8391
float reordering_noise_shaping_threshold = NAN;
84-
// TODO - add other scann build params
8592
};
8693

8794
/**
@@ -141,7 +148,7 @@ struct index : cuvs::neighbors::index {
141148
IdxT dim,
142149
uint32_t pq_clusters,
143150
uint32_t pq_num_subspaces,
144-
bool reordering_bf16_enabled)
151+
bool reordering_bf16)
145152
: cuvs::neighbors::index(),
146153
metric_(metric),
147154
pq_dim_(pq_dim),
@@ -159,7 +166,7 @@ struct index : cuvs::neighbors::index {
159166
n_rows_(n_rows),
160167
dim_(dim),
161168
bf16_dataset_(raft::make_host_matrix<int16_t, IdxT, raft::row_major>(
162-
reordering_bf16_enabled ? n_rows : 0, reordering_bf16_enabled ? dim : 0))
169+
reordering_bf16 ? n_rows : 0, reordering_bf16 ? dim : 0))
163170

164171
{
165172
}
@@ -174,7 +181,7 @@ struct index : cuvs::neighbors::index {
174181
dim,
175182
1 << params.pq_bits,
176183
dim / params.pq_dim,
177-
params.reordering_bf16_enabled)
184+
params.reordering_bf16)
178185
{
179186
RAFT_EXPECTS(params.pq_bits == 4 || params.pq_bits == 8, "ScaNN only supports 4 or 8 bit PQ");
180187
RAFT_EXPECTS(dim >= params.pq_dim,

cpp/src/neighbors/scann/detail/scann_build.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ index<T, IdxT> build(
313313
// TODO (rmaschal): Might be more efficient to do on CPU, to avoid DtoH copy
314314
auto bf16_dataset = raft::make_device_matrix<int16_t, int64_t>(res, batch_view.extent(0), dim);
315315

316-
if (params.reordering_bf16_enabled) {
316+
if (params.reordering_bf16) {
317317
quantize_bfloat16(
318318
res, batch_view, bf16_dataset.view(), params.reordering_noise_shaping_threshold);
319319
}
@@ -333,7 +333,7 @@ index<T, IdxT> build(
333333
quantized_soar_residuals.size(),
334334
stream);
335335

336-
if (params.reordering_bf16_enabled) {
336+
if (params.reordering_bf16) {
337337
raft::copy(idx.bf16_dataset().data_handle() + batch.offset() * dim,
338338
bf16_dataset.data_handle(),
339339
bf16_dataset.size(),

cpp/src/neighbors/scann/detail/scann_quantize.cuh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,11 @@ __device__ inline float bfloat16_to_float(int16_t& bf16)
314314
* Based on the signs of the current residual and quantized value,
315315
* increment or decrement the quantized value to push residual closer to 0
316316
*
317+
* Note that the bfloat16 value is encoded as an int16_t, and the
318+
* increment/decrement is applied to encoded value. In terms of the float
319+
* representation, it is the mantissa that is being incremented/decremented,
320+
* which could carryover to the exponent
321+
*
317322
* @param res the float residual
318323
* @param current the current quantized dimension
319324
* @return the other possible quantized value
@@ -328,9 +333,6 @@ __device__ inline int16_t bfloat16_next_delta(float& res, int16_t& current)
328333
return current + 1;
329334
}
330335

331-
/**
332-
*
333-
*/
334336
template <uint32_t BlockSize, typename IdxT>
335337
__launch_bounds__(BlockSize) RAFT_KERNEL
336338
quantize_bfloat16_noise_shaped_kernel(raft::device_matrix_view<const float, IdxT> dataset,

cpp/tests/neighbors/ann_scann.cuh

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class scann_test : public ::testing::TestWithParam<scann_inputs> {
131131
ASSERT_EQ(index.pq_codebook().extent(0), num_pq_clusters);
132132
ASSERT_EQ(index.pq_codebook().extent(1), ps.dim);
133133

134-
IdxT expected_bf16_size = ps.index_params.reordering_bf16_enabled ? ps.dim * ps.num_db_vecs : 0;
134+
IdxT expected_bf16_size = ps.index_params.reordering_bf16 ? ps.dim * ps.num_db_vecs : 0;
135135

136136
ASSERT_EQ(index.bf16_dataset().size(), expected_bf16_size);
137137
}
@@ -227,7 +227,16 @@ inline auto big_dims_all_pq_bits() -> test_cases_t
227227
inline auto bf16() -> test_cases_t
228228
{
229229
scann_inputs ts;
230-
ts.index_params.reordering_bf16_enabled = true;
230+
ts.index_params.reordering_bf16 = true;
231+
232+
return {ts};
233+
}
234+
235+
inline auto bf16_avq() -> test_cases_t
236+
{
237+
scann_inputs ts;
238+
ts.index_params.reordering_bf16 = true;
239+
ts.index_params.reordering_noise_shaping_threshold = 0.2;
231240

232241
return {ts};
233242
}

cpp/tests/neighbors/ann_scann/test_float_int64_t.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ TEST_BUILD_HOST_INPUT(f32_i64)
2525
TEST_BUILD_HOST_INPUT_OVERLAP(f32_i64);
2626

2727
INSTANTIATE(f32_i64,
28-
defaults() + small_dims_all_pq_bits() + big_dims_all_pq_bits() + bf16() + avq() +
29-
soar());
28+
defaults() + small_dims_all_pq_bits() + big_dims_all_pq_bits() + bf16() + bf16_avq() +
29+
avq() + soar());
3030

3131
} // namespace cuvs::neighbors::experimental::scann

0 commit comments

Comments
 (0)