-
Notifications
You must be signed in to change notification settings - Fork 184
[Review] ScaNN: Add option for AVQ/Noise Shaping to bfloat16 quantization #1354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
872774a
8f73092
64552f1
ed4e5be
b1c99f4
593638a
bf9d314
f3b126d
9eabc1e
606fc11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,9 @@ | |
|
|
||
| #include "../../detail/vpq_dataset.cuh" | ||
| #include <chrono> | ||
| #include <cmath> | ||
| #include <cuvs/neighbors/common.hpp> | ||
| #include <raft/linalg/transpose.cuh> | ||
| #include <raft/matrix/gather.cuh> | ||
|
|
||
| #include "scann_common.cuh" | ||
|
|
@@ -267,6 +269,248 @@ void unpack_codes(raft::resources const& res, | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * @brief compute eta for AVQ according to Theorem 3.4 in https://arxiv.org/abs/1908.10396 | ||
| * | ||
| * @tparam IdxT | ||
| * @param dim the dataset dimension | ||
| * @param sq_norm the squared norm of the vector | ||
| * @param noise_shaping_threshold the threshold T in the Theorem | ||
| * @return eta | ||
| */ | ||
| template <typename IdxT> | ||
| __device__ inline float compute_avq_eta(IdxT dim, const float sq_norm, const float threshold) | ||
| { | ||
| return (dim - 1) * (threshold * threshold / sq_norm) / (1 - threshold * threshold / sq_norm); | ||
| } | ||
|
|
||
| /** | ||
| * @brief helper to convert a float to bfloat16 (represented as int16_t) | ||
| * | ||
| * @param f the float value | ||
| * @return the bflaot16 value (as int16_t) | ||
| */ | ||
| __device__ inline int16_t float_to_bfloat16(const float& f) | ||
| { | ||
| nv_bfloat16 val = __float2bfloat16(f); | ||
| return reinterpret_cast<int16_t&>(val); | ||
| } | ||
|
|
||
| /** | ||
| * @brief helper to convert a bfloat16 (represented as int16_t) to float | ||
| * | ||
| * @param bf16 the bf16 value (represented as int16_t) | ||
| * @return the float value | ||
| */ | ||
| __device__ inline float bfloat16_to_float(int16_t& bf16) | ||
| { | ||
| nv_bfloat16 nv_bf16 = reinterpret_cast<nv_bfloat16&>(bf16); | ||
| return __bfloat162float(nv_bf16); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Select the next bfloat16 value to try during coordinate descent | ||
| * | ||
| * Based on the signs of the current residual and quantized value, | ||
| * increment or decrement the quantized value to push residual closer to 0 | ||
| * | ||
| * Note that the bfloat16 value is encoded as an int16_t, and the | ||
| * increment/decrement is applied to encoded value. In terms of the float | ||
| * representation, it is the mantissa that is being incremented/decremented, | ||
| * which could carryover to the exponent | ||
| * | ||
| * @param res the float residual | ||
| * @param current the current quantized dimension | ||
| * @return the other possible quantized value | ||
| */ | ||
| __device__ inline int16_t bfloat16_next_delta(float& res, int16_t& current) | ||
| { | ||
| uint32_t res_sign = ((int32_t)res & (1u << 31) >> 31); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick, but would this be simpler as just |
||
| uint32_t curr_sign = (current & (1 << 15)) >> 15; | ||
|
|
||
| if (res_sign == curr_sign) { return current - 1; } | ||
|
|
||
| return current + 1; | ||
| } | ||
|
|
||
| template <uint32_t BlockSize, typename IdxT> | ||
| __launch_bounds__(BlockSize) RAFT_KERNEL | ||
| quantize_bfloat16_noise_shaped_kernel(raft::device_matrix_view<const float, IdxT> dataset, | ||
| raft::device_matrix_view<int16_t, IdxT> bf16_dataset, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we calling this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's convenient to use int16_t for the datatype in the index. OSS ScaNN expects an int16 matrix, so we can directly use the serialized result without any post processing. That's where bf16_dataset comes from, a view into that large int16_t matrix. With one exception, it doesn't really matter whether we use __nv_bfloat16 or int16_t. I use the former for the float <-> __nv_bfloat16 conversion functions, but I just store the bits reinterpreted as int16_t for convenience. The only place it matters is bfloat16_next_delta. There I'm using the IEEE representation of a bfloat16 together with arithmetic operations on int16_t to generate the next bfloat16 number that is larger or smaller than the given value (as explained above, I do this by incrementing/decrementing the mantissa, which is equivalent to inc/dec the int16_t representation). I'm not aware of a bfloat16 version of std::nextafter, which is similar to what blfoat16_next_delta does.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, thanks for clarifying. Out of curiosity, would there be any perf advantage to this over using the original
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm maybe, I'm not sure
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rmaschal your explanation here for why we use an integral type is great. Can we add that to the code for future eyes please?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cjnolet the part about relating arithmetic on int16_t to finding the next smaller/larger bfloat16 is in the description of bfloat16_next_delta(..). But I added the full reasoning behind using int16_t (OSS ScaNN expects it + this avq specific convo) as comment in the index def in scann.hpp |
||
| raft::device_vector_view<const float> sq_norms, | ||
| float noise_shaping_threshold) | ||
| { | ||
| IdxT row_idx = raft::Pow2<32>::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); | ||
|
|
||
| if (row_idx >= dataset.extent(0)) { return; } | ||
|
|
||
| uint32_t lane_id = raft::Pow2<32>::mod(threadIdx.x); | ||
|
|
||
| IdxT dim = dataset.extent(1); | ||
|
|
||
| // 1 / ||x|| | ||
| float inv_norm = 1 / sqrtf(sq_norms[row_idx]); | ||
| float eta = compute_avq_eta(dim, sq_norms[row_idx], noise_shaping_threshold); | ||
|
|
||
| // < r, x > | ||
| float residual_dot = 0.0; | ||
|
|
||
| for (int i = lane_id; i < dim; i += 32) { | ||
| bf16_dataset(row_idx, i) = float_to_bfloat16(dataset(row_idx, i)); | ||
|
tfeher marked this conversation as resolved.
|
||
|
|
||
| float residual = dataset(row_idx, i) - bfloat16_to_float(bf16_dataset(row_idx, i)); | ||
| residual_dot += dataset(row_idx, i) * residual * inv_norm; | ||
| } | ||
|
|
||
| // reduce and broadcast residual_dot across warp | ||
| for (uint32_t offset = 16; offset > 0; offset >>= 1) { | ||
| residual_dot += raft::shfl_xor(residual_dot, offset, 32); | ||
| } | ||
|
tfeher marked this conversation as resolved.
|
||
|
|
||
| constexpr uint32_t kMaxRounds = 10; | ||
|
rmaschal marked this conversation as resolved.
|
||
|
|
||
| bool round_changes = true; | ||
| for (int round = 0; round < kMaxRounds && round_changes; round++) { | ||
| round_changes = false; | ||
|
|
||
| for (int i = lane_id; i < dim; i += 32) { | ||
| // coaleseced reads of required data | ||
| float original = dataset(row_idx, i); | ||
| int16_t quantized = bf16_dataset(row_idx, i); | ||
|
|
||
| float old_residual = original - bfloat16_to_float(quantized); | ||
| int16_t quantized_new = bfloat16_next_delta(old_residual, quantized); | ||
|
tfeher marked this conversation as resolved.
|
||
|
|
||
| float new_residual = original - bfloat16_to_float(quantized_new); | ||
| float residual_dot_delta = (new_residual - old_residual) * dataset(row_idx, i) * inv_norm; | ||
|
|
||
| float residual_norm_delta = new_residual * new_residual - old_residual * old_residual; | ||
|
|
||
| // we want to compute the change in cost = eta || r_parallel || ^2 + || r_perpendicular|| ^2 | ||
| // The change in || r_parallel ||^2 can be written (residual_dot + residual_dot_delta) ^ 2 | ||
| // the change in || r_perpendicular || ^2 can be written residual_norm_delta - | ||
| // parallel_norm_delta Thus cost_delta = eta * (residual_dot + residual_dot_delta) ^2 + | ||
| // (residual_norm_delta - (residual_dot + residual_dot_delta)^2 Expanding and simplying, | ||
| // cost_delta = a + b * resdiaul_dot, where a and b are as below. Since only residual_dot is | ||
| // unknown (because updates must be made synchronously) we can compute a and b in parallel | ||
| // across threads in the warp and minimize computation in the update step of the coordinate | ||
| // descent | ||
| float a = residual_norm_delta + (eta - 1) * residual_dot_delta * residual_dot_delta; | ||
| float b = 2 * (eta - 1) * residual_dot_delta; | ||
|
|
||
| // Dim may not be divisible by 32 | ||
| // Only synchronize/shuffle for active threads | ||
| int active_threads = std::min<int>(32, dim - i + lane_id); | ||
| int mask = (1 << active_threads) - 1; | ||
|
|
||
| // Update step for coordinate descent. Compute the cost_delta for | ||
| // each thread, update the quantized value and residual_dot if applicable, | ||
| // then broadcast the new residual dot to the warp | ||
| // AVQ loss the not separable, so we must optimize each dimension separately | ||
| for (int j = 0; j < active_threads; j++) { | ||
| if (lane_id == j) { | ||
| // change in AVQ loss | ||
| float cost_delta = b * residual_dot + a; | ||
|
|
||
| if (cost_delta < 0.0) { | ||
| quantized = quantized_new; | ||
| residual_dot += residual_dot_delta; | ||
| round_changes = true; | ||
| } | ||
| } | ||
|
|
||
| // broadcast new dot product to all lanes | ||
| residual_dot = raft::shfl(residual_dot, j, active_threads, mask); | ||
| } | ||
|
|
||
| // coalesced write of possibly updated quantized values | ||
| bf16_dataset(row_idx, i) = quantized; | ||
| } | ||
|
|
||
| // reduce round_changes across warp | ||
| for (uint32_t offset = 16; offset > 0; offset >>= 1) { | ||
| round_changes |= raft::shfl_xor(round_changes, offset, 32); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * @brief Quantized a float dataset as bfloat16, with noise shaping (AVQ) | ||
| * | ||
| * During quantization we replace each input vector coordinate `f` of type float32 with a bfloat16 | ||
| * coordinate `b`. One way to do this would be to simply assign the nearest bfloat16 value to | ||
| * each coordinate. This would be the best way to quantize if we want to minimize the L2 | ||
| * distance between the quantized and the original vector. | ||
| * | ||
| * In the AVQ method, we use a different cost function. To minimize that, we consider nearest | ||
| * representable bfloat16 values (`b1`, `b2`) around `f`, and select the one that minimizes the AVQ | ||
| * cost function. In two dimensions we need to consider the four neighboring quantized vectors: | ||
| * b1 b2 | ||
| * f | ||
| * b3 b4 | ||
| * | ||
| * In N dimension we will select one the vertices of an N dimensional hypercube as the quantized | ||
| * vector. To find the minimum without enumerating all the combinations, a coordinate descent | ||
| * method is used. | ||
| * @tparam IdxT | ||
| * @param res raft resources | ||
| * @param dataset the dataset (device only) size [n_rows, dim] | ||
| * @param bf16_dataset the quantized dataset (device only) size [n_rows, dim] | ||
| * @param noise_shaping_threshold the threshold for AVQ | ||
| */ | ||
| template <typename IdxT> | ||
| void quantize_bfloat16_noise_shaped(raft::resources const& res, | ||
| raft::device_matrix_view<const float, IdxT> dataset, | ||
| raft::device_matrix_view<int16_t, IdxT> bf16_dataset, | ||
| float noise_shaping_threshold) | ||
| { | ||
| cudaStream_t stream = raft::resource::get_cuda_stream(res); | ||
|
|
||
| IdxT n_rows = dataset.extent(0); | ||
| auto norms = raft::make_device_vector<float, IdxT>(res, n_rows); | ||
|
|
||
| // populate square norms | ||
| raft::linalg::norm<raft::linalg::NormType::L2Norm, raft::Apply::ALONG_ROWS>( | ||
| res, dataset, norms.view()); | ||
|
|
||
| constexpr int64_t kBlockSize = 256; | ||
|
|
||
| dim3 threads(kBlockSize, 1, 1); | ||
| dim3 blocks(raft::div_rounding_up_safe<ix_t>(n_rows, kBlockSize / 32), 1, 1); | ||
|
|
||
| quantize_bfloat16_noise_shaped_kernel<kBlockSize, IdxT><<<blocks, threads, 0, stream>>>( | ||
| dataset, bf16_dataset, raft::make_const_mdspan(norms.view()), noise_shaping_threshold); | ||
|
|
||
| RAFT_CUDA_TRY(cudaPeekAtLastError()); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Quantized a float dataset as bfloat16, with optional noise shaping (AVQ) | ||
| * | ||
| * @tparam IdxT | ||
| * @param res raft resources | ||
| * @param dataset the dataset (device only) size [n_rows, dim] | ||
| * @param bf16_dataset the quantized dataset (device only) size [n_rows, dim] | ||
| * @param noise_shaping_threshold the threshold for AVQ (nan when not using AVQ) | ||
| */ | ||
| template <typename IdxT> | ||
| void quantize_bfloat16(raft::resources const& res, | ||
| raft::device_matrix_view<const float, IdxT> dataset, | ||
| raft::device_matrix_view<int16_t, IdxT> bf16_dataset, | ||
| float noise_shaping_threshold) | ||
| { | ||
| if (!std::isnan(noise_shaping_threshold)) { | ||
| quantize_bfloat16_noise_shaped(res, dataset, bf16_dataset, noise_shaping_threshold); | ||
| } else { | ||
| raft::linalg::unaryOp( | ||
| bf16_dataset.data_handle(), | ||
| dataset.data_handle(), | ||
| dataset.size(), | ||
| [] __device__(float x) { return float_to_bfloat16(x); }, | ||
| resource::get_cuda_stream(res)); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * @brief sample dataset vectors/labels and compute their residuals for PQ training | ||
| * | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.