@@ -53,6 +53,17 @@ namespace cuvs::neighbors::vamana {
5353 */
5454
5555struct index_params : cuvs::neighbors::index_params {
56+ /* *
57+ * @brief Parameters used to build quantized DiskANN index; to be generated using
58+ * deserialize_codebooks()
59+ */
60+ template <typename T = float >
61+ struct codebook_params {
62+ int pq_codebook_size;
63+ int pq_dim;
64+ std::vector<T> pq_encoding_table;
65+ std::vector<T> rotation_matrix;
66+ };
5667 /* * Maximum degree of output graph corresponds to the R parameter in the original Vamana
5768 * literature. */
5869 uint32_t graph_degree = 32 ;
@@ -72,6 +83,8 @@ struct index_params : cuvs::neighbors::index_params {
7283 uint32_t queue_size = 127 ;
7384 /* * Max batchsize of reverse edge processing (reduces memory footprint) */
7485 uint32_t reverse_batchsize = 1000000 ;
86+ /* * Codebooks and related parameters */
87+ std::optional<codebook_params<float >> codebooks = std::nullopt ;
7588};
7689
7790/* *
@@ -127,6 +140,13 @@ struct index : cuvs::neighbors::index {
127140 return *dataset_;
128141 }
129142
143+ /* * Quantized dataset [size, codes_rowlen] */
144+ [[nodiscard]] inline auto quantized_data () const noexcept
145+ -> raft::device_matrix_view<const uint8_t, int64_t, raft::row_major>
146+ {
147+ return quantized_dataset_.view ();
148+ }
149+
130150 /* * vamana graph [size, graph-degree] */
131151 [[nodiscard]] inline auto graph () const noexcept
132152 -> raft::device_matrix_view<const IdxT, int64_t, raft::row_major>
@@ -150,7 +170,8 @@ struct index : cuvs::neighbors::index {
150170 : cuvs::neighbors::index(),
151171 metric_ (metric),
152172 graph_(raft::make_device_matrix<IdxT, int64_t >(res, 0 , 0 )),
153- dataset_(new cuvs::neighbors::empty_dataset<int64_t >(0 ))
173+ dataset_(new cuvs::neighbors::empty_dataset<int64_t >(0 )),
174+ quantized_dataset_(raft::make_device_matrix<uint8_t , int64_t >(res, 0 , 0 ))
154175 {
155176 }
156177
@@ -168,6 +189,7 @@ struct index : cuvs::neighbors::index {
168189 metric_(metric),
169190 graph_(raft::make_device_matrix<IdxT, int64_t >(res, 0 , 0 )),
170191 dataset_(make_aligned_dataset(res, dataset, 16 )),
192+ quantized_dataset_(raft::make_device_matrix<uint8_t , int64_t >(res, 0 , 0 )),
171193 medoid_id_(medoid_id)
172194 {
173195 RAFT_EXPECTS (dataset.extent (0 ) == vamana_graph.extent (0 ),
@@ -212,11 +234,42 @@ struct index : cuvs::neighbors::index {
212234 graph_view_ = graph_.view ();
213235 }
214236
237+ /* *
238+ * @brief Replace the current quantized dataset with a new quantized dataset.
239+ *
240+ * We create a copy of the quantized dataset on the device. The index manages the lifetime of this
241+ * copy.
242+ *
243+ * @param[in] res
244+ * @param[in] new_quantized_dataset the new quantized dataset for the index
245+ *
246+ */
247+ void update_quantized_dataset (
248+ raft::resources const & res,
249+ raft::device_matrix_view<const uint8_t , int64_t , raft::row_major> new_quantized_dataset)
250+ {
251+ RAFT_LOG_DEBUG (" Creating device copy of Vamana quantized dataset" );
252+ if ((quantized_dataset_.extent (0 ) != new_quantized_dataset.extent (0 )) ||
253+ (quantized_dataset_.extent (1 ) != new_quantized_dataset.extent (1 ))) {
254+ // clear existing memory before allocating to prevent OOM errors on large datasets
255+ if (quantized_dataset_.size ()) {
256+ quantized_dataset_ = raft::make_device_matrix<uint8_t , int64_t >(res, 0 , 0 );
257+ }
258+ quantized_dataset_ = raft::make_device_matrix<uint8_t , int64_t >(
259+ res, new_quantized_dataset.extent (0 ), new_quantized_dataset.extent (1 ));
260+ }
261+ raft::copy (quantized_dataset_.data_handle (),
262+ new_quantized_dataset.data_handle (),
263+ new_quantized_dataset.size (),
264+ raft::resource::get_cuda_stream (res));
265+ }
266+
215267 private:
216268 cuvs::distance::DistanceType metric_;
217269 raft::device_matrix<IdxT, int64_t , raft::row_major> graph_;
218270 raft::device_matrix_view<const IdxT, int64_t , raft::row_major> graph_view_;
219271 std::unique_ptr<neighbors::dataset<int64_t >> dataset_;
272+ raft::device_matrix<uint8_t , int64_t , raft::row_major> quantized_dataset_;
220273 IdxT medoid_id_;
221274};
222275/* *
@@ -457,13 +510,15 @@ auto build(raft::resources const& res,
457510 * @param[in] file_prefix prefix of path and name of index files
458511 * @param[in] index Vamana index
459512 * @param[in] include_dataset whether or not to serialize the dataset
513+ * @param[in] sector_aligned whether output file should be aligned to disk sectors of 4096 bytes
460514 *
461515 */
462516
463517void serialize (raft::resources const & handle,
464518 const std::string& file_prefix,
465519 const cuvs::neighbors::vamana::index<float , uint32_t >& index,
466- bool include_dataset = true );
520+ bool include_dataset = true ,
521+ bool sector_aligned = false );
467522
468523/* *
469524 * Save the index to file.
@@ -486,12 +541,14 @@ void serialize(raft::resources const& handle,
486541 * @param[in] file_prefix prefix of path and name of index files
487542 * @param[in] index Vamana index
488543 * @param[in] include_dataset whether or not to serialize the dataset
544+ * @param[in] sector_aligned whether output file should be aligned to disk sectors of 4096 bytes
489545 *
490546 */
491547void serialize (raft::resources const & handle,
492548 const std::string& file_prefix,
493549 const cuvs::neighbors::vamana::index<int8_t , uint32_t >& index,
494- bool include_dataset = true );
550+ bool include_dataset = true ,
551+ bool sector_aligned = false );
495552
496553/* *
497554 * Save the index to file.
@@ -514,12 +571,48 @@ void serialize(raft::resources const& handle,
514571 * @param[in] file_prefix prefix of path and name of index files
515572 * @param[in] index Vamana index
516573 * @param[in] include_dataset whether or not to serialize the dataset
574+ * @param[in] sector_aligned whether output file should be aligned to disk sectors of 4096 bytes
517575 *
518576 */
519577void serialize (raft::resources const & handle,
520578 const std::string& file_prefix,
521579 const cuvs::neighbors::vamana::index<uint8_t , uint32_t >& index,
522- bool include_dataset = true );
580+ bool include_dataset = true ,
581+ bool sector_aligned = false );
582+
583+ /* *
584+ * @}
585+ */
586+
587+ /* *
588+ * @defgroup vamana_cpp_codebook Vamana codebook functions
589+ * @{
590+ */
591+
592+ /* *
593+ * @brief Construct codebook parameters from input codebook files
594+ *
595+ * Expects pq pivots file at
596+ * "${codebook_prefix}_pq_pivots.bin" and rotation matrix file at
597+ * "${codebook_prefix}_pq_pivots.bin_rotation_matrix.bin".
598+ *
599+ * @code{.cpp}
600+ * #include <cuvs/neighbors/vamana.hpp>
601+ *
602+ * // create a string with a filepath
603+ * std::string codebook_prefix("/path/to/index/prefix");
604+ * // define dimension of vectors in dataset
605+ * int dim = 64;
606+ * // construct codebook parameters from input codebook files
607+ * auto codebooks = cuvs::neighbors::vamana::deserialize_codebooks(codebook_prefix, dim);
608+ * @endcode
609+ *
610+ * @param[in] codebook_prefix path prefix to pq pivots and rotation matrix files
611+ * @param[in] dim dimension of vectors in dataset
612+ *
613+ */
614+ auto deserialize_codebooks (const std::string& codebook_prefix, const int dim)
615+ -> index_params::codebook_params<float>;
523616
524617/* *
525618 * @}
0 commit comments