Skip to content

Commit 13e3af4

Browse files
committed
Add AVQ to bfloat quantization
1 parent 88f7f23 commit 13e3af4

4 files changed

Lines changed: 242 additions & 17 deletions

File tree

cpp/include/cuvs/neighbors/scann.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <raft/util/integer_utils.hpp>
3131
#include <rmm/cuda_stream_view.hpp>
3232

33+
#include <cmath>
3334
#include <optional>
3435
#include <variant>
3536

@@ -74,8 +75,12 @@ struct index_params : cuvs::neighbors::index_params {
7475
uint32_t pq_train_iters = 10;
7576

7677
/** whether to apply bf16 quantization of dataset vectors **/
77-
bool bf16_enabled = false;
78+
bool reordering_bf16_enabled = false;
7879

80+
/** Threshold for computing AVQ eta va Theorem 3.4 in https://arxiv.org/abs/1908.10396
81+
* If the threshold is NAN, AVQ is not performed during bfloat16 quant
82+
*/
83+
float reordering_noise_shaping_threshold = NAN;
7984
// TODO - add other scann build params
8085
};
8186

@@ -136,7 +141,7 @@ struct index : cuvs::neighbors::index {
136141
IdxT dim,
137142
uint32_t pq_clusters,
138143
uint32_t pq_num_subspaces,
139-
bool bf16_enabled)
144+
bool reordering_bf16_enabled)
140145
: cuvs::neighbors::index(),
141146
metric_(metric),
142147
pq_dim_(pq_dim),
@@ -154,7 +159,7 @@ struct index : cuvs::neighbors::index {
154159
n_rows_(n_rows),
155160
dim_(dim),
156161
bf16_dataset_(raft::make_host_matrix<int16_t, IdxT, raft::row_major>(
157-
bf16_enabled ? n_rows : 0, bf16_enabled ? dim : 0))
162+
reordering_bf16_enabled ? n_rows : 0, reordering_bf16_enabled ? dim : 0))
158163

159164
{
160165
}
@@ -169,7 +174,7 @@ struct index : cuvs::neighbors::index {
169174
dim,
170175
1 << params.pq_bits,
171176
dim / params.pq_dim,
172-
params.bf16_enabled)
177+
params.reordering_bf16_enabled)
173178
{
174179
RAFT_EXPECTS(params.pq_bits == 4 || params.pq_bits == 8, "ScaNN only supports 4 or 8 bit PQ");
175180
RAFT_EXPECTS(dim >= params.pq_dim,

cpp/src/neighbors/scann/detail/scann_build.cuh

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -313,16 +313,9 @@ index<T, IdxT> build(
313313
// TODO (rmaschal): Might be more efficient to do on CPU, to avoid DtoH copy
314314
auto bf16_dataset = raft::make_device_matrix<int16_t, int64_t>(res, batch_view.extent(0), dim);
315315

316-
if (params.bf16_enabled) {
317-
raft::linalg::unaryOp(
318-
bf16_dataset.data_handle(),
319-
batch_view.data_handle(),
320-
batch_view.size(),
321-
[] __device__(T x) {
322-
nv_bfloat16 val = __float2bfloat16(x);
323-
return reinterpret_cast<int16_t&>(val);
324-
},
325-
resource::get_cuda_stream(res));
316+
if (params.reordering_bf16_enabled) {
317+
quantize_bfloat16(
318+
res, batch_view, bf16_dataset.view(), params.reordering_noise_shaping_threshold);
326319
}
327320

328321
// Prefetch next batch
@@ -340,7 +333,7 @@ index<T, IdxT> build(
340333
quantized_soar_residuals.size(),
341334
stream);
342335

343-
if (params.bf16_enabled) {
336+
if (params.reordering_bf16_enabled) {
344337
raft::copy(idx.bf16_dataset().data_handle() + batch.offset() * dim,
345338
bf16_dataset.data_handle(),
346339
bf16_dataset.size(),

cpp/src/neighbors/scann/detail/scann_quantize.cuh

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
#include "../../detail/vpq_dataset.cuh"
1818
#include <chrono>
19+
#include <cmath>
1920
#include <cuvs/neighbors/common.hpp>
21+
#include <raft/linalg/transpose.cuh>
2022
#include <raft/matrix/gather.cuh>
2123

2224
#include "scann_common.cuh"
@@ -267,6 +269,231 @@ void unpack_codes(raft::resources const& res,
267269
}
268270
}
269271

272+
/**
273+
* @brief compute eta for AVQ according to Theorem 3.4 in https://arxiv.org/abs/1908.10396
274+
*
275+
* @tparam IdxT
276+
* @param dim the dataset dimension
277+
* @param sq_norm the squared norm of the vector
278+
* @param noise_shaping_threshold the threshold T in the Theorem
279+
* @return eta
280+
*/
281+
template <typename IdxT>
282+
__device__ inline float compute_avq_eta(IdxT dim, const float sq_norm, const float threshold)
283+
{
284+
return (dim - 1) * (threshold * threshold / sq_norm) / (1 - threshold * threshold / sq_norm);
285+
}
286+
287+
/**
288+
* @brief helper to convert a float to bfloat16 (represented as int16_t)
289+
*
290+
* @param f the float value
291+
* @return the bflaot16 value (as int16_t)
292+
*/
293+
__device__ inline int16_t float_to_bfloat16(const float& f)
294+
{
295+
nv_bfloat16 val = __float2bfloat16(f);
296+
return reinterpret_cast<int16_t&>(val);
297+
}
298+
299+
/**
300+
* @brief helper to convert a bfloat16 (represented as int16_t) to float
301+
*
302+
* @param bf16 the bf16 value (represented as int16_t)
303+
* @return the float value
304+
*/
305+
__device__ inline float bfloat16_to_float(int16_t& bf16)
306+
{
307+
nv_bfloat16 nv_bf16 = reinterpret_cast<nv_bfloat16&>(bf16);
308+
return __bfloat162float(nv_bf16);
309+
}
310+
311+
/**
312+
* @brief Select the next bfloat16 value to try during coordinate descent
313+
*
314+
* Based on the signs of the current residual and quantized value,
315+
* increment or decrement the quantized value to push residual closer to 0
316+
*
317+
* @param res the float residual
318+
* @param current the current quantized dimension
319+
* @return the other possible quantized value
320+
*/
321+
__device__ inline int16_t bfloat16_next_delta(float& res, int16_t& current)
322+
{
323+
uint32_t res_sign = ((int32_t)res & (1u << 31) >> 31);
324+
uint32_t curr_sign = (current & (1 << 15)) >> 15;
325+
326+
if (res_sign == curr_sign) { return current - 1; }
327+
328+
return current + 1;
329+
}
330+
331+
/**
332+
*
333+
*/
334+
template <uint32_t BlockSize, typename IdxT>
335+
__launch_bounds__(BlockSize) RAFT_KERNEL
336+
quantize_bfloat16_noise_shaped_kernel(raft::device_matrix_view<const float, IdxT> dataset,
337+
raft::device_matrix_view<int16_t, IdxT> bf16_dataset,
338+
raft::device_vector_view<const float> sq_norms,
339+
float noise_shaping_threshold)
340+
{
341+
IdxT row_idx = raft::Pow2<32>::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x});
342+
343+
if (row_idx >= dataset.extent(0)) { return; }
344+
345+
uint32_t lane_id = raft::Pow2<32>::mod(threadIdx.x);
346+
347+
IdxT dim = dataset.extent(1);
348+
349+
// 1 / ||x||
350+
float inv_norm = 1 / sqrtf(sq_norms[row_idx]);
351+
float eta = compute_avq_eta(dim, sq_norms[row_idx], noise_shaping_threshold);
352+
353+
// < r, x >
354+
float residual_dot = 0.0;
355+
356+
for (int i = lane_id; i < dim; i += 32) {
357+
bf16_dataset(row_idx, i) = float_to_bfloat16(dataset(row_idx, i));
358+
359+
float residual = dataset(row_idx, i) - bfloat16_to_float(bf16_dataset(row_idx, i));
360+
residual_dot += dataset(row_idx, i) * residual * inv_norm;
361+
}
362+
363+
// reduce and broadcast residual_dot across warp
364+
for (uint32_t offset = 16; offset > 0; offset >>= 1) {
365+
residual_dot += raft::shfl_xor(residual_dot, offset, 32);
366+
}
367+
368+
constexpr uint32_t kMaxRounds = 10;
369+
370+
bool round_changes = true;
371+
for (int round = 0; round < kMaxRounds && round_changes; round++) {
372+
round_changes = false;
373+
374+
for (int i = lane_id; i < dim; i += 32) {
375+
// coaleseced reads of required data
376+
float original = dataset(row_idx, i);
377+
int16_t quantized = bf16_dataset(row_idx, i);
378+
379+
float old_residual = original - bfloat16_to_float(quantized);
380+
int16_t quantized_new = bfloat16_next_delta(old_residual, quantized);
381+
382+
float new_residual = original - bfloat16_to_float(quantized_new);
383+
float residual_dot_delta = (new_residual - old_residual) * dataset(row_idx, i) * inv_norm;
384+
385+
float residual_norm_delta = new_residual * new_residual - old_residual * old_residual;
386+
387+
// we want to compute the change in cost = eta || r_parallel || ^2 + || r_perpendicular|| ^2
388+
// The change in || r_parallel ||^2 can be written (residual_dot + residual_dot_delta) ^ 2
389+
// the change in || r_perpendicular || ^2 can be written residual_norm_delta -
390+
// parallel_norm_delta Thus cost_delta = eta * (residual_dot + residual_dot_delta) ^2 +
391+
// (residual_norm_delta - (residual_dot + residual_dot_delta)^2 Expanding and simplying,
392+
// cost_delta = a + b * resdiaul_dot, where a and b are as below. Since only residual_dot is
393+
// unknown (because updates must be made synchronously) we can compute a and b in parallel
394+
// across threads in the warp and minimize computation in the update step of the coordinate
395+
// descent
396+
float a = residual_norm_delta + (eta - 1) * residual_dot_delta * residual_dot_delta;
397+
float b = 2 * (eta - 1) * residual_dot_delta;
398+
399+
// Dim may not be divisible by 32
400+
// Only synchronize/shuffle for active threads
401+
int active_threads = std::min<int>(32, dim - i + lane_id);
402+
int mask = (1 << active_threads) - 1;
403+
404+
// Update step for coordinate descent. Compute the cost_delta for
405+
// each thread, update the quantized value and residual_dot if applicable,
406+
// then broadcast the new residual dot to the warp
407+
// AVQ loss the not separable, so we must optimize each dimension separately
408+
for (int j = 0; j < active_threads; j++) {
409+
if (lane_id == j) {
410+
// change in AVQ loss
411+
float cost_delta = b * residual_dot + a;
412+
413+
if (cost_delta < 0.0) {
414+
quantized = quantized_new;
415+
residual_dot += residual_dot_delta;
416+
round_changes = true;
417+
}
418+
}
419+
420+
// broadcast new dot product to all lanes
421+
residual_dot = raft::shfl(residual_dot, j, active_threads, mask);
422+
}
423+
424+
// coalesced write of possibly updated quantized values
425+
bf16_dataset(row_idx, i) = quantized;
426+
}
427+
428+
// reduce round_changes across warp
429+
for (uint32_t offset = 16; offset > 0; offset >>= 1) {
430+
round_changes |= raft::shfl_xor(round_changes, offset, 32);
431+
}
432+
}
433+
}
434+
435+
/**
436+
* @brief Quantized a float dataset as bfloat16, with noise shaping (AVQ)
437+
*
438+
* @tparam IdxT
439+
* @param res raft resources
440+
* @param dataset the dataset (device only) size [n_rows, dim]
441+
* @param bf16_dataset the quantized dataset (device only) size [n_rows, dim]
442+
* @param noise_shaping_threshold the threshold for AVQ
443+
*/
444+
template <typename IdxT>
445+
void quantize_bfloat16_noise_shaped(raft::resources const& res,
446+
raft::device_matrix_view<const float, IdxT> dataset,
447+
raft::device_matrix_view<int16_t, IdxT> bf16_dataset,
448+
float noise_shaping_threshold)
449+
{
450+
cudaStream_t stream = raft::resource::get_cuda_stream(res);
451+
452+
IdxT n_rows = dataset.extent(0);
453+
auto norms = raft::make_device_vector<float, IdxT>(res, n_rows);
454+
455+
// populate square norms
456+
raft::linalg::norm<raft::linalg::NormType::L2Norm, raft::Apply::ALONG_ROWS>(
457+
res, dataset, norms.view());
458+
459+
constexpr int64_t kBlockSize = 256;
460+
461+
dim3 threads(kBlockSize, 1, 1);
462+
dim3 blocks(raft::div_rounding_up_safe<ix_t>(n_rows, kBlockSize / 32), 1, 1);
463+
464+
quantize_bfloat16_noise_shaped_kernel<kBlockSize, IdxT><<<blocks, threads, 0, stream>>>(
465+
dataset, bf16_dataset, raft::make_const_mdspan(norms.view()), noise_shaping_threshold);
466+
467+
RAFT_CUDA_TRY(cudaPeekAtLastError());
468+
}
469+
470+
/**
471+
* @brief Quantized a float dataset as bfloat16, with optional noise shaping (AVQ)
472+
*
473+
* @tparam IdxT
474+
* @param res raft resources
475+
* @param dataset the dataset (device only) size [n_rows, dim]
476+
* @param bf16_dataset the quantized dataset (device only) size [n_rows, dim]
477+
* @param noise_shaping_threshold the threshold for AVQ (nan when not using AVQ)
478+
*/
479+
template <typename IdxT>
480+
void quantize_bfloat16(raft::resources const& res,
481+
raft::device_matrix_view<const float, IdxT> dataset,
482+
raft::device_matrix_view<int16_t, IdxT> bf16_dataset,
483+
float noise_shaping_threshold)
484+
{
485+
if (!std::isnan(noise_shaping_threshold)) {
486+
quantize_bfloat16_noise_shaped(res, dataset, bf16_dataset, noise_shaping_threshold);
487+
} else {
488+
raft::linalg::unaryOp(
489+
bf16_dataset.data_handle(),
490+
dataset.data_handle(),
491+
dataset.size(),
492+
[] __device__(float x) { return float_to_bfloat16(x); },
493+
resource::get_cuda_stream(res));
494+
}
495+
}
496+
270497
/**
271498
* @brief sample dataset vectors/labels and compute their residuals for PQ training
272499
*

cpp/tests/neighbors/ann_scann.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class scann_test : public ::testing::TestWithParam<scann_inputs> {
131131
ASSERT_EQ(index.pq_codebook().extent(0), num_pq_clusters);
132132
ASSERT_EQ(index.pq_codebook().extent(1), ps.dim);
133133

134-
IdxT expected_bf16_size = ps.index_params.bf16_enabled ? ps.dim * ps.num_db_vecs : 0;
134+
IdxT expected_bf16_size = ps.index_params.reordering_bf16_enabled ? ps.dim * ps.num_db_vecs : 0;
135135

136136
ASSERT_EQ(index.bf16_dataset().size(), expected_bf16_size);
137137
}
@@ -227,7 +227,7 @@ inline auto big_dims_all_pq_bits() -> test_cases_t
227227
inline auto bf16() -> test_cases_t
228228
{
229229
scann_inputs ts;
230-
ts.index_params.bf16_enabled = true;
230+
ts.index_params.reordering_bf16_enabled = true;
231231

232232
return {ts};
233233
}

0 commit comments

Comments
 (0)