Skip to content
Merged
38 changes: 31 additions & 7 deletions cpp/include/cuvs/neighbors/scann.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <raft/util/integer_utils.hpp>
#include <rmm/cuda_stream_view.hpp>

#include <cmath>
#include <optional>
#include <variant>

Expand Down Expand Up @@ -74,9 +75,20 @@ struct index_params : cuvs::neighbors::index_params {
uint32_t pq_train_iters = 10;

/** whether to apply bf16 quantization of dataset vectors **/
bool bf16_enabled = false;

// TODO - add other scann build params
bool reordering_bf16 = false;
Comment thread
rmaschal marked this conversation as resolved.

/** Threshold T for computing AVQ eta = (dim - 1) ( T^2 / || x ||^2) / ( 1 - T^2 / || x ||^2)
*
* When quantizing a vector x to x_q, AVQ minimizes the loss function
* L(x, x_q) = eta * || r_para ||^2 + || r_perp ||^2, where
* r = x - x_q, r_para = <r, x> * x / || x ||^2, r_perp = r - r_para
*
* Compared to L2 loss, This produces an x_q which better approximates
* the dot product of a query vector with x
*
* If the threshold is NAN, AVQ is not performed during bfloat16 quant
Comment thread
cjnolet marked this conversation as resolved.
*/
float reordering_noise_shaping_threshold = NAN;
};

/**
Expand Down Expand Up @@ -136,7 +148,7 @@ struct index : cuvs::neighbors::index {
IdxT dim,
uint32_t pq_clusters,
uint32_t pq_num_subspaces,
bool bf16_enabled)
bool reordering_bf16)
: cuvs::neighbors::index(),
metric_(metric),
pq_dim_(pq_dim),
Expand All @@ -154,7 +166,7 @@ struct index : cuvs::neighbors::index {
n_rows_(n_rows),
dim_(dim),
bf16_dataset_(raft::make_host_matrix<int16_t, IdxT, raft::row_major>(
bf16_enabled ? n_rows : 0, bf16_enabled ? dim : 0))
reordering_bf16 ? n_rows : 0, reordering_bf16 ? dim : 0))

{
}
Expand All @@ -169,7 +181,7 @@ struct index : cuvs::neighbors::index {
dim,
1 << params.pq_bits,
dim / params.pq_dim,
params.bf16_enabled)
params.reordering_bf16)
{
RAFT_EXPECTS(params.pq_bits == 4 || params.pq_bits == 8, "ScaNN only supports 4 or 8 bit PQ");
RAFT_EXPECTS(dim >= params.pq_dim,
Expand Down Expand Up @@ -260,9 +272,21 @@ struct index : cuvs::neighbors::index {
raft::device_matrix<float, uint32_t, raft::row_major> pq_codebook_;
raft::host_matrix<uint8_t, IdxT, raft::row_major> quantized_residuals_;
raft::host_matrix<uint8_t, IdxT, raft::row_major> quantized_soar_residuals_;

/* Internally, __nv_bfloat16 is used for float <-> __nv_bfloat16 conversion.
* The bits of __nv_bfloat16 are stored here reinterpreted as int16_t
*
* int16_t is used for two reaosns:
* * OSS ScaNN expects int16_t, so the serialzed bf16_dataset_ can be consumed
* without any additional post-processing
* * For AVQ, we need to find the next bfloat16 number that is larger/smaller than a
* given float. This is equivalent to incrementing/decrementing the mantissa
* in IEEE representation of the bfloat16 number, which in turn is equivalent
* to incrementing/decrementing the int16_t representation
*/
raft::host_matrix<int16_t, IdxT, raft::row_major> bf16_dataset_;
// TODO - add any data, pointers or structures needed
};

