From 2a83fbf0118ffd6e02063f3b1f89c22223271afc Mon Sep 17 00:00:00 2001 From: Michael Norris Date: Thu, 6 Feb 2025 13:17:50 -0800 Subject: [PATCH] Add sharding convenience function for IVF indexes (#4150) Summary: Creates a sharding convenience function for IVF indexes. - The __**centroids on the quantizer**__ are sharded based on the given sharding function. (not the data, as data sharding by ids is already implemented by copy_subuset_to, https://github.com/facebookresearch/faiss/blob/main/faiss/IndexIVF.h#L408) - The output is written to files based on the template filename generator param. - The default sharding function is simply the ith vector mod the total shard count. This would called by Laser here: https://www.internalfb.com/code/fbsource/[ce1f2e028e79]/fbcode/fblearner/flow/projects/laser/laser_sim_search/knn_trainer.py?lines=295-296. This convenience function will do the file writing, and return the created file names. There's a few key required changes in FAISS: 1. Allow `std::vector` to be used. Updates swigfaiss.swig and array_conversions.py to accommodate. These have to be numpy dtype of `object` instead of the more correct `unicode`, because unicode dtype is fixed length. I couldn't figure out how to create a numpy array with each of the output file names where they have different dtypes. (Say the file names are like file1, file11, file111. The dtype would need to be U5, U6, U7 respectively, as the dtype for unicode contains the length). I tried structured arrays : this does not work either, as numpy makes it into a matrix instead: the `file1 file11 file111` example with explicit setting of U5, U6, U7 turns into `[[file1 file1 file1], [file1 file11 file11], [file1 file11 file111]]`, which we do not want. If someone knows the right syntax, please yell at me 2. Create Python callbacks for sharding and template filename: `PyCallbackFilenameTemplateGenerator` and `PyCallbackShardingFunction`. Users of this function would inherit from the FilenameTemplateGenerator or ShardingFunction in C++ to pass to `shard_ivf_index_centroids`. See the other examples in python_callbacks.cpp. This is required because Python functions cannot be passed through SWIG to C++ (i.e. no std::function or function pointers), so we have to use this approach. This approach allows it to be called from both C++ and Python. test_sharding.py shows the Python calling, test_utils.cpp shows the C++ calling. Differential Revision: D68534991 --- faiss/IVFlib.cpp | 143 ++++++++++++++++++++++++++++++ faiss/IVFlib.h | 38 ++++++++ faiss/clone_index.cpp | 34 +++++++ faiss/python/__init__.py | 3 +- faiss/python/class_wrappers.py | 9 ++ faiss/python/python_callbacks.cpp | 24 +++++ faiss/python/python_callbacks.h | 22 +++++ tests/test_ivflib.py | 140 +++++++++++++++++++++++++++++ 8 files changed, 412 insertions(+), 1 deletion(-) diff --git a/faiss/IVFlib.cpp b/faiss/IVFlib.cpp index 83812f6abe..11900f4b09 100644 --- a/faiss/IVFlib.cpp +++ b/faiss/IVFlib.cpp @@ -16,7 +16,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -519,5 +521,146 @@ void ivf_residual_add_from_flat_codes( index->ntotal += nb; } +int64_t DefaultShardingFunction::operator()(int64_t i, int64_t shard_count) { + return i % shard_count; +} + +void handle_ivf( + faiss::IndexIVF* index, + int64_t shard_count, + const std::string& filename_template, + ShardingFunction* sharding_function) { + std::vector sharded_indexes(shard_count); + auto clone = static_cast(faiss::clone_index(index)); + clone->quantizer->reset(); + for (int64_t i = 0; i < shard_count; i++) { + sharded_indexes[i] = + static_cast(faiss::clone_index(clone)); + } + + // assign centroids to each sharded Index based on sharding_function, and + // add them to the quantizer of each sharded index + std::vector> sharded_centroids(shard_count); + for (int64_t i = 0; i < index->quantizer->ntotal; i++) { + int64_t shard_id = (*sharding_function)(i, shard_count); + float* reconstructed = new float[index->quantizer->d]; + index->quantizer->reconstruct(i, reconstructed); + sharded_centroids[shard_id].insert( + sharded_centroids[shard_id].end(), + &reconstructed[0], + &reconstructed[index->quantizer->d]); + delete[] reconstructed; + } + for (int64_t i = 0; i < shard_count; i++) { + sharded_indexes[i]->quantizer->add( + sharded_centroids[i].size() / index->quantizer->d, + sharded_centroids[i].data()); + } + + for (int64_t i = 0; i < shard_count; i++) { + char fname[256]; + snprintf(fname, 256, filename_template.c_str(), i); + faiss::write_index(sharded_indexes[i], fname); + } + + for (int64_t i = 0; i < shard_count; i++) { + delete sharded_indexes[i]; + } +} + +void handle_binary_ivf( + faiss::IndexBinaryIVF* index, + int64_t shard_count, + const std::string& filename_template, + ShardingFunction* sharding_function) { + std::vector sharded_indexes(shard_count); + + auto clone = static_cast( + faiss::clone_binary_index(index)); + clone->quantizer->reset(); + + for (int64_t i = 0; i < shard_count; i++) { + sharded_indexes[i] = static_cast( + faiss::clone_binary_index(clone)); + } + + // assign centroids to each sharded Index based on sharding_function, and + // add them to the quantizer of each sharded index + int64_t reconstruction_size = index->quantizer->d / 8; + std::vector> sharded_centroids(shard_count); + for (int64_t i = 0; i < index->quantizer->ntotal; i++) { + int64_t shard_id = (*sharding_function)(i, shard_count); + uint8_t* reconstructed = new uint8_t[reconstruction_size]; + index->quantizer->reconstruct(i, reconstructed); + sharded_centroids[shard_id].insert( + sharded_centroids[shard_id].end(), + &reconstructed[0], + &reconstructed[reconstruction_size]); + delete[] reconstructed; + } + for (int64_t i = 0; i < shard_count; i++) { + sharded_indexes[i]->quantizer->add( + sharded_centroids[i].size() / reconstruction_size, + sharded_centroids[i].data()); + } + + for (int64_t i = 0; i < shard_count; i++) { + char fname[256]; + snprintf(fname, 256, filename_template.c_str(), i); + faiss::write_index_binary(sharded_indexes[i], fname); + } + + for (int64_t i = 0; i < shard_count; i++) { + delete sharded_indexes[i]; + } +} + +template +void sharding_helper( + IndexType* index, + int64_t shard_count, + const std::string& filename_template, + ShardingFunction* sharding_function) { + FAISS_THROW_IF_MSG(index->quantizer->ntotal == 0, "No centroids to shard."); + FAISS_THROW_IF_MSG( + filename_template.find("%d") == std::string::npos, + "Invalid filename_template. Must contain format specifier for shard count."); + + DefaultShardingFunction default_sharding_function; + if (sharding_function == nullptr) { + sharding_function = &default_sharding_function; + } + + if (typeid(IndexType) == typeid(faiss::IndexIVF)) { + handle_ivf( + dynamic_cast(index), + shard_count, + filename_template, + sharding_function); + } else if (typeid(IndexType) == typeid(faiss::IndexBinaryIVF)) { + handle_binary_ivf( + dynamic_cast(index), + shard_count, + filename_template, + sharding_function); + } +} + +void shard_ivf_index_centroids( + faiss::IndexIVF* index, + int64_t shard_count, + const std::string& filename_template, + ShardingFunction* sharding_function) { + sharding_helper(index, shard_count, filename_template, sharding_function); +} + +void shard_binary_ivf_index_centroids( + faiss::IndexBinaryIVF* index, + int64_t shard_count, + const std::string& filename_template, + ShardingFunction* sharding_function) { + sharding_helper(index, shard_count, filename_template, sharding_function); +} + } // namespace ivflib } // namespace faiss diff --git a/faiss/IVFlib.h b/faiss/IVFlib.h index 6f6a590c72..8a83ba515e 100644 --- a/faiss/IVFlib.h +++ b/faiss/IVFlib.h @@ -14,6 +14,7 @@ * IndexIVFs embedded within an IndexPreTransform. */ +#include #include #include @@ -167,6 +168,43 @@ void ivf_residual_add_from_flat_codes( const uint8_t* codes, int64_t code_size = -1); +struct ShardingFunction { + virtual int64_t operator()(int64_t i, int64_t shard_count) = 0; + virtual ~ShardingFunction() = default; + ShardingFunction() {} + ShardingFunction(const ShardingFunction&) = default; + ShardingFunction(ShardingFunction&&) = default; + ShardingFunction& operator=(const ShardingFunction&) = default; + ShardingFunction& operator=(ShardingFunction&&) = default; +}; +struct DefaultShardingFunction : ShardingFunction { + int64_t operator()(int64_t i, int64_t shard_count) override; +}; + +/** + * Shards an IVF index centroids by the given sharding function, and writes + * the index to the path given by filename_generator. The centroids must already + * be added to the index quantizer. + * + * @param index The IVF index containing centroids to shard. + * @param shard_count Number of shards. + * @param filename_template Template for shard filenames. + * @param sharding_function The function to shard by. The default is ith vector + * mod shard_count. + * @return The number of shards written. + */ +void shard_ivf_index_centroids( + IndexIVF* index, + int64_t shard_count = 20, + const std::string& filename_template = "shard.%d.index", + ShardingFunction* sharding_function = nullptr); + +void shard_binary_ivf_index_centroids( + faiss::IndexBinaryIVF* index, + int64_t shard_count = 20, + const std::string& filename_template = "shard.%d.index", + ShardingFunction* sharding_function = nullptr); + } // namespace ivflib } // namespace faiss diff --git a/faiss/clone_index.cpp b/faiss/clone_index.cpp index 7174cd6ae0..bc08283740 100644 --- a/faiss/clone_index.cpp +++ b/faiss/clone_index.cpp @@ -19,6 +19,8 @@ #include #include #include +#include +#include #include #include #include @@ -107,6 +109,11 @@ IndexIVF* Cloner::clone_IndexIVF(const IndexIVF* ivf) { return nullptr; } +IndexBinaryIVF* clone_IndexBinaryIVF(const IndexBinaryIVF* ivf) { + TRYCLONE(IndexBinaryIVF, ivf) + return nullptr; +} + IndexRefine* clone_IndexRefine(const IndexRefine* ir) { TRYCLONE(IndexRefineFlat, ir) TRYCLONE(IndexRefine, ir) { @@ -131,6 +138,11 @@ IndexHNSW* clone_IndexHNSW(const IndexHNSW* ihnsw) { } } +IndexBinaryHNSW* clone_IndexBinaryHNSW(const IndexBinaryHNSW* ihnsw) { + TRYCLONE(IndexBinaryHNSW, ihnsw) + return nullptr; +} + IndexNNDescent* clone_IndexNNDescent(const IndexNNDescent* innd) { TRYCLONE(IndexNNDescentFlat, innd) TRYCLONE(IndexNNDescent, innd) { @@ -385,6 +397,28 @@ Quantizer* clone_Quantizer(const Quantizer* quant) { IndexBinary* clone_binary_index(const IndexBinary* index) { if (auto ii = dynamic_cast(index)) { return new IndexBinaryFlat(*ii); + } else if ( + const IndexBinaryIVF* ivf = + dynamic_cast(index)) { + IndexBinaryIVF* res = clone_IndexBinaryIVF(ivf); + if (ivf->invlists == nullptr) { + res->invlists = nullptr; + } else { + res->invlists = clone_InvertedLists(ivf->invlists); + res->own_invlists = true; + } + + res->own_fields = true; + res->quantizer = clone_binary_index(ivf->quantizer); + + return res; + } else if ( + const IndexBinaryHNSW* ihnsw = + dynamic_cast(index)) { + IndexBinaryHNSW* res = clone_IndexBinaryHNSW(ihnsw); + res->own_fields = true; + res->storage = clone_binary_index(ihnsw->storage); + return res; } else { FAISS_THROW_MSG("cannot clone this type of index"); } diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index 9d956ebe71..7266da71f3 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -53,6 +53,7 @@ class_wrappers.handle_Linear(Linear) class_wrappers.handle_QINCo(QINCo) class_wrappers.handle_QINCoStep(QINCoStep) +shard_ivf_index_centroids = class_wrappers.handle_shard_ivf_index_centroids(shard_ivf_index_centroids) this_module = sys.modules[__name__] @@ -170,7 +171,7 @@ def replacement_function(*args): add_ref_in_constructor(GpuIndexIVFPQ, 1) add_ref_in_constructor(GpuIndexIVFScalarQuantizer, 1) except NameError as e: - logger.info("Failed to load GPU Faiss: %s. Will not load constructor refs for GPU indexes." % e.args[0]) + logger.info("Failed to load GPU Faiss: %s. Will not load constructor refs for GPU indexes. This is only an error if you're trying to use GPU Faiss." % e.args[0]) add_ref_in_constructor(IndexIVFFlat, 0) add_ref_in_constructor(IndexIVFFlatDedup, 0) diff --git a/faiss/python/class_wrappers.py b/faiss/python/class_wrappers.py index 607fdd6d29..46f8b0195f 100644 --- a/faiss/python/class_wrappers.py +++ b/faiss/python/class_wrappers.py @@ -1395,3 +1395,12 @@ def from_torch(self, qinco): the_class.__init__ = replacement_init the_class.from_torch = from_torch + + +def handle_shard_ivf_index_centroids(func): + def wrapper(*args, **kwargs): + args = list(args) + if len(args) > 3 and args[3] is not None: + args[3] = faiss.PyCallbackShardingFunction(args[3]) + return func(*args, **kwargs) + return wrapper diff --git a/faiss/python/python_callbacks.cpp b/faiss/python/python_callbacks.cpp index ce36bed437..8b78bf1e43 100644 --- a/faiss/python/python_callbacks.cpp +++ b/faiss/python/python_callbacks.cpp @@ -134,3 +134,27 @@ PyCallbackIDSelector::~PyCallbackIDSelector() { PyThreadLock gil; Py_DECREF(callback); } + +/*********************************************************** + * Callbacks for IVF index sharding + ***********************************************************/ + +PyCallbackShardingFunction::PyCallbackShardingFunction(PyObject* callback) + : callback(callback) { + PyThreadLock gil; + Py_INCREF(callback); +} + +int64_t PyCallbackShardingFunction::operator()(int64_t i, int64_t shard_count) { + PyThreadLock gil; + PyObject* shard_id = PyObject_CallFunction(callback, "LL", i, shard_count); + if (shard_id == nullptr) { + FAISS_THROW_MSG("propagate py error"); + } + return PyLong_AsLongLong(shard_id); +} + +PyCallbackShardingFunction::~PyCallbackShardingFunction() { + PyThreadLock gil; + Py_DECREF(callback); +} diff --git a/faiss/python/python_callbacks.h b/faiss/python/python_callbacks.h index fa8ebaf53c..072e69f91f 100644 --- a/faiss/python/python_callbacks.h +++ b/faiss/python/python_callbacks.h @@ -7,6 +7,7 @@ #pragma once +#include #include #include #include @@ -58,3 +59,24 @@ struct PyCallbackIDSelector : faiss::IDSelector { ~PyCallbackIDSelector() override; }; + +/*********************************************************** + * Callbacks for IVF index sharding + ***********************************************************/ + +struct PyCallbackShardingFunction : faiss::ivflib::ShardingFunction { + PyObject* callback; + + explicit PyCallbackShardingFunction(PyObject* callback); + + int64_t operator()(int64_t i, int64_t shard_count) override; + + ~PyCallbackShardingFunction() override; + + PyCallbackShardingFunction(const PyCallbackShardingFunction&) = delete; + PyCallbackShardingFunction(PyCallbackShardingFunction&&) noexcept = default; + PyCallbackShardingFunction& operator=(const PyCallbackShardingFunction&) = + default; + PyCallbackShardingFunction& operator=(PyCallbackShardingFunction&&) = + default; +}; diff --git a/tests/test_ivflib.py b/tests/test_ivflib.py index d905f3d486..4121304689 100644 --- a/tests/test_ivflib.py +++ b/tests/test_ivflib.py @@ -8,6 +8,9 @@ import unittest import faiss import numpy as np +import os +import random + class TestIVFlib(unittest.TestCase): @@ -180,3 +183,140 @@ def test_small_data(self): assert np.all(lims == ref_lims) assert np.all(D == ref_D) assert np.all(I == ref_I) + + +class TestIvfSharding(unittest.TestCase): + d = 32 + nlist = 100 + nb = 1000 + + def custom_sharding_function(self, i, _): + return 1 if i % 2 == 0 else 7 + + # Mimics the default in DefaultShardingFunction. + # This impl is just used for verification. + def default_sharding_function(self, i, shard_count): + return i % shard_count + + def verify_sharded_ivf_indexes( + self, template, xb, shard_count, sharding_function): + sharded_indexes_counters = [0] * shard_count + sharded_indexes = [] + for i in range(shard_count): + if xb[0].dtype.name == 'uint8': + index = faiss.read_index_binary(template % i) + else: + index = faiss.read_index(template % i) + sharded_indexes.append(index) + # Reconstruct and verify each centroid + nb = len(xb) + for i in range(nb): + shard_id = sharding_function(i, shard_count) + reconstructed = sharded_indexes[shard_id].quantizer.reconstruct( + sharded_indexes_counters[shard_id]) + sharded_indexes_counters[shard_id] += 1 + print(f"reconstructed: {reconstructed} xb[i]: {xb[i]}") + np.testing.assert_array_equal(reconstructed, xb[i]) + # Clean up + for i in range(shard_count): + os.remove(template % i) + + def test_save_index_shards_by_centroids_no_op(self): + quantizer = faiss.IndexFlatL2(self.d) + index = faiss.IndexIVFFlat(quantizer, self.d, self.nlist) + with self.assertRaises(RuntimeError): + faiss.shard_ivf_index_centroids( + index, + 10, + "shard.%d.index", + None + ) + + def test_save_index_shards_by_centroids_flat_quantizer_default_sharding( + self): + xb = np.random.rand(self.nb, self.d).astype('float32') + quantizer = faiss.IndexFlatL2(self.d) + index = faiss.IndexIVFFlat(quantizer, self.d, self.nlist) + shard_count = 3 + + index.quantizer.add(xb) + + template = str(random.randint(0, 100000)) + "shard.%d.index" + faiss.shard_ivf_index_centroids( + index, + shard_count, + template + ) + self.verify_sharded_ivf_indexes( + template, xb, shard_count, self.default_sharding_function) + + def test_save_index_shards_by_centroids_flat_quantizer_custom_sharding( + self): + xb = np.random.rand(self.nb, self.d).astype('float32') + quantizer = faiss.IndexFlatL2(self.d) + index = faiss.IndexIVFFlat(quantizer, self.d, self.nlist) + shard_count = 20 + + index.quantizer.add(xb) + + template = str(random.randint(0, 100000)) + "shard.%d.index" + faiss.shard_ivf_index_centroids( + index, + shard_count, + template, + self.custom_sharding_function + ) + self.verify_sharded_ivf_indexes( + template, xb, shard_count, self.custom_sharding_function) + + def test_save_index_shards_by_centroids_hnsw_quantizer(self): + xb = np.random.rand(self.nb, self.d).astype('float32') + quantizer = faiss.IndexHNSWFlat(self.d, 32) + index = faiss.IndexIVFFlat(quantizer, self.d, self.nlist) + shard_count = 17 + + index.quantizer.add(xb) + + template = str(random.randint(0, 100000)) + "shard.%d.index" + faiss.shard_ivf_index_centroids( + index, + shard_count, + template, + None + ) + self.verify_sharded_ivf_indexes( + template, xb, shard_count, self.default_sharding_function) + + def test_save_index_shards_by_centroids_binary_flat_quantizer(self): + xb = np.random.randint(256, size=(self.nb, int(self.d / 8))).astype('uint8') + quantizer = faiss.IndexBinaryFlat(self.d) + index = faiss.IndexBinaryIVF(quantizer, self.d, self.nlist) + shard_count = 11 + + index.quantizer.add(xb) + + template = str(random.randint(0, 100000)) + "shard.%d.index" + faiss.shard_binary_ivf_index_centroids( + index, + shard_count, + template + ) + self.verify_sharded_ivf_indexes( + template, xb, shard_count, self.default_sharding_function) + + def test_save_index_shards_by_centroids_binary_hnsw_quantizer(self): + xb = np.random.randint(256, size=(self.nb, int(self.d / 8))).astype('uint8') + quantizer = faiss.IndexBinaryHNSW(self.d, 32) + index = faiss.IndexBinaryIVF(quantizer, self.d, self.nlist) + shard_count = 13 + + index.quantizer.add(xb) + + template = str(random.randint(0, 100000)) + "shard.%d.index" + faiss.shard_binary_ivf_index_centroids( + index, + shard_count, + template + ) + self.verify_sharded_ivf_indexes( + template, xb, shard_count, self.default_sharding_function)