Skip to content

[Review] ScaNN: Add option for AVQ/Noise Shaping to bfloat16 quantization#1354

Merged
rapids-bot[bot] merged 10 commits intorapidsai:mainfrom
rmaschal:bfloat-avq
Oct 21, 2025
Merged

[Review] ScaNN: Add option for AVQ/Noise Shaping to bfloat16 quantization#1354
rapids-bot[bot] merged 10 commits intorapidsai:mainfrom
rmaschal:bfloat-avq

Conversation

@rmaschal
Copy link
Copy Markdown
Contributor

This PR adds support for AVQ loss/Noise shaping to the BFloat16 dataset quantization.

AVQ loss is a modified version of L2 loss which separately penalizes the components of the residual vector which are parallel and perpendicular to the original vector. Quantizing vectors with AVQ loss rather than L2 loss gives a better approximation of the inner product, and thus performs better in Maximal Innter Product Search (https://arxiv.org/abs/1908.10396).

Math:
x : original vector
x_q : quantized vector
r = x - x_q : residual vector
r_para = < r , x > x / || x ||^2 : parallel component of the residual
r_perp = r - r_para : perpendicular compoent of the residual
eta >= 1 : AVQ parameter
AVQ loss = eta * || r_para ||^2 + || r_perp ||^2

For a float vector x, the goal is to find a bfloat16 vector x_q which minimizes the AVQ loss for a given eta. Unlike L2, AVQ loss is not separable (e.g. ||r_para||^2 contains cross terms from the inner product), so we cannot optimize individual dimensions in parallel and expect convergence. Instead, we use coordinate descent to optimize dimensions of x_q one at a time, until convergence.

This coordinate descent happens in the new kernel "quantize_bfloat16_noise_shaped_kernel". For efficient memory accesses and compute, one warp is assigned to optimize each dataset vector. The computation of avq loss is algebraically separated into two pieces: those which can be computed in parallel (i.e. those only depending on local information for the assigned dimension) and those which require global information (namely those depending on < r , x >). Finally threads in a warp serialize to compute the final cost for their dimension, update the quantized value and value of < r , x > (if applicable), and broadcast the updated value of < r, x > for other threads. This continues in blocks of 32 dimensions, until convergence (or a maximum of 10 iterations).

I've found this strategy does a good job taking advantage of the inherently row_major structure of the dataset/index for efficient coalesced accesses, while still making good use of compute resources (hitting >90% compute throughput on an A6000).

Besides the coordinate descent kernel, this PR adds some helper functions for the above, refactors the existing bfloat16 to take advantage of them, and adds configuration for the AVQ eta (code uses normal bfloat16 quantization when avq threshold is NaN).

@rmaschal rmaschal requested a review from a team as a code owner September 22, 2025 18:45
@rmaschal rmaschal added improvement Improves an existing functionality non-breaking Introduces a non-breaking change labels Sep 22, 2025
@rmaschal rmaschal self-assigned this Sep 22, 2025
Comment thread cpp/include/cuvs/neighbors/scann.hpp Outdated
Comment thread cpp/src/neighbors/scann/detail/scann_build.cuh
stream);

