[Review] ScaNN: Add option for AVQ/Noise Shaping to bfloat16 quantization#1354
[Review] ScaNN: Add option for AVQ/Noise Shaping to bfloat16 quantization#1354rapids-bot[bot] merged 10 commits intorapidsai:mainfrom
Conversation
09a4e65 to
0deb206
Compare
| stream); | ||
|
|
||
| if (params.bf16_enabled) { | ||
| if (params.reordering_bf16_enabled) { |
There was a problem hiding this comment.
I'm looking over your implementation here trying to figure out the expectations from the user. What I'm gathering is that this is taking a dataset in float and quantizing to bfloat16.
Is there a benefit to doing the quantization inside the build process instead of the user quantizing up front and then passing the quantized dataset into build()? Is it because it's quantizing independently for each IVF list rather than quantizing the dataset as a whole?
There was a problem hiding this comment.
It's quantizing the dataset as a whole, but the full precision dataset is used for all other steps of the build process, not the quantized dataset. So the user quantizing first and passing that in wouldn't make sense.
I do the quant is this particular loop mostly for locality reasons. We have a batch on gpu for the other soar/quant steps, might as well do the bfloat16 too.
For regular bfloat16 quant (basically chop off the last 16 bits) it's not really important that we do it right here. In fact, the extra DtoH copy might be more expensive than doing regular bfloat16 right on cpu. I haven't tried yet.
For bfloat16 w/ AVQ (introduced in this PR), we definitely want to use gpu, and I don't think it makes sense to introduce another HtoD copy of the dataset to do it elsewhere (e.g. outside of build)
There was a problem hiding this comment.
For bfloat16 w/ AVQ (introduced in this PR), we definitely want to use gpu, and I don't think it makes sense to introduce another HtoD copy of the dataset to do it elsewhere (e.g. outside of build)
The question here is more about reuse than it is about whether or not you should be doing this in the build. While I'm not particularly fond of the idea of maintaining copies in device memory, I think I can accept that more. What I'm concerned about is having different implementations of the same sorts of computations, transformations, and computational functions duplicated throughout the codebased. That's why I ask the question "Can this be it's own API or is it inherently coupled to SCaNN?"
Just as a heads up, if we end up finding that it's more reusable, we're going to want to pull it out into its own thing, and if that means introducing an addtional copy or small perf hit as a result, that's a trade-off we are generally wiling to make in order to keep the codebase clean and maintainable. However, I'm fine keeping it where it is in the meantime if it's the only place in the codebase that needs to do it. Does that make sense?
There was a problem hiding this comment.
Makes sense to me. In that case, your concern applies to several things I've added and will add to ScaNN (anything related to AVQ, SOAR, etc.). I've generally added these under ScaNN for my convenience, but in principle they could all be generalized. There is no inherit coupling with "ScaNN" (which is a library, not an algorithm). The coupling is only with the query distance function (should be dot product).
The perf issue i've seen show up with other functions (the predict function for vpq comes to mind). I think this could be solved by providing a single batch device api, and a batched (host or device) api which calls the single batch api. I can use the first in my build func since I want control of the batching (e.g. for good locality) and the batch apis are more general/ergonomic. I feel like this out-of-scope for this PR, but this pattern makes sense I think.
0deb206 to
b37a403
Compare
cjnolet
left a comment
There was a problem hiding this comment.
Looking good so far. We need to make sure our public APIs are well documented, especially for users not inherently familiar w/ SCaNN.
| * If the threshold is NAN, AVQ is not performed during bfloat16 quant | ||
| */ | ||
| float reordering_noise_shaping_threshold = NAN; | ||
| // TODO - add other scann build params |
There was a problem hiding this comment.
Please create a Github issue for this and reference it here.As a general rule, we don't just drop todos in comments without having a corresponding Github issue for you or someone else to follow up.
tfeher
left a comment
There was a problem hiding this comment.
Hi Robert, thank you for the PR, I am looking through it here are my first two comments.
6c7f3e0 to
6a84f5a
Compare
bkarsin
left a comment
There was a problem hiding this comment.
Just a few clarifying questions / nitpicks on my side. Aside from these looks good to me.
| */ | ||
| __device__ inline int16_t bfloat16_next_delta(float& res, int16_t& current) | ||
| { | ||
| uint32_t res_sign = ((int32_t)res & (1u << 31) >> 31); |
There was a problem hiding this comment.
Nitpick, but would this be simpler as just (int32_t)res >= 0?
| template <uint32_t BlockSize, typename IdxT> | ||
| __launch_bounds__(BlockSize) RAFT_KERNEL | ||
| quantize_bfloat16_noise_shaped_kernel(raft::device_matrix_view<const float, IdxT> dataset, | ||
| raft::device_matrix_view<int16_t, IdxT> bf16_dataset, |
There was a problem hiding this comment.
Why are we calling this bf16 if it's an int16 matrix? It seems you're casting it to bf16 or float throughput, but wouldn't it still always fall in the range of -2^15 to 2^15? Sorry if I'm missing the main bf16 benefit here.
There was a problem hiding this comment.
It's convenient to use int16_t for the datatype in the index. OSS ScaNN expects an int16 matrix, so we can directly use the serialized result without any post processing. That's where bf16_dataset comes from, a view into that large int16_t matrix. With one exception, it doesn't really matter whether we use __nv_bfloat16 or int16_t. I use the former for the float <-> __nv_bfloat16 conversion functions, but I just store the bits reinterpreted as int16_t for convenience.
The only place it matters is bfloat16_next_delta. There I'm using the IEEE representation of a bfloat16 together with arithmetic operations on int16_t to generate the next bfloat16 number that is larger or smaller than the given value (as explained above, I do this by incrementing/decrementing the mantissa, which is equivalent to inc/dec the int16_t representation). I'm not aware of a bfloat16 version of std::nextafter, which is similar to what blfoat16_next_delta does.
There was a problem hiding this comment.
I see, thanks for clarifying. Out of curiosity, would there be any perf advantage to this over using the original int16_t throughout (and casting it to float where needed)?
There was a problem hiding this comment.
Hmm maybe, I'm not sure
There was a problem hiding this comment.
@rmaschal your explanation here for why we use an integral type is great. Can we add that to the code for future eyes please?
There was a problem hiding this comment.
@cjnolet the part about relating arithmetic on int16_t to finding the next smaller/larger bfloat16 is in the description of bfloat16_next_delta(..). But I added the full reasoning behind using int16_t (OSS ScaNN expects it + this avq specific convo) as comment in the index def in scann.hpp
bkarsin
left a comment
There was a problem hiding this comment.
Thanks for answering my questions. PR looks good to me.
89a2ec7 to
9f495b0
Compare
937bf9c to
e849db7
Compare
Co-authored-by: Tamas Bela Feher <[email protected]>
|
/merge |
…tion (rapidsai#1354) This PR adds support for AVQ loss/Noise shaping to the BFloat16 dataset quantization. AVQ loss is a modified version of L2 loss which separately penalizes the components of the residual vector which are parallel and perpendicular to the original vector. Quantizing vectors with AVQ loss rather than L2 loss gives a better approximation of the inner product, and thus performs better in Maximal Innter Product Search (https://arxiv.org/abs/1908.10396). Math: x : original vector x_q : quantized vector r = x - x_q : residual vector r_para = < r , x > x / || x ||^2 : parallel component of the residual r_perp = r - r_para : perpendicular compoent of the residual eta >= 1 : AVQ parameter AVQ loss = eta * || r_para ||^2 + || r_perp ||^2 For a float vector x, the goal is to find a bfloat16 vector x_q which minimizes the AVQ loss for a given eta. Unlike L2, AVQ loss is not separable (e.g. ||r_para||^2 contains cross terms from the inner product), so we cannot optimize individual dimensions in parallel and expect convergence. Instead, we use coordinate descent to optimize dimensions of x_q one at a time, until convergence. This coordinate descent happens in the new kernel "quantize_bfloat16_noise_shaped_kernel". For efficient memory accesses and compute, one warp is assigned to optimize each dataset vector. The computation of avq loss is algebraically separated into two pieces: those which can be computed in parallel (i.e. those only depending on local information for the assigned dimension) and those which require global information (namely those depending on < r , x >). Finally threads in a warp serialize to compute the final cost for their dimension, update the quantized value and value of < r , x > (if applicable), and broadcast the updated value of < r, x > for other threads. This continues in blocks of 32 dimensions, until convergence (or a maximum of 10 iterations). I've found this strategy does a good job taking advantage of the inherently row_major structure of the dataset/index for efficient coalesced accesses, while still making good use of compute resources (hitting >90% compute throughput on an A6000). Besides the coordinate descent kernel, this PR adds some helper functions for the above, refactors the existing bfloat16 to take advantage of them, and adds configuration for the AVQ eta (code uses normal bfloat16 quantization when avq threshold is NaN). Authors: - https://github.com/rmaschal - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ben Karsin (https://github.com/bkarsin) - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#1354
This PR adds support for AVQ loss/Noise shaping to the BFloat16 dataset quantization.
AVQ loss is a modified version of L2 loss which separately penalizes the components of the residual vector which are parallel and perpendicular to the original vector. Quantizing vectors with AVQ loss rather than L2 loss gives a better approximation of the inner product, and thus performs better in Maximal Innter Product Search (https://arxiv.org/abs/1908.10396).
Math:
x : original vector
x_q : quantized vector
r = x - x_q : residual vector
r_para = < r , x > x / || x ||^2 : parallel component of the residual
r_perp = r - r_para : perpendicular compoent of the residual
eta >= 1 : AVQ parameter
AVQ loss = eta * || r_para ||^2 + || r_perp ||^2
For a float vector x, the goal is to find a bfloat16 vector x_q which minimizes the AVQ loss for a given eta. Unlike L2, AVQ loss is not separable (e.g. ||r_para||^2 contains cross terms from the inner product), so we cannot optimize individual dimensions in parallel and expect convergence. Instead, we use coordinate descent to optimize dimensions of x_q one at a time, until convergence.
This coordinate descent happens in the new kernel "quantize_bfloat16_noise_shaped_kernel". For efficient memory accesses and compute, one warp is assigned to optimize each dataset vector. The computation of avq loss is algebraically separated into two pieces: those which can be computed in parallel (i.e. those only depending on local information for the assigned dimension) and those which require global information (namely those depending on < r , x >). Finally threads in a warp serialize to compute the final cost for their dimension, update the quantized value and value of < r , x > (if applicable), and broadcast the updated value of < r, x > for other threads. This continues in blocks of 32 dimensions, until convergence (or a maximum of 10 iterations).
I've found this strategy does a good job taking advantage of the inherently row_major structure of the dataset/index for efficient coalesced accesses, while still making good use of compute resources (hitting >90% compute throughput on an A6000).
Besides the coordinate descent kernel, this PR adds some helper functions for the above, refactors the existing bfloat16 to take advantage of them, and adds configuration for the AVQ eta (code uses normal bfloat16 quantization when avq threshold is NaN).