/**
* @}
*/
Expand Down
15 changes: 4 additions & 11 deletions cpp/src/neighbors/scann/detail/scann_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -313,16 +313,9 @@ index<T, IdxT> build(
// TODO (rmaschal): Might be more efficient to do on CPU, to avoid DtoH copy
auto bf16_dataset = raft::make_device_matrix<int16_t, int64_t>(res, batch_view.extent(0), dim);

if (params.bf16_enabled) {
raft::linalg::unaryOp(
bf16_dataset.data_handle(),
batch_view.data_handle(),
batch_view.size(),
[] __device__(T x) {
nv_bfloat16 val = __float2bfloat16(x);
return reinterpret_cast<int16_t&>(val);
},
resource::get_cuda_stream(res));
if (params.reordering_bf16) {
quantize_bfloat16(
Comment thread
cjnolet marked this conversation as resolved.
res, batch_view, bf16_dataset.view(), params.reordering_noise_shaping_threshold);
}

// Prefetch next batch
Expand All @@ -340,7 +333,7 @@ index<T, IdxT> build(
quantized_soar_residuals.size(),
stream);

if (params.bf16_enabled) {
if (params.reordering_bf16) {
raft::copy(idx.bf16_dataset().data_handle() + batch.offset() * dim,
bf16_dataset.data_handle(),
bf16_dataset.size(),
Expand Down
244 changes: 244 additions & 0 deletions cpp/src/neighbors/scann/detail/scann_quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include "../../detail/vpq_dataset.cuh"
#include <chrono>
#include <cmath>
#include <cuvs/neighbors/common.hpp>
#include <raft/linalg/transpose.cuh>
#include <raft/matrix/gather.cuh>

#include "scann_common.cuh"
Expand Down Expand Up @@ -267,6 +269,248 @@ void unpack_codes(raft::resources const& res,
}
}

/**
* @brief compute eta for AVQ according to Theorem 3.4 in https://arxiv.org/abs/1908.10396
*
* @tparam IdxT
* @param dim the dataset dimension
* @param sq_norm the squared norm of the vector
* @param noise_shaping_threshold the threshold T in the Theorem
* @return eta
*/
template <typename IdxT>
__device__ inline float compute_avq_eta(IdxT dim, const float sq_norm, const float threshold)
{
return (dim - 1) * (threshold * threshold / sq_norm) / (1 - threshold * threshold / sq_norm);
}

/**
* @brief helper to convert a float to bfloat16 (represented as int16_t)
*
* @param f the float value
* @return the bflaot16 value (as int16_t)
*/
__device__ inline int16_t float_to_bfloat16(const float& f)
{
nv_bfloat16 val = __float2bfloat16(f);
return reinterpret_cast<int16_t&>(val);
}

/**
* @brief helper to convert a bfloat16 (represented as int16_t) to float
*
* @param bf16 the bf16 value (represented as int16_t)
* @return the float value
*/
__device__ inline float bfloat16_to_float(int16_t& bf16)
{
nv_bfloat16 nv_bf16 = reinterpret_cast<nv_bfloat16&>(bf16);
return __bfloat162float(nv_bf16);
}

/**
* @brief Select the next bfloat16 value to try during coordinate descent
*
* Based on the signs of the current residual and quantized value,
* increment or decrement the quantized value to push residual closer to 0
*
* Note that the bfloat16 value is encoded as an int16_t, and the
* increment/decrement is applied to encoded value. In terms of the float
* representation, it is the mantissa that is being incremented/decremented,
* which could carryover to the exponent
*
* @param res the float residual
* @param current the current quantized dimension
* @return the other possible quantized value
*/
__device__ inline int16_t bfloat16_next_delta(float& res, int16_t& current)
{
uint32_t res_sign = ((int32_t)res & (1u << 31) >> 31);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, but would this be simpler as just (int32_t)res >= 0?

uint32_t curr_sign = (current & (1 << 15)) >> 15;

if (res_sign == curr_sign) { return current - 1; }

return current + 1;
}

template <uint32_t BlockSize, typename IdxT>
__launch_bounds__(BlockSize) RAFT_KERNEL
quantize_bfloat16_noise_shaped_kernel(raft::device_matrix_view<const float, IdxT> dataset,
raft::device_matrix_view<int16_t, IdxT> bf16_dataset,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we calling this bf16 if it's an int16 matrix? It seems you're casting it to bf16 or float throughput, but wouldn't it still always fall in the range of -2^15 to 2^15? Sorry if I'm missing the main bf16 benefit here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's convenient to use int16_t for the datatype in the index. OSS ScaNN expects an int16 matrix, so we can directly use the serialized result without any post processing. That's where bf16_dataset comes from, a view into that large int16_t matrix. With one exception, it doesn't really matter whether we use __nv_bfloat16 or int16_t. I use the former for the float <-> __nv_bfloat16 conversion functions, but I just store the bits reinterpreted as int16_t for convenience.

The only place it matters is bfloat16_next_delta. There I'm using the IEEE representation of a bfloat16 together with arithmetic operations on int16_t to generate the next bfloat16 number that is larger or smaller than the given value (as explained above, I do this by incrementing/decrementing the mantissa, which is equivalent to inc/dec the int16_t representation). I'm not aware of a bfloat16 version of std::nextafter, which is similar to what blfoat16_next_delta does.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for clarifying. Out of curiosity, would there be any perf advantage to this over using the original int16_t throughout (and casting it to float where needed)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm maybe, I'm not sure

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rmaschal your explanation here for why we use an integral type is great. Can we add that to the code for future eyes please?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjnolet the part about relating arithmetic on int16_t to finding the next smaller/larger bfloat16 is in the description of bfloat16_next_delta(..). But I added the full reasoning behind using int16_t (OSS ScaNN expects it + this avq specific convo) as comment in the index def in scann.hpp

raft::device_vector_view<const float> sq_norms,
float noise_shaping_threshold)
{
IdxT row_idx = raft::Pow2<32>::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x});

if (row_idx >= dataset.extent(0)) { return; }

uint32_t lane_id = raft::Pow2<32>::mod(threadIdx.x);

IdxT dim = dataset.extent(1);

// 1 / ||x||
float inv_norm = 1 / sqrtf(sq_norms[row_idx]);
float eta = compute_avq_eta(dim, sq_norms[row_idx], noise_shaping_threshold);

// < r, x >
float residual_dot = 0.0;

for (int i = lane_id; i < dim; i += 32) {
bf16_dataset(row_idx, i) = float_to_bfloat16(dataset(row_idx, i));
Comment thread
tfeher marked this conversation as resolved.

float residual = dataset(row_idx, i) - bfloat16_to_float(bf16_dataset(row_idx, i));
residual_dot += dataset(row_idx, i) * residual * inv_norm;
}

// reduce and broadcast residual_dot across warp
for (uint32_t offset = 16; offset > 0; offset >>= 1) {
residual_dot += raft::shfl_xor(residual_dot, offset, 32);
}
Comment thread
tfeher marked this conversation as resolved.

constexpr uint32_t kMaxRounds = 10;
Comment thread
rmaschal marked this conversation as resolved.

bool round_changes = true;
for (int round = 0; round < kMaxRounds && round_changes; round++) {
round_changes = false;

for (int i = lane_id; i < dim; i += 32) {
// coaleseced reads of required data
float original = dataset(row_idx, i);
int16_t quantized = bf16_dataset(row_idx, i);

float old_residual = original - bfloat16_to_float(quantized);
int16_t quantized_new = bfloat16_next_delta(old_residual, quantized);
Comment thread
tfeher marked this conversation as resolved.

float new_residual = original - bfloat16_to_float(quantized_new);
float residual_dot_delta = (new_residual - old_residual) * dataset(row_idx, i) * inv_norm;

float residual_norm_delta = new_residual * new_residual - old_residual * old_residual;

// we want to compute the change in cost = eta || r_parallel || ^2 + || r_perpendicular|| ^2
// The change in || r_parallel ||^2 can be written (residual_dot + residual_dot_delta) ^ 2
// the change in || r_perpendicular || ^2 can be written residual_norm_delta -
// parallel_norm_delta Thus cost_delta = eta * (residual_dot + residual_dot_delta) ^2 +
// (residual_norm_delta - (residual_dot + residual_dot_delta)^2 Expanding and simplying,
// cost_delta = a + b * resdiaul_dot, where a and b are as below. Since only residual_dot is
// unknown (because updates must be made synchronously) we can compute a and b in parallel
// across threads in the warp and minimize computation in the update step of the coordinate
// descent
float a = residual_norm_delta + (eta - 1) * residual_dot_delta * residual_dot_delta;
float b = 2 * (eta - 1) * residual_dot_delta;

// Dim may not be divisible by 32
// Only synchronize/shuffle for active threads
int active_threads = std::min<int>(32, dim - i + lane_id);
int mask = (1 << active_threads) - 1;

// Update step for coordinate descent. Compute the cost_delta for
// each thread, update the quantized value and residual_dot if applicable,
// then broadcast the new residual dot to the warp
// AVQ loss the not separable, so we must optimize each dimension separately
for (int j = 0; j < active_threads; j++) {
if (lane_id == j) {
// change in AVQ loss
float cost_delta = b * residual_dot + a;

if (cost_delta < 0.0) {
quantized = quantized_new;
residual_dot += residual_dot_delta;
round_changes = true;
}
}

// broadcast new dot product to all lanes
residual_dot = raft::shfl(residual_dot, j, active_threads, mask);
}

// coalesced write of possibly updated quantized values
bf16_dataset(row_idx, i) = quantized;
}

// reduce round_changes across warp
for (uint32_t offset = 16; offset > 0; offset >>= 1) {
round_changes |= raft::shfl_xor(round_changes, offset, 32);
}
}
}

/**
* @brief Quantized a float dataset as bfloat16, with noise shaping (AVQ)
*
* During quantization we replace each input vector coordinate `f` of type float32 with a bfloat16
* coordinate `b`. One way to do this would be to simply assign the nearest bfloat16 value to
* each coordinate. This would be the best way to quantize if we want to minimize the L2
* distance between the quantized and the original vector.
*
* In the AVQ method, we use a different cost function. To minimize that, we consider nearest
* representable bfloat16 values (`b1`, `b2`) around `f`, and select the one that minimizes the AVQ
* cost function. In two dimensions we need to consider the four neighboring quantized vectors:
* b1 b2
* f
* b3 b4
*
* In N dimension we will select one the vertices of an N dimensional hypercube as the quantized
* vector. To find the minimum without enumerating all the combinations, a coordinate descent
* method is used.
* @tparam IdxT
* @param res raft resources
* @param dataset the dataset (device only) size [n_rows, dim]
* @param bf16_dataset the quantized dataset (device only) size [n_rows, dim]
* @param noise_shaping_threshold the threshold for AVQ
*/
template <typename IdxT>
void quantize_bfloat16_noise_shaped(raft::resources const& res,
raft::device_matrix_view<const float, IdxT> dataset,
raft::device_matrix_view<int16_t, IdxT> bf16_dataset,
float noise_shaping_threshold)
{
cudaStream_t stream = raft::resource::get_cuda_stream(res);

IdxT n_rows = dataset.extent(0);
auto norms = raft::make_device_vector<float, IdxT>(res, n_rows);

// populate square norms
raft::linalg::norm<raft::linalg::NormType::L2Norm, raft::Apply::ALONG_ROWS>(
res, dataset, norms.view());

constexpr int64_t kBlockSize = 256;

dim3 threads(kBlockSize, 1, 1);
dim3 blocks(raft::div_rounding_up_safe<ix_t>(n_rows, kBlockSize / 32), 1, 1);

quantize_bfloat16_noise_shaped_kernel<kBlockSize, IdxT><<<blocks, threads, 0, stream>>>(
dataset, bf16_dataset, raft::make_const_mdspan(norms.view()), noise_shaping_threshold);

RAFT_CUDA_TRY(cudaPeekAtLastError());
}

/**
* @brief Quantized a float dataset as bfloat16, with optional noise shaping (AVQ)
*
* @tparam IdxT
* @param res raft resources
* @param dataset the dataset (device only) size [n_rows, dim]
* @param bf16_dataset the quantized dataset (device only) size [n_rows, dim]
* @param noise_shaping_threshold the threshold for AVQ (nan when not using AVQ)
*/
template <typename IdxT>
void quantize_bfloat16(raft::resources const& res,
raft::device_matrix_view<const float, IdxT> dataset,
raft::device_matrix_view<int16_t, IdxT> bf16_dataset,
float noise_shaping_threshold)
{
if (!std::isnan(noise_shaping_threshold)) {
quantize_bfloat16_noise_shaped(res, dataset, bf16_dataset, noise_shaping_threshold);
} else {
raft::linalg::unaryOp(
bf16_dataset.data_handle(),
dataset.data_handle(),
dataset.size(),
[] __device__(float x) { return float_to_bfloat16(x); },
resource::get_cuda_stream(res));
}
}

/**
* @brief sample dataset vectors/labels and compute their residuals for PQ training
*
Expand Down
13 changes: 11 additions & 2 deletions cpp/tests/neighbors/ann_scann.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class scann_test : public ::testing::TestWithParam<scann_inputs> {
ASSERT_EQ(index.pq_codebook().extent(0), num_pq_clusters);
ASSERT_EQ(index.pq_codebook().extent(1), ps.dim);

IdxT expected_bf16_size = ps.index_params.bf16_enabled ? ps.dim * ps.num_db_vecs : 0;
IdxT expected_bf16_size = ps.index_params.reordering_bf16 ? ps.dim * ps.num_db_vecs : 0;

ASSERT_EQ(index.bf16_dataset().size(), expected_bf16_size);
}
Expand Down Expand Up @@ -227,7 +227,16 @@ inline auto big_dims_all_pq_bits() -> test_cases_t
inline auto bf16() -> test_cases_t
{
scann_inputs ts;
ts.index_params.bf16_enabled = true;
ts.index_params.reordering_bf16 = true;

return {ts};
}

inline auto bf16_avq() -> test_cases_t
{
scann_inputs ts;
ts.index_params.reordering_bf16 = true;
ts.index_params.reordering_noise_shaping_threshold = 0.2;

return {ts};
}
Expand Down
Loading