Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cpp/tests/neighbors/ann_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ class AnnBruteForceTest : public ::testing::TestWithParam<AnnBruteForceInputs<Id
stream_,
true));

brute_force::serialize(handle_, std::string{"brute_force_index"}, idx, true);
tmp_index_file index_file;
brute_force::serialize(handle_, index_file.filename, idx, true);
auto index_loaded = brute_force::index<DataT, T>(handle_);
brute_force::deserialize(handle_, std::string{"brute_force_index"}, &index_loaded);
brute_force::deserialize(handle_, index_file.filename, &index_loaded);

brute_force::search(handle_,
index_loaded,
Expand Down
5 changes: 3 additions & 2 deletions cpp/tests/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), ps.n_rows, ps.dim);

tmp_index_file index_file;
{
std::optional<raft::host_matrix<DataT, int64_t>> database_host{std::nullopt};
cagra::index<DataT, IdxT> index(handle_, index_params.metric);
Expand All @@ -422,11 +423,11 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
index = cagra::build(handle_, index_params, database_view);
};

cagra::serialize(handle_, "cagra_index", index, ps.include_serialized_dataset);
cagra::serialize(handle_, index_file.filename, index, ps.include_serialized_dataset);
}

cagra::index<DataT, IdxT> index(handle_);
cagra::deserialize(handle_, "cagra_index", &index);
cagra::deserialize(handle_, index_file.filename, &index);

if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); }

Expand Down
6 changes: 3 additions & 3 deletions cpp/tests/neighbors/ann_ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs<IdxT>> {
indices_ivfflat_dev.data(), ps.num_queries, ps.k);
auto dists_out_view = raft::make_device_matrix_view<T, IdxT>(
distances_ivfflat_dev.data(), ps.num_queries, ps.k);
const std::string filename = "ivf_flat_index";
cuvs::neighbors::ivf_flat::serialize(handle_, filename, index_2);
tmp_index_file index_file;
cuvs::neighbors::ivf_flat::serialize(handle_, index_file.filename, index_2);
cuvs::neighbors::ivf_flat::index<DataT, IdxT> index_loaded(handle_);
cuvs::neighbors::ivf_flat::deserialize(handle_, filename, &index_loaded);
cuvs::neighbors::ivf_flat::deserialize(handle_, index_file.filename, &index_loaded);
ASSERT_EQ(index_2.size(), index_loaded.size());

cuvs::neighbors::ivf_flat::search(handle_,
Expand Down
6 changes: 3 additions & 3 deletions cpp/tests/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {

auto build_serialize()
{
std::string filename = "ivf_pq_index";
cuvs::neighbors::ivf_pq::serialize(handle_, filename, build_only());
tmp_index_file index_file;
cuvs::neighbors::ivf_pq::serialize(handle_, index_file.filename, build_only());
cuvs::neighbors::ivf_pq::index<IdxT> index(handle_);
cuvs::neighbors::ivf_pq::deserialize(handle_, filename, &index);
cuvs::neighbors::ivf_pq::deserialize(handle_, index_file.filename, &index);
return index;
}

Expand Down
26 changes: 26 additions & 0 deletions cpp/tests/neighbors/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
#include "naive_knn.cuh"

#include "../test_utils.cuh"
#include <atomic>
#include <cstdio>
#include <filesystem>
#include <gtest/gtest.h>
#include <iostream>
#include <limits>
Expand Down Expand Up @@ -346,4 +349,27 @@ auto eval_distances(raft::resources const& handle,
}
return testing::AssertionSuccess();
}

/**
* A helper class to create a temporary file for a cuVS index object in the system's temp directory.
* The file will be automatically deleted when the object is destroyed.
*/
struct tmp_index_file {
// Ideally, we should use std::tmpfile() or another system-provided API to create a temporary
// file. However, our API requires a file name, so we cannot use the file descriptors. There's no
// recommended way to generate a robust unique temp filenames, so we use a combination of a
// counter, process id, and random number.
std::string filename = (std::filesystem::temp_directory_path() /
("cuvs_" + std::to_string(getpid()) + "_" + std::to_string(counter++) +
"_" + std::to_string(std::rand())))
.string();
~tmp_index_file()
{
if (std::filesystem::exists(filename)) { std::filesystem::remove(filename); }
}

private:
static inline std::atomic<uint64_t> counter = 0;
};

} // namespace cuvs::neighbors
3 changes: 2 additions & 1 deletion cpp/tests/neighbors/ann_vamana.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ class AnnVamanaTest : public ::testing::TestWithParam<AnnVamanaInputs> {

CheckGraph<DataT, IdxT>(&index, ps, stream_);

vamana::serialize(handle_, "vamana_index", index);
tmp_index_file index_file;
vamana::serialize(handle_, index_file.filename, index);

// Test recall by searching with CAGRA search
if (ps.graph_degree < 256) { // CAGRA search result buffer cannot support larger graph degree
Expand Down
30 changes: 18 additions & 12 deletions cpp/tests/neighbors/mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,14 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
auto distances = raft::make_host_matrix_view<float, int64_t, row_major>(
distances_snmg_ann.data(), ps.num_queries, ps.k);

tmp_index_file index_file;
{
auto index = cuvs::neighbors::ivf_flat::build(clique_, index_params, index_dataset);
cuvs::neighbors::ivf_flat::extend(clique_, index, index_dataset, std::nullopt);
cuvs::neighbors::ivf_flat::serialize(clique_, index, "mg_ivf_flat_index");
cuvs::neighbors::ivf_flat::serialize(clique_, index, index_file.filename);
}
auto new_index =
cuvs::neighbors::ivf_flat::deserialize<DataT, int64_t>(clique_, "mg_ivf_flat_index");
cuvs::neighbors::ivf_flat::deserialize<DataT, int64_t>(clique_, index_file.filename);

if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK)
search_params.merge_mode = MERGE_ON_ROOT_RANK;
Expand Down Expand Up @@ -187,13 +188,14 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
auto distances = raft::make_host_matrix_view<float, int64_t, row_major>(
distances_snmg_ann.data(), ps.num_queries, ps.k);

tmp_index_file index_file;
{
auto index = cuvs::neighbors::ivf_pq::build(clique_, index_params, index_dataset);
cuvs::neighbors::ivf_pq::extend(clique_, index, index_dataset, std::nullopt);
cuvs::neighbors::ivf_pq::serialize(clique_, index, "mg_ivf_pq_index");
cuvs::neighbors::ivf_pq::serialize(clique_, index, index_file.filename);
}
auto new_index =
cuvs::neighbors::ivf_pq::deserialize<DataT, int64_t>(clique_, "mg_ivf_pq_index");
cuvs::neighbors::ivf_pq::deserialize<DataT, int64_t>(clique_, index_file.filename);

if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK)
search_params.merge_mode = MERGE_ON_ROOT_RANK;
Expand Down Expand Up @@ -243,12 +245,13 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
auto distances = raft::make_host_matrix_view<float, uint32_t, row_major>(
distances_snmg_ann.data(), ps.num_queries, ps.k);

