Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,28 @@ if(BUILD_TESTS)
)

ConfigureTest(
NAME
NEIGHBORS_ANN_CAGRA_TEST
PATH
neighbors/ann_cagra/bug_extreme_inputs_oob.cu
neighbors/ann_cagra/bug_multi_cta_crash.cu
neighbors/ann_cagra/test_float_uint32_t.cu
neighbors/ann_cagra/test_half_uint32_t.cu
neighbors/ann_cagra/test_int8_t_uint32_t.cu
neighbors/ann_cagra/test_uint8_t_uint32_t.cu
GPUS
1
PERCENT
100
NAME NEIGHBORS_ANN_CAGRA_TEST_BUGS PATH neighbors/ann_cagra/bug_extreme_inputs_oob.cu
neighbors/ann_cagra/bug_multi_cta_crash.cu GPUS 1 PERCENT 100
)

ConfigureTest(
NAME NEIGHBORS_ANN_CAGRA_FLOAT_UINT32_TEST PATH neighbors/ann_cagra/test_float_uint32_t.cu GPUS
1 PERCENT 100
)

ConfigureTest(
NAME NEIGHBORS_ANN_CAGRA_HALF_UINT32_TEST PATH neighbors/ann_cagra/test_half_uint32_t.cu GPUS 1
PERCENT 100
)

ConfigureTest(
NAME NEIGHBORS_ANN_CAGRA_INT8_UINT32_TEST PATH neighbors/ann_cagra/test_int8_t_uint32_t.cu GPUS
1 PERCENT 100
)

ConfigureTest(
NAME NEIGHBORS_ANN_CAGRA_UINT8_UINT32_TEST PATH neighbors/ann_cagra/test_uint8_t_uint32_t.cu
GPUS 1 PERCENT 100
)

ConfigureTest(
Expand Down
53 changes: 31 additions & 22 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -876,14 +876,15 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
inline std::vector<AnnCagraInputs> generate_inputs()
{
// TODO(tfeher): test MULTI_CTA kernel with search_width > 1 to allow multiple CTA per queries
// Varying dim, k, graph_build_algo, search_algo, max_queries
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love the extra annotations on what are we varying. Thanks!

std::vector<AnnCagraInputs> inputs = raft::util::itertools::product<AnnCagraInputs>(
{100},
{1000},
{1, 8, 17},
{1, 17},
{1, 16}, // k
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dim and build algo combinations are tested below, therefor we focus on dim and search algo and max_query parameter value here.

{graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT},
{search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL},
{0, 1, 10, 100}, // query size
{0, 1, 100}, // query size
{0},
{256},
{1},
Expand All @@ -892,11 +893,12 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{true},
{0.995});

// Varying dim, graph_build_algo
auto inputs2 = raft::util::itertools::product<AnnCagraInputs>(
{100},
{1000},
{1, 3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim
{16}, // k
{1, 3, 7, 17, 128, 192, 512, 1024}, // dim
{16}, // k
{graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT},
{search_algo::AUTO},
{10},
Expand All @@ -908,6 +910,8 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{true},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

// Varying team_size, graph_build_algo
inputs2 = raft::util::itertools::product<AnnCagraInputs>(
{100},
{1000},
Expand All @@ -925,6 +929,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

// Varying graph_build_algo, itopk_size
inputs2 = raft::util::itertools::product<AnnCagraInputs>(
{100},
{1000},
Expand All @@ -942,6 +947,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

// Varying n_rows, host_dataset
inputs2 =
raft::util::itertools::product<AnnCagraInputs>({100},
{10000, 20000},
Expand All @@ -959,7 +965,8 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0.985});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

// a few PQ configurations
// A few PQ configurations.
// Varying dim, vq_n_centers
inputs2 = raft::util::itertools::product<AnnCagraInputs>(
{100},
{10000},
Expand Down Expand Up @@ -987,11 +994,12 @@ inline std::vector<AnnCagraInputs> generate_inputs()
}
}

// refinement options
// Refinement options
// Varying host_dataset, ivf_pq_search_refine_ratio
inputs2 =
raft::util::itertools::product<AnnCagraInputs>({100},
{5000},
{32, 64},
{64},
Comment thread
bdice marked this conversation as resolved.
Outdated
{16},
{graph_build_algo::IVF_PQ},
{search_algo::AUTO},
Expand All @@ -1006,21 +1014,22 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1.0f, 2.0f, 3.0f});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

inputs2 = raft::util::itertools::product<AnnCagraInputs>(
{100},
{1000},
{1, 3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim
{10},
{graph_build_algo::IVF_PQ},
{search_algo::AUTO},
{10},
{0}, // team_size
{64},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{false},
{false},
{0.995});
// Varying dim, adding non_owning_memory_buffer_flag
inputs2 =
raft::util::itertools::product<AnnCagraInputs>({100},
{1000},
{1, 5, 8, 64, 137, 256, 619, 1024}, // dim
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we test a few edge cases where the dimensionality is 2ⁿ or 2ⁿ ± 1. These are important to check whether we have any problems with padding the data during build and search.

Copy link
Copy Markdown
Contributor Author

@bdice bdice Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured that was the case. There were a few places that varied dim over a wide range of values. I kept different dim values in each test so that the full range would be covered. Perhaps we can cover the full range in only one set of inputs instead, and use a small range for the other cases?

{10},
{graph_build_algo::IVF_PQ},
{search_algo::AUTO},
{10},
{0}, // team_size
{64},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{false},
{false},
{0.995});
for (auto input : inputs2) {
input.non_owning_memory_buffer_flag = true;
inputs.push_back(input);
Expand Down