Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion c/include/cuvs/neighbors/vamana.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ cuvsError_t cuvsVamanaIndexGetDims(cuvsVamanaIndex_t index, int* dim);
*
* Build the index from the dataset for efficient DiskANN search.
*
* The build utilities the Vamana insertion-based algorithm to create the graph. The algorithm
* The build uses the Vamana insertion-based algorithm to create the graph. The algorithm
* starts with an empty graph and iteratively inserts batches of nodes. Each batch involves
* performing a greedy search for each vector to be inserted, and inserting it with edges to
* all nodes traversed during the search. Reverse edges are also inserted and robustPrune is applied
Expand Down
1 change: 1 addition & 0 deletions rust/cuvs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod error;
pub mod ivf_flat;
pub mod ivf_pq;
mod resources;
pub mod vamana;

pub use dlpack::ManagedTensor;
pub use error::{Error, Result};
Expand Down
118 changes: 118 additions & 0 deletions rust/cuvs/src/vamana/index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

use std::ffi::{CStr, CString};
use std::io::{stderr, Write};

use crate::dlpack::ManagedTensor;
use crate::error::{check_cuvs, Result};
use crate::resources::Resources;
use crate::vamana::IndexParams;

/// Vamana ANN Index
#[derive(Debug)]
pub struct Index(ffi::cuvsVamanaIndex_t);

impl Index {
/// Builds Vamana Index for efficient DiskANN search
///
/// The build uses the Vamana insertion-based algorithm to create the graph. The algorithm
/// starts with an empty graph and iteratively inserts batches of nodes. Each batch involves
/// performing a greedy search for each vector to be inserted, and inserting it with edges to
/// all nodes traversed during the search. Reverse edges are also inserted and robustPrune is applied
/// to improve graph quality. The index_params struct controls the degree of the final graph.
///
///
/// # Arguments
///
/// * `res` - Resources to use
/// * `params` - Parameters for building the index
/// * `dataset` - A row-major matrix on either the host or device to index
pub fn build<T: Into<ManagedTensor>>(
res: &Resources,
params: &IndexParams,
dataset: T,
) -> Result<Index> {
let dataset: ManagedTensor = dataset.into();
let index = Index::new()?;
unsafe {
check_cuvs(ffi::cuvsVamanaBuild(
res.0,
params.0,
dataset.as_ptr(),
index.0,
))?;
}
Ok(index)
}

/// Creates a new empty index
pub fn new() -> Result<Index> {
unsafe {
let mut index = std::mem::MaybeUninit::<ffi::cuvsVamanaIndex_t>::uninit();
check_cuvs(ffi::cuvsVamanaIndexCreate(index.as_mut_ptr()))?;
Ok(Index(index.assume_init()))
}
}

/// Save Vamana index to file
///
/// Matches the file format used by the DiskANN open-source repository, allowing cross-compatibility.
///
/// Serialized Index is to be used by the DiskANN open-source repository for graph search.
///
/// # Arguments
///
/// * `res` - Resources to use
/// * `filename` - The file prefix for where the index is sazved
/// * `include_dataset` - whether to include the dataset in the serialized index
pub fn serialize(self, res: &Resources, filename: &str, include_dataset: bool) -> Result<()> {
let c_filename = CString::new(filename).unwrap();
unsafe {
check_cuvs(ffi::cuvsVamanaSerialize(
res.0,
c_filename.as_ptr(),
self.0,
include_dataset,
))
}
}
}

impl Drop for Index {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsVamanaIndexDestroy(self.0) }) {
write!(stderr(), "failed to call cuvsVamanaIndexDestroy {:?}", e)
.expect("failed to write to stderr");
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use ndarray::s;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;

#[test]
fn test_vamana() {
let build_params = IndexParams::new().unwrap();

let res = Resources::new().unwrap();

// Create a new random dataset to index
let n_datapoints = 1024;
let n_features = 16;
let dataset =
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));

let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap();

// build the vamana index
let index = Index::build(&res, &build_params, dataset_device)
.expect("failed to create vamana index");
}
}
136 changes: 136 additions & 0 deletions rust/cuvs/src/vamana/index_params.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

use crate::distance_type::DistanceType;
use crate::error::{check_cuvs, Result};
use std::fmt;
use std::io::{stderr, Write};

pub struct IndexParams(pub ffi::cuvsVamanaIndexParams_t);

impl IndexParams {
/// Returns a new IndexParams
pub fn new() -> Result<IndexParams> {
unsafe {
let mut params = std::mem::MaybeUninit::<ffi::cuvsVamanaIndexParams_t>::uninit();
check_cuvs(ffi::cuvsVamanaIndexParamsCreate(params.as_mut_ptr()))?;
Ok(IndexParams(params.assume_init()))
}
}

/// DistanceType to use for building the index
pub fn set_metric(self, metric: DistanceType) -> IndexParams {
unsafe {
(*self.0).metric = metric;
}
self
}

/// Maximum degree of output graph corresponds to the R parameter in the original Vamana
/// literature.
pub fn set_graph_degree(self, graph_degree: u32) -> IndexParams {
unsafe {
(*self.0).graph_degree = graph_degree;
}
self
}

/// Maximum number of visited nodes per search corresponds to the L parameter in the Vamana
/// literature
pub fn set_visited_size(self, visited_size: u32) -> IndexParams {
unsafe {
(*self.0).visited_size = visited_size;
}
self
}

/// Number of Vamana vector insertion iterations (each iteration inserts all vectors).
pub fn set_vamana_iters(self, vamana_iters: f32) -> IndexParams {
unsafe {
(*self.0).vamana_iters = vamana_iters;
}
self
}

/// Alpha for pruning parameter
pub fn set_alpha(self, alpha: f32) -> IndexParams {
unsafe {
(*self.0).alpha = alpha;
}
self
}

/// Maximum fraction of dataset inserted per batch.
/// Larger max batch decreases graph quality, but improves speed
pub fn set_max_fraction(self, max_fraction: f32) -> IndexParams {
unsafe {
(*self.0).max_fraction = max_fraction;
}
self
}

/// Base of growth rate of batch sizes
pub fn set_batch_base(self, batch_base: f32) -> IndexParams {
unsafe {
(*self.0).batch_base = batch_base;
}
self
}

/// Size of candidate queue structure - should be (2^x)-1
pub fn set_queue_size(self, queue_size: u32) -> IndexParams {
unsafe {
(*self.0).queue_size = queue_size;
}
self
}

/// Max batchsize of reverse edge processing (reduces memory footprint)
pub fn set_reverse_batchsize(self, reverse_batchsize: u32) -> IndexParams {
unsafe {
(*self.0).reverse_batchsize = reverse_batchsize;
}
self
}
}

impl fmt::Debug for IndexParams {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// custom debug trait here, default value will show the pointer address
// for the inner params object which isn't that useful.
write!(f, "IndexParams({:?})", unsafe { *self.0 })
}
}

impl Drop for IndexParams {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsVamanaIndexParamsDestroy(self.0) }) {
write!(
stderr(),
"failed to call cuvsVamanaIndexParamsDestroy {:?}",
e
)
.expect("failed to write to stderr");
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_index_params() {
let params = IndexParams::new()
.unwrap()
.set_alpha(1.0)
.set_visited_size(128);

unsafe {
assert_eq!((*params.0).alpha, 1.0);
assert_eq!((*params.0).visited_size, 128);
}
}
}
11 changes: 11 additions & 0 deletions rust/cuvs/src/vamana/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
//! Vamana

mod index;
mod index_params;

pub use index::Index;
pub use index_params::IndexParams;
Loading