tmp_index_file index_file;
{
auto index = cuvs::neighbors::cagra::build(clique_, index_params, index_dataset);
cuvs::neighbors::cagra::serialize(clique_, index, "mg_cagra_index");
cuvs::neighbors::cagra::serialize(clique_, index, index_file.filename);
}
auto new_index =
cuvs::neighbors::cagra::deserialize<DataT, uint32_t>(clique_, "mg_cagra_index");
cuvs::neighbors::cagra::deserialize<DataT, uint32_t>(clique_, index_file.filename);

if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK)
search_params.merge_mode = MERGE_ON_ROOT_RANK;
Expand Down Expand Up @@ -286,11 +289,12 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
search_params.n_probes = ps.nprobe;
search_params.search_mode = LOAD_BALANCER;

tmp_index_file index_file;
{
auto index_dataset = raft::make_device_matrix_view<const DataT, int64_t>(
d_index_dataset.data(), ps.num_db_vecs, ps.dim);
auto index = cuvs::neighbors::ivf_flat::build(clique_, index_params, index_dataset);
ivf_flat::serialize(clique_, "local_ivf_flat_index", index);
ivf_flat::serialize(clique_, index_file.filename, index);
}

auto queries = raft::make_host_matrix_view<const DataT, int64_t, row_major>(
Expand All @@ -301,7 +305,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
distances_snmg_ann.data(), ps.num_queries, ps.k);

auto distributed_index =
cuvs::neighbors::ivf_flat::distribute<DataT, int64_t>(clique_, "local_ivf_flat_index");
cuvs::neighbors::ivf_flat::distribute<DataT, int64_t>(clique_, index_file.filename);
search_params.merge_mode = TREE_MERGE;

search_params.n_rows_per_batch = n_rows_per_search_batch;
Expand Down Expand Up @@ -335,11 +339,12 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
search_params.n_probes = ps.nprobe;
search_params.search_mode = LOAD_BALANCER;

tmp_index_file index_file;
{
auto index_dataset = raft::make_device_matrix_view<const DataT, int64_t>(
d_index_dataset.data(), ps.num_db_vecs, ps.dim);
auto index = cuvs::neighbors::ivf_pq::build(clique_, index_params, index_dataset);
ivf_pq::serialize(clique_, "local_ivf_pq_index", index);
ivf_pq::serialize(clique_, index_file.filename, index);
}

auto queries = raft::make_host_matrix_view<const DataT, int64_t, row_major>(
Expand All @@ -350,7 +355,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
distances_snmg_ann.data(), ps.num_queries, ps.k);

auto distributed_index =
cuvs::neighbors::ivf_pq::distribute<DataT, int64_t>(clique_, "local_ivf_pq_index");
cuvs::neighbors::ivf_pq::distribute<DataT, int64_t>(clique_, index_file.filename);
search_params.merge_mode = TREE_MERGE;

search_params.n_rows_per_batch = n_rows_per_search_batch;
Expand Down Expand Up @@ -379,11 +384,12 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {

mg_search_params<cagra::search_params> search_params;

tmp_index_file index_file;
{
auto index_dataset = raft::make_device_matrix_view<const DataT, int64_t>(
d_index_dataset.data(), ps.num_db_vecs, ps.dim);
auto index = cuvs::neighbors::cagra::build(clique_, index_params, index_dataset);
cuvs::neighbors::cagra::serialize(clique_, "local_cagra_index", index);
cuvs::neighbors::cagra::serialize(clique_, index_file.filename, index);
}

auto queries = raft::make_host_matrix_view<const DataT, int64_t, row_major>(
Expand All @@ -394,7 +400,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
distances_snmg_ann.data(), ps.num_queries, ps.k);

auto distributed_index =
cuvs::neighbors::cagra::distribute<DataT, uint32_t>(clique_, "local_cagra_index");
cuvs::neighbors::cagra::distribute<DataT, uint32_t>(clique_, index_file.filename);

search_params.merge_mode = TREE_MERGE;

Expand Down