|
16 | 16 |
|
17 | 17 | #include "../../detail/vpq_dataset.cuh" |
18 | 18 | #include <chrono> |
| 19 | +#include <cmath> |
19 | 20 | #include <cuvs/neighbors/common.hpp> |
| 21 | +#include <raft/linalg/transpose.cuh> |
20 | 22 | #include <raft/matrix/gather.cuh> |
21 | 23 |
|
22 | 24 | #include "scann_common.cuh" |
@@ -267,6 +269,231 @@ void unpack_codes(raft::resources const& res, |
267 | 269 | } |
268 | 270 | } |
269 | 271 |
|
| 272 | +/** |
| 273 | + * @brief compute eta for AVQ according to Theorem 3.4 in https://arxiv.org/abs/1908.10396 |
| 274 | + * |
| 275 | + * @tparam IdxT |
| 276 | + * @param dim the dataset dimension |
| 277 | + * @param sq_norm the squared norm of the vector |
| 278 | + * @param noise_shaping_threshold the threshold T in the Theorem |
| 279 | + * @return eta |
| 280 | + */ |
| 281 | +template <typename IdxT> |
| 282 | +__device__ inline float compute_avq_eta(IdxT dim, const float sq_norm, const float threshold) |
| 283 | +{ |
| 284 | + return (dim - 1) * (threshold * threshold / sq_norm) / (1 - threshold * threshold / sq_norm); |
| 285 | +} |
| 286 | + |
| 287 | +/** |
| 288 | + * @brief helper to convert a float to bfloat16 (represented as int16_t) |
| 289 | + * |
| 290 | + * @param f the float value |
| 291 | + * @return the bflaot16 value (as int16_t) |
| 292 | + */ |
| 293 | +__device__ inline int16_t float_to_bfloat16(const float& f) |
| 294 | +{ |
| 295 | + nv_bfloat16 val = __float2bfloat16(f); |
| 296 | + return reinterpret_cast<int16_t&>(val); |
| 297 | +} |
| 298 | + |
| 299 | +/** |
| 300 | + * @brief helper to convert a bfloat16 (represented as int16_t) to float |
| 301 | + * |
| 302 | + * @param bf16 the bf16 value (represented as int16_t) |
| 303 | + * @return the float value |
| 304 | + */ |
| 305 | +__device__ inline float bfloat16_to_float(int16_t& bf16) |
| 306 | +{ |
| 307 | + nv_bfloat16 nv_bf16 = reinterpret_cast<nv_bfloat16&>(bf16); |
| 308 | + return __bfloat162float(nv_bf16); |
| 309 | +} |
| 310 | + |
| 311 | +/** |
| 312 | + * @brief Select the next bfloat16 value to try during coordinate descent |
| 313 | + * |
| 314 | + * Based on the signs of the current residual and quantized value, |
| 315 | + * increment or decrement the quantized value to push residual closer to 0 |
| 316 | + * |
| 317 | + * @param res the float residual |
| 318 | + * @param current the current quantized dimension |
| 319 | + * @return the other possible quantized value |
| 320 | + */ |
| 321 | +__device__ inline int16_t bfloat16_next_delta(float& res, int16_t& current) |
| 322 | +{ |
| 323 | + uint32_t res_sign = ((int32_t)res & (1u << 31) >> 31); |
| 324 | + uint32_t curr_sign = (current & (1 << 15)) >> 15; |
| 325 | + |
| 326 | + if (res_sign == curr_sign) { return current - 1; } |
| 327 | + |
| 328 | + return current + 1; |
| 329 | +} |
| 330 | + |
| 331 | +/** |
| 332 | + * |
| 333 | + */ |
| 334 | +template <uint32_t BlockSize, typename IdxT> |
| 335 | +__launch_bounds__(BlockSize) RAFT_KERNEL |
| 336 | + quantize_bfloat16_noise_shaped_kernel(raft::device_matrix_view<const float, IdxT> dataset, |
| 337 | + raft::device_matrix_view<int16_t, IdxT> bf16_dataset, |
| 338 | + raft::device_vector_view<const float> sq_norms, |
| 339 | + float noise_shaping_threshold) |
| 340 | +{ |
| 341 | + IdxT row_idx = raft::Pow2<32>::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); |
| 342 | + |
| 343 | + if (row_idx >= dataset.extent(0)) { return; } |
| 344 | + |
| 345 | + uint32_t lane_id = raft::Pow2<32>::mod(threadIdx.x); |
| 346 | + |
| 347 | + IdxT dim = dataset.extent(1); |
| 348 | + |
| 349 | + // 1 / ||x|| |
| 350 | + float inv_norm = 1 / sqrtf(sq_norms[row_idx]); |
| 351 | + float eta = compute_avq_eta(dim, sq_norms[row_idx], noise_shaping_threshold); |
| 352 | + |
| 353 | + // < r, x > |
| 354 | + float residual_dot = 0.0; |
| 355 | + |
| 356 | + for (int i = lane_id; i < dim; i += 32) { |
| 357 | + bf16_dataset(row_idx, i) = float_to_bfloat16(dataset(row_idx, i)); |
| 358 | + |
| 359 | + float residual = dataset(row_idx, i) - bfloat16_to_float(bf16_dataset(row_idx, i)); |
| 360 | + residual_dot += dataset(row_idx, i) * residual * inv_norm; |
| 361 | + } |
| 362 | + |
| 363 | + // reduce and broadcast residual_dot across warp |
| 364 | + for (uint32_t offset = 16; offset > 0; offset >>= 1) { |
| 365 | + residual_dot += raft::shfl_xor(residual_dot, offset, 32); |
| 366 | + } |
| 367 | + |
| 368 | + constexpr uint32_t kMaxRounds = 10; |
| 369 | + |
| 370 | + bool round_changes = true; |
| 371 | + for (int round = 0; round < kMaxRounds && round_changes; round++) { |
| 372 | + round_changes = false; |
| 373 | + |
| 374 | + for (int i = lane_id; i < dim; i += 32) { |
| 375 | + // coaleseced reads of required data |
| 376 | + float original = dataset(row_idx, i); |
| 377 | + int16_t quantized = bf16_dataset(row_idx, i); |
| 378 | + |
| 379 | + float old_residual = original - bfloat16_to_float(quantized); |
| 380 | + int16_t quantized_new = bfloat16_next_delta(old_residual, quantized); |
| 381 | + |
| 382 | + float new_residual = original - bfloat16_to_float(quantized_new); |
| 383 | + float residual_dot_delta = (new_residual - old_residual) * dataset(row_idx, i) * inv_norm; |
| 384 | + |
| 385 | + float residual_norm_delta = new_residual * new_residual - old_residual * old_residual; |
| 386 | + |
| 387 | + // we want to compute the change in cost = eta || r_parallel || ^2 + || r_perpendicular|| ^2 |
| 388 | + // The change in || r_parallel ||^2 can be written (residual_dot + residual_dot_delta) ^ 2 |
| 389 | + // the change in || r_perpendicular || ^2 can be written residual_norm_delta - |
| 390 | + // parallel_norm_delta Thus cost_delta = eta * (residual_dot + residual_dot_delta) ^2 + |
| 391 | + // (residual_norm_delta - (residual_dot + residual_dot_delta)^2 Expanding and simplying, |
| 392 | + // cost_delta = a + b * resdiaul_dot, where a and b are as below. Since only residual_dot is |
| 393 | + // unknown (because updates must be made synchronously) we can compute a and b in parallel |
| 394 | + // across threads in the warp and minimize computation in the update step of the coordinate |
| 395 | + // descent |
| 396 | + float a = residual_norm_delta + (eta - 1) * residual_dot_delta * residual_dot_delta; |
| 397 | + float b = 2 * (eta - 1) * residual_dot_delta; |
| 398 | + |
| 399 | + // Dim may not be divisible by 32 |
| 400 | + // Only synchronize/shuffle for active threads |
| 401 | + int active_threads = std::min<int>(32, dim - i + lane_id); |
| 402 | + int mask = (1 << active_threads) - 1; |
| 403 | + |
| 404 | + // Update step for coordinate descent. Compute the cost_delta for |
| 405 | + // each thread, update the quantized value and residual_dot if applicable, |
| 406 | + // then broadcast the new residual dot to the warp |
| 407 | + // AVQ loss the not separable, so we must optimize each dimension separately |
| 408 | + for (int j = 0; j < active_threads; j++) { |
| 409 | + if (lane_id == j) { |
| 410 | + // change in AVQ loss |
| 411 | + float cost_delta = b * residual_dot + a; |
| 412 | + |
| 413 | + if (cost_delta < 0.0) { |
| 414 | + quantized = quantized_new; |
| 415 | + residual_dot += residual_dot_delta; |
| 416 | + round_changes = true; |
| 417 | + } |
| 418 | + } |
| 419 | + |
| 420 | + // broadcast new dot product to all lanes |
| 421 | + residual_dot = raft::shfl(residual_dot, j, active_threads, mask); |
| 422 | + } |
| 423 | + |
| 424 | + // coalesced write of possibly updated quantized values |
| 425 | + bf16_dataset(row_idx, i) = quantized; |
| 426 | + } |
| 427 | + |
| 428 | + // reduce round_changes across warp |
| 429 | + for (uint32_t offset = 16; offset > 0; offset >>= 1) { |
| 430 | + round_changes |= raft::shfl_xor(round_changes, offset, 32); |
| 431 | + } |
| 432 | + } |
| 433 | +} |
| 434 | + |
| 435 | +/** |
| 436 | + * @brief Quantized a float dataset as bfloat16, with noise shaping (AVQ) |
| 437 | + * |
| 438 | + * @tparam IdxT |
| 439 | + * @param res raft resources |
| 440 | + * @param dataset the dataset (device only) size [n_rows, dim] |
| 441 | + * @param bf16_dataset the quantized dataset (device only) size [n_rows, dim] |
| 442 | + * @param noise_shaping_threshold the threshold for AVQ |
| 443 | + */ |
| 444 | +template <typename IdxT> |
| 445 | +void quantize_bfloat16_noise_shaped(raft::resources const& res, |
| 446 | + raft::device_matrix_view<const float, IdxT> dataset, |
| 447 | + raft::device_matrix_view<int16_t, IdxT> bf16_dataset, |
| 448 | + float noise_shaping_threshold) |
| 449 | +{ |
| 450 | + cudaStream_t stream = raft::resource::get_cuda_stream(res); |
| 451 | + |
| 452 | + IdxT n_rows = dataset.extent(0); |
| 453 | + auto norms = raft::make_device_vector<float, IdxT>(res, n_rows); |
| 454 | + |
| 455 | + // populate square norms |
| 456 | + raft::linalg::norm<raft::linalg::NormType::L2Norm, raft::Apply::ALONG_ROWS>( |
| 457 | + res, dataset, norms.view()); |
| 458 | + |
| 459 | + constexpr int64_t kBlockSize = 256; |
| 460 | + |
| 461 | + dim3 threads(kBlockSize, 1, 1); |
| 462 | + dim3 blocks(raft::div_rounding_up_safe<ix_t>(n_rows, kBlockSize / 32), 1, 1); |
| 463 | + |
| 464 | + quantize_bfloat16_noise_shaped_kernel<kBlockSize, IdxT><<<blocks, threads, 0, stream>>>( |
| 465 | + dataset, bf16_dataset, raft::make_const_mdspan(norms.view()), noise_shaping_threshold); |
| 466 | + |
| 467 | + RAFT_CUDA_TRY(cudaPeekAtLastError()); |
| 468 | +} |
| 469 | + |
| 470 | +/** |
| 471 | + * @brief Quantized a float dataset as bfloat16, with optional noise shaping (AVQ) |
| 472 | + * |
| 473 | + * @tparam IdxT |
| 474 | + * @param res raft resources |
| 475 | + * @param dataset the dataset (device only) size [n_rows, dim] |
| 476 | + * @param bf16_dataset the quantized dataset (device only) size [n_rows, dim] |
| 477 | + * @param noise_shaping_threshold the threshold for AVQ (nan when not using AVQ) |
| 478 | + */ |
| 479 | +template <typename IdxT> |
| 480 | +void quantize_bfloat16(raft::resources const& res, |
| 481 | + raft::device_matrix_view<const float, IdxT> dataset, |
| 482 | + raft::device_matrix_view<int16_t, IdxT> bf16_dataset, |
| 483 | + float noise_shaping_threshold) |
| 484 | +{ |
| 485 | + if (!std::isnan(noise_shaping_threshold)) { |
| 486 | + quantize_bfloat16_noise_shaped(res, dataset, bf16_dataset, noise_shaping_threshold); |
| 487 | + } else { |
| 488 | + raft::linalg::unaryOp( |
| 489 | + bf16_dataset.data_handle(), |
| 490 | + dataset.data_handle(), |
| 491 | + dataset.size(), |
| 492 | + [] __device__(float x) { return float_to_bfloat16(x); }, |
| 493 | + resource::get_cuda_stream(res)); |
| 494 | + } |
| 495 | +} |
| 496 | + |
270 | 497 | /** |
271 | 498 | * @brief sample dataset vectors/labels and compute their residuals for PQ training |
272 | 499 | * |
|
0 commit comments