Skip to content

Commit dc99bfb

Browse files
committed
fix gtests
1 parent ee28540 commit dc99bfb

3 files changed

Lines changed: 132 additions & 1 deletion

File tree

cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -859,6 +859,9 @@ inline void search(raft::resources const& handle,
859859
static_cast<uint64_t>(index.size()));
860860
RAFT_EXPECTS(params.n_probes > 0,
861861
"n_probes (number of clusters to probe in the search) must be positive.");
862+
RAFT_EXPECTS(index.codes_layout() == list_layout::INTERLEAVED,
863+
"IVF-PQ search requires INTERLEAVED codes layout. FLAT layout is not supported for "
864+
"GPU search.");
862865

863866
switch (utils::check_pointer_residency(queries, neighbors, distances)) {
864867
case utils::pointer_residency::device_only:

cpp/tests/neighbors/ann_ivf_pq.cuh

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,62 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
565565
}
566566
}
567567

568+
/**
569+
* Test that FLAT layout produces the same PQ codes as INTERLEAVED layout.
570+
* Builds both layouts and compares the unpacked codes.
571+
*/
572+
void check_flat_layout_codes()
573+
{
574+
auto ipams_interleaved = ps.index_params;
575+
ipams_interleaved.add_data_on_build = true;
576+
ipams_interleaved.codes_layout = list_layout::INTERLEAVED;
577+
578+
auto database_view =
579+
raft::make_device_matrix_view<const DataT, int64_t>(database.data(), ps.num_db_vecs, ps.dim);
580+
auto index_interleaved =
581+
cuvs::neighbors::ivf_pq::build(handle_, ipams_interleaved, database_view);
582+
583+
auto ipams_flat = ps.index_params;
584+
ipams_flat.add_data_on_build = true;
585+
ipams_flat.codes_layout = list_layout::FLAT;
586+
587+
auto index_flat = cuvs::neighbors::ivf_pq::build(handle_, ipams_flat, database_view);
588+
589+
ASSERT_EQ(index_interleaved.codes_layout(), list_layout::INTERLEAVED);
590+
ASSERT_EQ(index_flat.codes_layout(), list_layout::FLAT);
591+
592+
ASSERT_TRUE(cuvs::devArrMatch(index_interleaved.list_sizes().data_handle(),
593+
index_flat.list_sizes().data_handle(),
594+
index_interleaved.n_lists(),
595+
cuvs::Compare<uint32_t>{}));
596+
597+
uint32_t bytes_per_vector =
598+
raft::div_rounding_up_safe(index_flat.pq_dim() * index_flat.pq_bits(), 8u);
599+
600+
for (uint32_t label = 0; label < index_interleaved.n_lists(); label++) {
601+
auto& list_interleaved = index_interleaved.lists()[label];
602+
auto& list_flat = index_flat.lists()[label];
603+
604+
uint32_t n_rows = list_interleaved->size.load();
605+
if (n_rows == 0) { continue; }
606+
607+
rmm::device_uvector<uint8_t> interleaved_codes(n_rows * bytes_per_vector, stream_);
608+
helpers::codepacker::unpack_contiguous(handle_,
609+
list_interleaved->data.view(),
610+
index_interleaved.pq_bits(),
611+
0,
612+
n_rows,
613+
index_interleaved.pq_dim(),
614+
interleaved_codes.data());
615+
616+
ASSERT_TRUE(cuvs::devArrMatch(interleaved_codes.data(),
617+
list_flat->data.data_handle(),
618+
n_rows * bytes_per_vector,
619+
cuvs::Compare<uint8_t>{}))
620+
<< "PQ codes mismatch at list " << label;
621+
}
622+
}
623+
568624
void SetUp() override // NOLINT
569625
{
570626
gen_data();
@@ -1147,7 +1203,76 @@ inline auto special_cases() -> test_cases_t
11471203
#define TEST_BUILD_PRECOMPUTED(type) \
11481204
TEST_P(type, build_precomputed) /* NOLINT */ { this->build_precomputed(); }
11491205

1206+
#define TEST_FLAT_LAYOUT_CODES(type) \
1207+
TEST_P(type, flat_layout_codes) /* NOLINT */ { this->check_flat_layout_codes(); }
1208+
11501209
#define INSTANTIATE(type, vals) \
11511210
INSTANTIATE_TEST_SUITE_P(IvfPq, type, ::testing::ValuesIn(vals)); /* NOLINT */
11521211

1212+
/**
1213+
* Test cases for flat layout comparison.
1214+
* These test all pq_bits values (4-8) to ensure correct encoding.
1215+
*/
1216+
inline auto flat_layout_tests() -> test_cases_t
1217+
{
1218+
test_cases_t xs;
1219+
1220+
// Test with pq_bits = 4 (byte-aligned, 2 codes per byte)
1221+
add_test_case(xs, [](ivf_pq_inputs& x) {
1222+
x.num_db_vecs = 1000;
1223+
x.dim = 64;
1224+
x.index_params.n_lists = 10;
1225+
x.index_params.pq_bits = 4;
1226+
x.index_params.pq_dim = 16;
1227+
});
1228+
1229+
// Test with pq_bits = 5 (not byte-aligned)
1230+
add_test_case(xs, [](ivf_pq_inputs& x) {
1231+
x.num_db_vecs = 1000;
1232+
x.dim = 64;
1233+
x.index_params.n_lists = 10;
1234+
x.index_params.pq_bits = 5;
1235+
x.index_params.pq_dim = 16;
1236+
});
1237+
1238+
// Test with pq_bits = 6 (not byte-aligned)
1239+
add_test_case(xs, [](ivf_pq_inputs& x) {
1240+
x.num_db_vecs = 1000;
1241+
x.dim = 64;
1242+
x.index_params.n_lists = 10;
1243+
x.index_params.pq_bits = 6;
1244+
x.index_params.pq_dim = 16;
1245+
});
1246+
1247+
// Test with pq_bits = 7 (not byte-aligned)
1248+
add_test_case(xs, [](ivf_pq_inputs& x) {
1249+
x.num_db_vecs = 1000;
1250+
x.dim = 64;
1251+
x.index_params.n_lists = 10;
1252+
x.index_params.pq_bits = 7;
1253+
x.index_params.pq_dim = 16;
1254+
});
1255+
1256+
// Test with pq_bits = 8 (byte-aligned, 1 code per byte)
1257+
add_test_case(xs, [](ivf_pq_inputs& x) {
1258+
x.num_db_vecs = 1000;
1259+
x.dim = 64;
1260+
x.index_params.n_lists = 10;
1261+
x.index_params.pq_bits = 8;
1262+
x.index_params.pq_dim = 16;
1263+
});
1264+
1265+
// Test with PER_CLUSTER codebook
1266+
add_test_case(xs, [](ivf_pq_inputs& x) {
1267+
x.num_db_vecs = 1000;
1268+
x.dim = 64;
1269+
x.index_params.n_lists = 10;
1270+
x.index_params.pq_bits = 8;
1271+
x.index_params.pq_dim = 16;
1272+
x.index_params.codebook_kind = codebook_gen::PER_CLUSTER;
1273+
});
1274+
1275+
return xs;
1276+
}
1277+
11531278
} // namespace cuvs::neighbors::ivf_pq

cpp/tests/neighbors/ann_ivf_pq/test_float_int64_t.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ INSTANTIATE(f32_f32_i64,
1919
defaults() + small_dims() + big_dims_moderate_lut() + enum_variety_l2() +
2020
enum_variety_l2sqrt() + enum_variety_ip() + enum_variety_cosine());
2121

22+
TEST_FLAT_LAYOUT_CODES(f32_f32_i64)
23+
INSTANTIATE_TEST_SUITE_P(IvfPqFlatLayout, f32_f32_i64, ::testing::ValuesIn(flat_layout_tests()));
24+
2225
TEST_BUILD_SEARCH(f32_f32_i64_filter)
2326
INSTANTIATE(f32_f32_i64_filter,
2427
defaults() + small_dims() + big_dims_moderate_lut() + enum_variety_l2() +

0 commit comments

Comments
 (0)