if (params.bf16_enabled) {
if (params.reordering_bf16_enabled) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking over your implementation here trying to figure out the expectations from the user. What I'm gathering is that this is taking a dataset in float and quantizing to bfloat16.

Is there a benefit to doing the quantization inside the build process instead of the user quantizing up front and then passing the quantized dataset into build()? Is it because it's quantizing independently for each IVF list rather than quantizing the dataset as a whole?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's quantizing the dataset as a whole, but the full precision dataset is used for all other steps of the build process, not the quantized dataset. So the user quantizing first and passing that in wouldn't make sense.

I do the quant is this particular loop mostly for locality reasons. We have a batch on gpu for the other soar/quant steps, might as well do the bfloat16 too.

For regular bfloat16 quant (basically chop off the last 16 bits) it's not really important that we do it right here. In fact, the extra DtoH copy might be more expensive than doing regular bfloat16 right on cpu. I haven't tried yet.

For bfloat16 w/ AVQ (introduced in this PR), we definitely want to use gpu, and I don't think it makes sense to introduce another HtoD copy of the dataset to do it elsewhere (e.g. outside of build)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For bfloat16 w/ AVQ (introduced in this PR), we definitely want to use gpu, and I don't think it makes sense to introduce another HtoD copy of the dataset to do it elsewhere (e.g. outside of build)

The question here is more about reuse than it is about whether or not you should be doing this in the build. While I'm not particularly fond of the idea of maintaining copies in device memory, I think I can accept that more. What I'm concerned about is having different implementations of the same sorts of computations, transformations, and computational functions duplicated throughout the codebased. That's why I ask the question "Can this be it's own API or is it inherently coupled to SCaNN?"

Just as a heads up, if we end up finding that it's more reusable, we're going to want to pull it out into its own thing, and if that means introducing an addtional copy or small perf hit as a result, that's a trade-off we are generally wiling to make in order to keep the codebase clean and maintainable. However, I'm fine keeping it where it is in the meantime if it's the only place in the codebase that needs to do it. Does that make sense?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me. In that case, your concern applies to several things I've added and will add to ScaNN (anything related to AVQ, SOAR, etc.). I've generally added these under ScaNN for my convenience, but in principle they could all be generalized. There is no inherit coupling with "ScaNN" (which is a library, not an algorithm). The coupling is only with the query distance function (should be dot product).

The perf issue i've seen show up with other functions (the predict function for vpq comes to mind). I think this could be solved by providing a single batch device api, and a batched (host or device) api which calls the single batch api. I can use the first in my build func since I want control of the batching (e.g. for good locality) and the batch apis are more general/ergonomic. I feel like this out-of-scope for this PR, but this pattern makes sense I think.

Copy link
Copy Markdown
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good so far. We need to make sure our public APIs are well documented, especially for users not inherently familiar w/ SCaNN.

Comment thread cpp/tests/neighbors/ann_scann.cuh Outdated
Comment thread cpp/include/cuvs/neighbors/scann.hpp
Comment thread cpp/include/cuvs/neighbors/scann.hpp Outdated
* If the threshold is NAN, AVQ is not performed during bfloat16 quant
*/
float reordering_noise_shaping_threshold = NAN;
// TODO - add other scann build params
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please create a Github issue for this and reference it here.As a general rule, we don't just drop todos in comments without having a corresponding Github issue for you or someone else to follow up.

Copy link
Copy Markdown
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Robert, thank you for the PR, I am looking through it here are my first two comments.

Comment thread cpp/src/neighbors/scann/detail/scann_quantize.cuh
Comment thread cpp/src/neighbors/scann/detail/scann_quantize.cuh
Comment thread cpp/src/neighbors/scann/detail/scann_quantize.cuh
@rmaschal rmaschal force-pushed the bfloat-avq branch 2 times, most recently from 6c7f3e0 to 6a84f5a Compare October 1, 2025 20:44
Copy link
Copy Markdown
Contributor

@bkarsin bkarsin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few clarifying questions / nitpicks on my side. Aside from these looks good to me.

Comment thread cpp/include/cuvs/neighbors/scann.hpp
*/
__device__ inline int16_t bfloat16_next_delta(float& res, int16_t& current)
{
uint32_t res_sign = ((int32_t)res & (1u << 31) >> 31);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, but would this be simpler as just (int32_t)res >= 0?

Comment thread cpp/src/neighbors/scann/detail/scann_quantize.cuh
template <uint32_t BlockSize, typename IdxT>
__launch_bounds__(BlockSize) RAFT_KERNEL
quantize_bfloat16_noise_shaped_kernel(raft::device_matrix_view<const float, IdxT> dataset,
raft::device_matrix_view<int16_t, IdxT> bf16_dataset,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we calling this bf16 if it's an int16 matrix? It seems you're casting it to bf16 or float throughput, but wouldn't it still always fall in the range of -2^15 to 2^15? Sorry if I'm missing the main bf16 benefit here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's convenient to use int16_t for the datatype in the index. OSS ScaNN expects an int16 matrix, so we can directly use the serialized result without any post processing. That's where bf16_dataset comes from, a view into that large int16_t matrix. With one exception, it doesn't really matter whether we use __nv_bfloat16 or int16_t. I use the former for the float <-> __nv_bfloat16 conversion functions, but I just store the bits reinterpreted as int16_t for convenience.

The only place it matters is bfloat16_next_delta. There I'm using the IEEE representation of a bfloat16 together with arithmetic operations on int16_t to generate the next bfloat16 number that is larger or smaller than the given value (as explained above, I do this by incrementing/decrementing the mantissa, which is equivalent to inc/dec the int16_t representation). I'm not aware of a bfloat16 version of std::nextafter, which is similar to what blfoat16_next_delta does.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for clarifying. Out of curiosity, would there be any perf advantage to this over using the original int16_t throughout (and casting it to float where needed)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm maybe, I'm not sure

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rmaschal your explanation here for why we use an integral type is great. Can we add that to the code for future eyes please?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjnolet the part about relating arithmetic on int16_t to finding the next smaller/larger bfloat16 is in the description of bfloat16_next_delta(..). But I added the full reasoning behind using int16_t (OSS ScaNN expects it + this avq specific convo) as comment in the index def in scann.hpp

Copy link
Copy Markdown
Contributor

@bkarsin bkarsin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for answering my questions. PR looks good to me.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Oct 6, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@rmaschal rmaschal force-pushed the bfloat-avq branch 2 times, most recently from 89a2ec7 to 9f495b0 Compare October 6, 2025 18:12
@rmaschal rmaschal force-pushed the bfloat-avq branch 2 times, most recently from 937bf9c to e849db7 Compare October 6, 2025 18:35
@rmaschal rmaschal requested a review from cjnolet October 6, 2025 19:36
Copy link
Copy Markdown
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

He @rmaschal the PR looks good overall. I have one suggestion to improve the documentation.

Comment thread cpp/src/neighbors/scann/detail/scann_quantize.cuh
Copy link
Copy Markdown
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rmaschal for the updates! The PR looks good to me.

@cjnolet cjnolet changed the base branch from branch-25.10 to branch-25.12 October 10, 2025 18:58
@cjnolet
Copy link
Copy Markdown
Member

cjnolet commented Oct 21, 2025

/merge

@rapids-bot rapids-bot Bot merged commit a409b30 into rapidsai:main Oct 21, 2025
161 of 164 checks passed
robertmaynard pushed a commit to robertmaynard/cuvs that referenced this pull request Oct 28, 2025
…tion (rapidsai#1354)

This PR adds support for AVQ loss/Noise shaping to the BFloat16 dataset quantization. 

AVQ loss is a modified version of L2 loss which separately penalizes the components of the residual vector which are parallel and perpendicular to the original vector. Quantizing vectors with AVQ loss rather than L2 loss gives a better approximation of the inner product, and thus performs better in Maximal Innter Product Search (https://arxiv.org/abs/1908.10396).

Math:
x : original vector
x_q : quantized vector
r = x - x_q : residual vector
r_para = < r , x > x / || x ||^2 : parallel component of the residual
r_perp = r - r_para : perpendicular compoent of the residual
eta >= 1 : AVQ parameter
AVQ loss = eta * || r_para ||^2 +   || r_perp ||^2

For a float vector x, the goal is to find a bfloat16 vector x_q which minimizes the AVQ loss for a given eta. Unlike L2, AVQ loss is not separable (e.g. ||r_para||^2 contains cross terms from the inner product), so we cannot optimize individual dimensions in parallel and expect convergence. Instead, we use coordinate descent to optimize dimensions of x_q one at a time,  until convergence. 

This coordinate descent happens in the new kernel "quantize_bfloat16_noise_shaped_kernel". For efficient memory accesses and compute, one warp is assigned to optimize each dataset vector. The computation of avq loss is algebraically separated into two pieces: those which can be computed in parallel (i.e. those only depending on local information for the assigned dimension) and those which require global information (namely those depending on < r , x >). Finally threads in a warp serialize to compute the final cost for their dimension, update the quantized value and value of < r , x > (if applicable), and broadcast the updated value of < r, x > for other threads. This continues in blocks of 32 dimensions, until convergence (or a maximum of 10 iterations). 

I've found this strategy does a good job taking advantage of the inherently row_major structure of the dataset/index for efficient coalesced accesses, while still making good use of compute resources (hitting >90% compute throughput on an A6000). 

Besides the coordinate descent kernel, this PR adds some helper functions for the above, refactors the existing bfloat16 to take advantage of them, and adds configuration for the AVQ eta (code uses normal bfloat16 quantization when avq threshold is NaN).

Authors:
  - https://github.com/rmaschal
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Ben Karsin (https://github.com/bkarsin)
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#1354
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

improvement Improves an existing functionality non-breaking Introduces a non-breaking change

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants