@@ -565,6 +565,62 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
565565 }
566566 }
567567
568+ /* *
569+ * Test that FLAT layout produces the same PQ codes as INTERLEAVED layout.
570+ * Builds both layouts and compares the unpacked codes.
571+ */
572+ void check_flat_layout_codes ()
573+ {
574+ auto ipams_interleaved = ps.index_params ;
575+ ipams_interleaved.add_data_on_build = true ;
576+ ipams_interleaved.codes_layout = list_layout::INTERLEAVED;
577+
578+ auto database_view =
579+ raft::make_device_matrix_view<const DataT, int64_t >(database.data (), ps.num_db_vecs , ps.dim );
580+ auto index_interleaved =
581+ cuvs::neighbors::ivf_pq::build (handle_, ipams_interleaved, database_view);
582+
583+ auto ipams_flat = ps.index_params ;
584+ ipams_flat.add_data_on_build = true ;
585+ ipams_flat.codes_layout = list_layout::FLAT;
586+
587+ auto index_flat = cuvs::neighbors::ivf_pq::build (handle_, ipams_flat, database_view);
588+
589+ ASSERT_EQ (index_interleaved.codes_layout (), list_layout::INTERLEAVED);
590+ ASSERT_EQ (index_flat.codes_layout (), list_layout::FLAT);
591+
592+ ASSERT_TRUE (cuvs::devArrMatch (index_interleaved.list_sizes ().data_handle (),
593+ index_flat.list_sizes ().data_handle (),
594+ index_interleaved.n_lists (),
595+ cuvs::Compare<uint32_t >{}));
596+
597+ uint32_t bytes_per_vector =
598+ raft::div_rounding_up_safe (index_flat.pq_dim () * index_flat.pq_bits (), 8u );
599+
600+ for (uint32_t label = 0 ; label < index_interleaved.n_lists (); label++) {
601+ auto & list_interleaved = index_interleaved.lists ()[label];
602+ auto & list_flat = index_flat.lists ()[label];
603+
604+ uint32_t n_rows = list_interleaved->size .load ();
605+ if (n_rows == 0 ) { continue ; }
606+
607+ rmm::device_uvector<uint8_t > interleaved_codes (n_rows * bytes_per_vector, stream_);
608+ helpers::codepacker::unpack_contiguous (handle_,
609+ list_interleaved->data .view (),
610+ index_interleaved.pq_bits (),
611+ 0 ,
612+ n_rows,
613+ index_interleaved.pq_dim (),
614+ interleaved_codes.data ());
615+
616+ ASSERT_TRUE (cuvs::devArrMatch (interleaved_codes.data (),
617+ list_flat->data .data_handle (),
618+ n_rows * bytes_per_vector,
619+ cuvs::Compare<uint8_t >{}))
620+ << " PQ codes mismatch at list " << label;
621+ }
622+ }
623+
568624 void SetUp () override // NOLINT
569625 {
570626 gen_data ();
@@ -1147,7 +1203,76 @@ inline auto special_cases() -> test_cases_t
11471203#define TEST_BUILD_PRECOMPUTED (type ) \
11481204 TEST_P (type, build_precomputed) /* NOLINT */ { this ->build_precomputed (); }
11491205
1206+ #define TEST_FLAT_LAYOUT_CODES (type ) \
1207+ TEST_P (type, flat_layout_codes) /* NOLINT */ { this ->check_flat_layout_codes (); }
1208+
11501209#define INSTANTIATE (type, vals ) \
11511210 INSTANTIATE_TEST_SUITE_P (IvfPq, type, ::testing::ValuesIn(vals)); /* NOLINT */
11521211
1212+ /* *
1213+ * Test cases for flat layout comparison.
1214+ * These test all pq_bits values (4-8) to ensure correct encoding.
1215+ */
1216+ inline auto flat_layout_tests () -> test_cases_t
1217+ {
1218+ test_cases_t xs;
1219+
1220+ // Test with pq_bits = 4 (byte-aligned, 2 codes per byte)
1221+ add_test_case (xs, [](ivf_pq_inputs& x) {
1222+ x.num_db_vecs = 1000 ;
1223+ x.dim = 64 ;
1224+ x.index_params .n_lists = 10 ;
1225+ x.index_params .pq_bits = 4 ;
1226+ x.index_params .pq_dim = 16 ;
1227+ });
1228+
1229+ // Test with pq_bits = 5 (not byte-aligned)
1230+ add_test_case (xs, [](ivf_pq_inputs& x) {
1231+ x.num_db_vecs = 1000 ;
1232+ x.dim = 64 ;
1233+ x.index_params .n_lists = 10 ;
1234+ x.index_params .pq_bits = 5 ;
1235+ x.index_params .pq_dim = 16 ;
1236+ });
1237+
1238+ // Test with pq_bits = 6 (not byte-aligned)
1239+ add_test_case (xs, [](ivf_pq_inputs& x) {
1240+ x.num_db_vecs = 1000 ;
1241+ x.dim = 64 ;
1242+ x.index_params .n_lists = 10 ;
1243+ x.index_params .pq_bits = 6 ;
1244+ x.index_params .pq_dim = 16 ;
1245+ });
1246+
1247+ // Test with pq_bits = 7 (not byte-aligned)
1248+ add_test_case (xs, [](ivf_pq_inputs& x) {
1249+ x.num_db_vecs = 1000 ;
1250+ x.dim = 64 ;
1251+ x.index_params .n_lists = 10 ;
1252+ x.index_params .pq_bits = 7 ;
1253+ x.index_params .pq_dim = 16 ;
1254+ });
1255+
1256+ // Test with pq_bits = 8 (byte-aligned, 1 code per byte)
1257+ add_test_case (xs, [](ivf_pq_inputs& x) {
1258+ x.num_db_vecs = 1000 ;
1259+ x.dim = 64 ;
1260+ x.index_params .n_lists = 10 ;
1261+ x.index_params .pq_bits = 8 ;
1262+ x.index_params .pq_dim = 16 ;
1263+ });
1264+
1265+ // Test with PER_CLUSTER codebook
1266+ add_test_case (xs, [](ivf_pq_inputs& x) {
1267+ x.num_db_vecs = 1000 ;
1268+ x.dim = 64 ;
1269+ x.index_params .n_lists = 10 ;
1270+ x.index_params .pq_bits = 8 ;
1271+ x.index_params .pq_dim = 16 ;
1272+ x.index_params .codebook_kind = codebook_gen::PER_CLUSTER;
1273+ });
1274+
1275+ return xs;
1276+ }
1277+
11531278} // namespace cuvs::neighbors::ivf_pq
0 commit comments