99#include < omp.h>
1010
1111#include < memory>
12+ #include < numeric>
1213
1314#include < faiss/IndexAdditiveQuantizer.h>
1415#include < faiss/IndexIVFAdditiveQuantizer.h>
@@ -529,20 +530,30 @@ void handle_ivf(
529530 faiss::IndexIVF* index,
530531 int64_t shard_count,
531532 const std::string& filename_template,
532- ShardingFunction* sharding_function) {
533+ ShardingFunction* sharding_function,
534+ bool generate_ids) {
533535 std::vector<faiss::IndexIVF*> sharded_indexes (shard_count);
534536 auto clone = static_cast <faiss::IndexIVF*>(faiss::clone_index (index));
535537 clone->quantizer ->reset ();
536538 for (int64_t i = 0 ; i < shard_count; i++) {
537539 sharded_indexes[i] =
538540 static_cast <faiss::IndexIVF*>(faiss::clone_index (clone));
541+ if (generate_ids) {
542+ // Assume the quantizer does not natively support add_with_ids.
543+ sharded_indexes[i]->quantizer =
544+ new IndexIDMap2 (sharded_indexes[i]->quantizer );
545+ }
539546 }
540547
541548 // assign centroids to each sharded Index based on sharding_function, and
542549 // add them to the quantizer of each sharded index
543550 std::vector<std::vector<float >> sharded_centroids (shard_count);
551+ std::vector<std::vector<idx_t >> xids (shard_count);
544552 for (int64_t i = 0 ; i < index->quantizer ->ntotal ; i++) {
545553 int64_t shard_id = (*sharding_function)(i, shard_count);
554+ // Since the quantizer does not natively support add_with_ids, we simply
555+ // generate them.
556+ xids[shard_id].push_back (i);
546557 float * reconstructed = new float [index->quantizer ->d ];
547558 index->quantizer ->reconstruct (i, reconstructed);
548559 sharded_centroids[shard_id].insert (
@@ -552,9 +563,16 @@ void handle_ivf(
552563 delete[] reconstructed;
553564 }
554565 for (int64_t i = 0 ; i < shard_count; i++) {
555- sharded_indexes[i]->quantizer ->add (
556- sharded_centroids[i].size () / index->quantizer ->d ,
557- sharded_centroids[i].data ());
566+ if (generate_ids) {
567+ sharded_indexes[i]->quantizer ->add_with_ids (
568+ sharded_centroids[i].size () / index->quantizer ->d ,
569+ sharded_centroids[i].data (),
570+ xids[i].data ());
571+ } else {
572+ sharded_indexes[i]->quantizer ->add (
573+ sharded_centroids[i].size () / index->quantizer ->d ,
574+ sharded_centroids[i].data ());
575+ }
558576 }
559577
560578 for (int64_t i = 0 ; i < shard_count; i++) {
@@ -572,7 +590,8 @@ void handle_binary_ivf(
572590 faiss::IndexBinaryIVF* index,
573591 int64_t shard_count,
574592 const std::string& filename_template,
575- ShardingFunction* sharding_function) {
593+ ShardingFunction* sharding_function,
594+ bool generate_ids) {
576595 std::vector<faiss::IndexBinaryIVF*> sharded_indexes (shard_count);
577596
578597 auto clone = static_cast <faiss::IndexBinaryIVF*>(
@@ -582,14 +601,23 @@ void handle_binary_ivf(
582601 for (int64_t i = 0 ; i < shard_count; i++) {
583602 sharded_indexes[i] = static_cast <faiss::IndexBinaryIVF*>(
584603 faiss::clone_binary_index (clone));
604+ if (generate_ids) {
605+ // Assume the quantizer does not natively support add_with_ids.
606+ sharded_indexes[i]->quantizer =
607+ new IndexBinaryIDMap2 (sharded_indexes[i]->quantizer );
608+ }
585609 }
586610
587611 // assign centroids to each sharded Index based on sharding_function, and
588612 // add them to the quantizer of each sharded index
589613 int64_t reconstruction_size = index->quantizer ->d / 8 ;
590614 std::vector<std::vector<uint8_t >> sharded_centroids (shard_count);
615+ std::vector<std::vector<idx_t >> xids (shard_count);
591616 for (int64_t i = 0 ; i < index->quantizer ->ntotal ; i++) {
592617 int64_t shard_id = (*sharding_function)(i, shard_count);
618+ // Since the quantizer does not natively support add_with_ids, we simply
619+ // generate them.
620+ xids[shard_id].push_back (i);
593621 uint8_t * reconstructed = new uint8_t [reconstruction_size];
594622 index->quantizer ->reconstruct (i, reconstructed);
595623 sharded_centroids[shard_id].insert (
@@ -599,9 +627,16 @@ void handle_binary_ivf(
599627 delete[] reconstructed;
600628 }
601629 for (int64_t i = 0 ; i < shard_count; i++) {
602- sharded_indexes[i]->quantizer ->add (
603- sharded_centroids[i].size () / reconstruction_size,
604- sharded_centroids[i].data ());
630+ if (generate_ids) {
631+ sharded_indexes[i]->quantizer ->add_with_ids (
632+ sharded_centroids[i].size () / reconstruction_size,
633+ sharded_centroids[i].data (),
634+ xids[i].data ());
635+ } else {
636+ sharded_indexes[i]->quantizer ->add (
637+ sharded_centroids[i].size () / reconstruction_size,
638+ sharded_centroids[i].data ());
639+ }
605640 }
606641
607642 for (int64_t i = 0 ; i < shard_count; i++) {
@@ -620,7 +655,8 @@ void sharding_helper(
620655 IndexType* index,
621656 int64_t shard_count,
622657 const std::string& filename_template,
623- ShardingFunction* sharding_function) {
658+ ShardingFunction* sharding_function,
659+ bool generate_ids) {
624660 FAISS_THROW_IF_MSG (index->quantizer ->ntotal == 0 , " No centroids to shard." );
625661 FAISS_THROW_IF_MSG (
626662 filename_template.find (" %d" ) == std::string::npos,
@@ -636,30 +672,44 @@ void sharding_helper(
636672 dynamic_cast <faiss::IndexIVF*>(index),
637673 shard_count,
638674 filename_template,
639- sharding_function);
675+ sharding_function,
676+ generate_ids);
640677 } else if (typeid (IndexType) == typeid (faiss::IndexBinaryIVF)) {
641678 handle_binary_ivf (
642679 dynamic_cast <faiss::IndexBinaryIVF*>(index),
643680 shard_count,
644681 filename_template,
645- sharding_function);
682+ sharding_function,
683+ generate_ids);
646684 }
647685}
648686
649687void shard_ivf_index_centroids (
650688 faiss::IndexIVF* index,
651689 int64_t shard_count,
652690 const std::string& filename_template,
653- ShardingFunction* sharding_function) {
654- sharding_helper (index, shard_count, filename_template, sharding_function);
691+ ShardingFunction* sharding_function,
692+ bool generate_ids) {
693+ sharding_helper (
694+ index,
695+ shard_count,
696+ filename_template,
697+ sharding_function,
698+ generate_ids);
655699}
656700
657701void shard_binary_ivf_index_centroids (
658702 faiss::IndexBinaryIVF* index,
659703 int64_t shard_count,
660704 const std::string& filename_template,
661- ShardingFunction* sharding_function) {
662- sharding_helper (index, shard_count, filename_template, sharding_function);
705+ ShardingFunction* sharding_function,
706+ bool generate_ids) {
707+ sharding_helper (
708+ index,
709+ shard_count,
710+ filename_template,
711+ sharding_function,
712+ generate_ids);
663713}
664714
665715} // namespace ivflib
0 commit comments