Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cpp/tests/neighbors/ann_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ class AnnBruteForceTest : public ::testing::TestWithParam<AnnBruteForceInputs<Id
stream_,
true));

brute_force::serialize(handle_, std::string{"brute_force_index"}, idx, true);
tmp_index_file index_file;
brute_force::serialize(handle_, index_file.filename, idx, true);
auto index_loaded = brute_force::index<DataT, T>(handle_);
brute_force::deserialize(handle_, std::string{"brute_force_index"}, &index_loaded);
brute_force::deserialize(handle_, index_file.filename, &index_loaded);

brute_force::search(handle_,
index_loaded,
Expand Down
5 changes: 3 additions & 2 deletions cpp/tests/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), ps.n_rows, ps.dim);

tmp_index_file index_file;
{
std::optional<raft::host_matrix<DataT, int64_t>> database_host{std::nullopt};
cagra::index<DataT, IdxT> index(handle_, index_params.metric);
Expand All @@ -422,11 +423,11 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
index = cagra::build(handle_, index_params, database_view);
};

cagra::serialize(handle_, "cagra_index", index, ps.include_serialized_dataset);
cagra::serialize(handle_, index_file.filename, index, ps.include_serialized_dataset);
}

cagra::index<DataT, IdxT> index(handle_);
cagra::deserialize(handle_, "cagra_index", &index);
cagra::deserialize(handle_, index_file.filename, &index);

if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); }

Expand Down
6 changes: 3 additions & 3 deletions cpp/tests/neighbors/ann_ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs<IdxT>> {
indices_ivfflat_dev.data(), ps.num_queries, ps.k);
auto dists_out_view = raft::make_device_matrix_view<T, IdxT>(
distances_ivfflat_dev.data(), ps.num_queries, ps.k);
const std::string filename = "ivf_flat_index";
cuvs::neighbors::ivf_flat::serialize(handle_, filename, index_2);
tmp_index_file index_file;
cuvs::neighbors::ivf_flat::serialize(handle_, index_file.filename, index_2);
cuvs::neighbors::ivf_flat::index<DataT, IdxT> index_loaded(handle_);
cuvs::neighbors::ivf_flat::deserialize(handle_, filename, &index_loaded);
cuvs::neighbors::ivf_flat::deserialize(handle_, index_file.filename, &index_loaded);
ASSERT_EQ(index_2.size(), index_loaded.size());

cuvs::neighbors::ivf_flat::search(handle_,
Expand Down
6 changes: 3 additions & 3 deletions cpp/tests/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {

auto build_serialize()
{
std::string filename = "ivf_pq_index";
cuvs::neighbors::ivf_pq::serialize(handle_, filename, build_only());
tmp_index_file index_file;
cuvs::neighbors::ivf_pq::serialize(handle_, index_file.filename, build_only());
cuvs::neighbors::ivf_pq::index<IdxT> index(handle_);
cuvs::neighbors::ivf_pq::deserialize(handle_, filename, &index);
cuvs::neighbors::ivf_pq::deserialize(handle_, index_file.filename, &index);
return index;
}

Expand Down
26 changes: 26 additions & 0 deletions cpp/tests/neighbors/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
#include "naive_knn.cuh"

#include "../test_utils.cuh"
#include <atomic>
#include <cstdio>
#include <filesystem>
#include <gtest/gtest.h>
#include <iostream>
#include <limits>
Expand Down Expand Up @@ -346,4 +349,27 @@ auto eval_distances(raft::resources const& handle,
}
return testing::AssertionSuccess();
}

/**
* A helper class to create a temporary file for a cuVS index object in the system's temp directory.
* The file will be automatically deleted when the object is destroyed.
*/
struct tmp_index_file {
// Ideally, we should use std::tmpfile() or another system-provided API to create a temporary
// file. However, our API requires a file name, so we cannot use the file descriptors. There's no
// recommended way to generate a robust unique temp filenames, so we use a combination of a
// counter, process id, and random number.
std::string filename = (std::filesystem::temp_directory_path() /
("cuvs_" + std::to_string(getpid()) + "_" + std::to_string(counter++) +
"_" + std::to_string(std::rand())))
.string();
~tmp_index_file()
{
if (std::filesystem::exists(filename)) { std::filesystem::remove(filename); }
}

private:
static inline std::atomic<uint64_t> counter = 0;
};

} // namespace cuvs::neighbors
3 changes: 2 additions & 1 deletion cpp/tests/neighbors/ann_vamana.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ class AnnVamanaTest : public ::testing::TestWithParam<AnnVamanaInputs> {

CheckGraph<DataT, IdxT>(&index, ps, stream_);

vamana::serialize(handle_, "vamana_index", index);
tmp_index_file index_file;
vamana::serialize(handle_, index_file.filename, index);

// Test recall by searching with CAGRA search
if (ps.graph_degree < 256) { // CAGRA search result buffer cannot support larger graph degree
Expand Down
30 changes: 18 additions & 12 deletions cpp/tests/neighbors/mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,14 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
auto distances = raft::make_host_matrix_view<float, int64_t, row_major>(
distances_snmg_ann.data(), ps.num_queries, ps.k);

tmp_index_file index_file;
{
auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset);
cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt);
cuvs::neighbors::mg::serialize(handle_, index, "mg_ivf_flat_index");
cuvs::neighbors::mg::serialize(handle_, index, index_file.filename);
}
auto new_index =
cuvs::neighbors::mg::deserialize_flat<DataT, int64_t>(handle_, "mg_ivf_flat_index");
cuvs::neighbors::mg::deserialize_flat<DataT, int64_t>(handle_, index_file.filename);

if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK)
search_params.merge_mode = MERGE_ON_ROOT_RANK;
Expand Down Expand Up @@ -176,13 +177,14 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
auto distances = raft::make_host_matrix_view<float, int64_t, row_major>(
distances_snmg_ann.data(), ps.num_queries, ps.k);

tmp_index_file index_file;
{
auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset);
cuvs::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt);
cuvs::neighbors::mg::serialize(handle_, index, "mg_ivf_pq_index");
cuvs::neighbors::mg::serialize(handle_, index, index_file.filename);
}
auto new_index =
cuvs::neighbors::mg::deserialize_pq<DataT, int64_t>(handle_, "mg_ivf_pq_index");
cuvs::neighbors::mg::deserialize_pq<DataT, int64_t>(handle_, index_file.filename);

if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK)
search_params.merge_mode = MERGE_ON_ROOT_RANK;
Expand Down Expand Up @@ -230,12 +232,13 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
auto distances = raft::make_host_matrix_view<float, uint32_t, row_major>(
distances_snmg_ann.data(), ps.num_queries, ps.k);

tmp_index_file index_file;
{
auto index = cuvs::neighbors::mg::build(handle_, index_params, index_dataset);
cuvs::neighbors::mg::serialize(handle_, index, "mg_cagra_index");
cuvs::neighbors::mg::serialize(handle_, index, index_file.filename);
}
auto new_index =
cuvs::neighbors::mg::deserialize_cagra<DataT, uint32_t>(handle_, "mg_cagra_index");
cuvs::neighbors::mg::deserialize_cagra<DataT, uint32_t>(handle_, index_file.filename);

if (ps.m_mode == m_mode_t::MERGE_ON_ROOT_RANK)
search_params.merge_mode = MERGE_ON_ROOT_RANK;
Expand Down Expand Up @@ -271,11 +274,12 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
search_params.n_probes = ps.nprobe;
search_params.search_mode = LOAD_BALANCER;

tmp_index_file index_file;
{
auto index_dataset = raft::make_device_matrix_view<const DataT, int64_t>(
d_index_dataset.data(), ps.num_db_vecs, ps.dim);
auto index = cuvs::neighbors::ivf_flat::build(handle_, index_params, index_dataset);
ivf_flat::serialize(handle_, "local_ivf_flat_index", index);
ivf_flat::serialize(handle_, index_file.filename, index);
}

auto queries = raft::make_host_matrix_view<const DataT, int64_t, row_major>(
Expand All @@ -286,7 +290,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
distances_snmg_ann.data(), ps.num_queries, ps.k);

auto distributed_index =
cuvs::neighbors::mg::distribute_flat<DataT, int64_t>(handle_, "local_ivf_flat_index");
cuvs::neighbors::mg::distribute_flat<DataT, int64_t>(handle_, index_file.filename);
search_params.merge_mode = TREE_MERGE;
cuvs::neighbors::mg::search(handle_,
distributed_index,
Expand Down Expand Up @@ -323,11 +327,12 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
search_params.n_probes = ps.nprobe;
search_params.search_mode = LOAD_BALANCER;

tmp_index_file index_file;
{
auto index_dataset = raft::make_device_matrix_view<const DataT, int64_t>(
d_index_dataset.data(), ps.num_db_vecs, ps.dim);
auto index = cuvs::neighbors::ivf_pq::build(handle_, index_params, index_dataset);
ivf_pq::serialize(handle_, "local_ivf_pq_index", index);
ivf_pq::serialize(handle_, index_file.filename, index);
}

auto queries = raft::make_host_matrix_view<const DataT, int64_t, row_major>(
Expand All @@ -338,7 +343,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
distances_snmg_ann.data(), ps.num_queries, ps.k);

auto distributed_index =
cuvs::neighbors::mg::distribute_pq<DataT, int64_t>(handle_, "local_ivf_pq_index");
cuvs::neighbors::mg::distribute_pq<DataT, int64_t>(handle_, index_file.filename);
search_params.merge_mode = TREE_MERGE;
cuvs::neighbors::mg::search(handle_,
distributed_index,
Expand Down Expand Up @@ -370,11 +375,12 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {

mg::search_params<cagra::search_params> search_params;

tmp_index_file index_file;
{
auto index_dataset = raft::make_device_matrix_view<const DataT, int64_t>(
d_index_dataset.data(), ps.num_db_vecs, ps.dim);
auto index = cuvs::neighbors::cagra::build(handle_, index_params, index_dataset);
cuvs::neighbors::cagra::serialize(handle_, "local_cagra_index", index);
cuvs::neighbors::cagra::serialize(handle_, index_file.filename, index);
}

auto queries = raft::make_host_matrix_view<const DataT, int64_t, row_major>(
Expand All @@ -385,7 +391,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
distances_snmg_ann.data(), ps.num_queries, ps.k);

auto distributed_index =
cuvs::neighbors::mg::distribute_cagra<DataT, uint32_t>(handle_, "local_cagra_index");
cuvs::neighbors::mg::distribute_cagra<DataT, uint32_t>(handle_, index_file.filename);

search_params.merge_mode = TREE_MERGE;
cuvs::neighbors::mg::search(handle_,
Expand Down
Loading