Skip to content

Commit 7dfa2f8

Browse files
Michael Norrisfacebook-github-bot
authored andcommitted
Add sharding convenience function for IVF indexes (#4150)
Summary: Creates a sharding convenience function for IVF indexes. - The centroids on the quantizer are sharded based on the given sharding function. - The output is written to files based on the template filename generator param. - The default sharding function is simply the ith vector mod the total shard count. This would called by Laser here: https://www.internalfb.com/code/fbsource/[ce1f2e028e79]/fbcode/fblearner/flow/projects/laser/laser_sim_search/knn_trainer.py?lines=295-296. This convenience function will do the file writing, and return the created file names. There's a few key required changes in FAISS: 1. Allow `std::vector<std::string>` to be used. Updates swigfaiss.swig and array_conversions.py to accommodate. These have to be numpy dtype of `object` instead of the more correct `unicode`, because unicode dtype is fixed length. I couldn't figure out how to create a numpy array with each of the output file names where they have different dtypes. (Say the file names are like file1, file11, file111. The dtype would need to be U5, U6, U7 respectively, as the dtype for unicode contains the length). I tried structured arrays : this does not work either, as numpy makes it into a matrix instead: the `file1 file11 file111` example with explicit setting of U5, U6, U7 turns into `[[file1 file1 file1], [file1 file11 file11], [file1 file11 file111]]`, which we do not want. If someone knows the right syntax, please yell at me 2. Create Python callback for sharding and template filename. Users of this function would inherit from the FilenameTemplateGenerator or ShardingFunction to pass to `shard_ivf_index_centroids`. See the other examples in python_callbacks.cpp. This is required because Python functions cannot be passed through SWIG to C++ (i.e. no std::function or function pointers), so we have to use this approach. This approach allows it to be called from both C++ and Python. Differential Revision: D68534991
1 parent 7856053 commit 7dfa2f8

10 files changed

Lines changed: 448 additions & 5 deletions

File tree

faiss/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ set(FAISS_SRC
9191
utils/quantize_lut.cpp
9292
utils/random.cpp
9393
utils/sorting.cpp
94+
utils/sharding.cpp
9495
utils/utils.cpp
9596
utils/distances_fused/avx512.cpp
9697
utils/distances_fused/distances_fused.cpp
@@ -207,6 +208,7 @@ set(FAISS_HEADERS
207208
utils/prefetch.h
208209
utils/quantize_lut.h
209210
utils/random.h
211+
utils/sharding.h
210212
utils/sorting.h
211213
utils/simdlib.h
212214
utils/simdlib_avx2.h

faiss/python/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def replacement_function(*args):
170170
add_ref_in_constructor(GpuIndexIVFPQ, 1)
171171
add_ref_in_constructor(GpuIndexIVFScalarQuantizer, 1)
172172
except NameError as e:
173-
logger.info("Failed to load GPU Faiss: %s. Will not load constructor refs for GPU indexes." % e.args[0])
173+
logger.info("Failed to load GPU Faiss: %s. Will not load constructor refs for GPU indexes. This is only an error if you're trying to use GPU Faiss." % e.args[0])
174174

175175
add_ref_in_constructor(IndexIVFFlat, 0)
176176
add_ref_in_constructor(IndexIVFFlatDedup, 0)

faiss/python/array_conversions.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def new_meth(cls, *args, **kwargs):
9797
'UInt16': 'uint16',
9898
'UInt32': 'uint32',
9999
'UInt64': 'uint64',
100+
'String': 'str',
100101
**{k: v.lower() for k, v in deprecated_name_map.items()}
101102
}
102103

@@ -107,10 +108,19 @@ def vector_to_array(v):
107108
if classname.startswith('AlignedTable'):
108109
return AlignedTable_to_array(v)
109110
assert classname.endswith('Vector')
110-
dtype = np.dtype(vector_name_map[classname[:-6]])
111-
a = np.empty(v.size(), dtype=dtype)
112-
if v.size() > 0:
113-
memcpy(swig_ptr(a), v.data(), a.nbytes)
111+
vector_name = vector_name_map[classname[:-6]]
112+
# TODO: Remove this hack with 'object' after upgrading to
113+
# Numpy 2, which can support variable length strings.
114+
if vector_name == 'str':
115+
values = []
116+
for i in range (0, v.size()):
117+
values.append(v.at(i))
118+
a = np.array(values, dtype='object')
119+
else:
120+
dtype = np.dtype(vector_name)
121+
a = np.empty(v.size(), dtype=dtype)
122+
if v.size() > 0:
123+
memcpy(swig_ptr(a), v.data(), a.nbytes)
114124
return a
115125

116126

faiss/python/python_callbacks.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,54 @@ PyCallbackIDSelector::~PyCallbackIDSelector() {
134134
PyThreadLock gil;
135135
Py_DECREF(callback);
136136
}
137+
138+
/***********************************************************
139+
* Callbacks for IVF index sharding
140+
***********************************************************/
141+
142+
PyCallbackFilenameTemplateGenerator::PyCallbackFilenameTemplateGenerator(
143+
PyObject* callback)
144+
: callback(callback) {
145+
PyThreadLock gil;
146+
Py_INCREF(callback);
147+
}
148+
149+
std::string PyCallbackFilenameTemplateGenerator::operator()() {
150+
PyThreadLock gil;
151+
PyObject* template_filename = PyObject_CallFunction(callback, NULL);
152+
if (template_filename == nullptr) {
153+
FAISS_THROW_MSG("propagate py error");
154+
}
155+
const char* cstr = PyUnicode_AsUTF8(template_filename);
156+
if (cstr == NULL) {
157+
// handle error or return empty string
158+
return "";
159+
}
160+
std::string result(cstr);
161+
return result;
162+
}
163+
164+
PyCallbackFilenameTemplateGenerator::~PyCallbackFilenameTemplateGenerator() {
165+
PyThreadLock gil;
166+
Py_DECREF(callback);
167+
}
168+
169+
PyCallbackShardingFunction::PyCallbackShardingFunction(PyObject* callback)
170+
: callback(callback) {
171+
PyThreadLock gil;
172+
Py_INCREF(callback);
173+
}
174+
175+
int64_t PyCallbackShardingFunction::operator()(int64_t i, int64_t shard_count) {
176+
PyThreadLock gil;
177+
PyObject* shard_id = PyObject_CallFunction(callback, "LL", i, shard_count);
178+
if (shard_id == nullptr) {
179+
FAISS_THROW_MSG("propagate py error");
180+
}
181+
return PyLong_AsLongLong(shard_id);
182+
}
183+
184+
PyCallbackShardingFunction::~PyCallbackShardingFunction() {
185+
PyThreadLock gil;
186+
Py_DECREF(callback);
187+
}

faiss/python/python_callbacks.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <faiss/impl/IDSelector.h>
1111
#include <faiss/impl/io.h>
1212
#include <faiss/invlists/InvertedLists.h>
13+
#include <faiss/utils/sharding.h>
1314
#include "Python.h"
1415

1516
// all callbacks have to acquire the GIL on input
@@ -58,3 +59,27 @@ struct PyCallbackIDSelector : faiss::IDSelector {
5859

5960
~PyCallbackIDSelector() override;
6061
};
62+
63+
/***********************************************************
64+
* Callbacks for IVF index sharding
65+
***********************************************************/
66+
67+
struct PyCallbackFilenameTemplateGenerator : faiss::FilenameTemplateGenerator {
68+
PyObject* callback;
69+
70+
explicit PyCallbackFilenameTemplateGenerator(PyObject* callback);
71+
72+
std::string operator()();
73+
74+
~PyCallbackFilenameTemplateGenerator();
75+
};
76+
77+
struct PyCallbackShardingFunction : faiss::ShardingFunction {
78+
PyObject* callback;
79+
80+
explicit PyCallbackShardingFunction(PyObject* callback);
81+
82+
int64_t operator()(int64_t i, int64_t shard_count);
83+
84+
~PyCallbackShardingFunction();
85+
};

faiss/python/swigfaiss.swig

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ typedef uint64_t size_t;
145145
#include <faiss/IVFlib.h>
146146
#include <faiss/utils/utils.h>
147147

148+
#include <faiss/utils/sharding.h>
148149
#include <faiss/utils/sorting.h>
149150
#include <faiss/utils/distances.h>
150151
#include <faiss/utils/extra_distances.h>
@@ -270,6 +271,8 @@ namespace std {
270271
%template(UInt32Vector) std::vector<uint32_t>;
271272
%template(UInt64Vector) std::vector<uint64_t>;
272273

274+
%template(StringVector) std::vector<std::string>;
275+
273276
%template(Float32VectorVector) std::vector<std::vector<float> >;
274277
%template(UInt8VectorVector) std::vector<std::vector<uint8_t> >;
275278
%template(Int32VectorVector) std::vector<std::vector<int32_t> >;
@@ -539,6 +542,17 @@ void gpu_sync_all_devices()
539542
%include <faiss/invlists/BlockInvertedLists.h>
540543
%include <faiss/invlists/DirectMap.h>
541544
%include <faiss/IndexIVF.h>
545+
546+
// Include after IndexIVF.
547+
%shared_ptr(faiss::FilenameTemplateGenerator);
548+
%shared_ptr(PyCallbackFilenameTemplateGenerator);
549+
%shared_ptr(faiss::DefaultFilenameTemplateGenerator);
550+
%shared_ptr(faiss::ShardingFunction);
551+
%shared_ptr(PyCallbackShardingFunction);
552+
%shared_ptr(faiss::DefaultShardingFunction);
553+
554+
%include <faiss/utils/sharding.h>
555+
542556
// NOTE(hoss): SWIG (wrongly) believes the overloaded const version shadows the
543557
// non-const one.
544558
%warnfilter(509) extract_index_ivf;
@@ -1163,6 +1177,9 @@ PyObject *swig_ptr (PyObject *a)
11631177
if(PyArray_TYPE(ao) == NPY_INT32) {
11641178
return SWIG_NewPointerObj(data, SWIGTYPE_p_int, 0);
11651179
}
1180+
if(PyArray_TYPE(ao) == NPY_OBJECT) {
1181+
return SWIG_NewPointerObj(data, SWIGTYPE_p_char, 0);
1182+
}
11661183
if(PyArray_TYPE(ao) == NPY_BOOL) {
11671184
return SWIG_NewPointerObj(data, SWIGTYPE_p_bool, 0);
11681185
}

faiss/utils/sharding.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <faiss/clone_index.h>
9+
#include <faiss/index_io.h>
10+
#include <faiss/utils/sharding.h>
11+
#include <cstdio>
12+
13+
namespace faiss {
14+
15+
std::string DefaultFilenameTemplateGenerator::operator()() {
16+
return "shard_%d.faissindex";
17+
}
18+
19+
int64_t DefaultShardingFunction::operator()(int64_t i, int64_t shard_count) {
20+
return i % shard_count;
21+
}
22+
23+
std::vector<std::string> shard_ivf_index_centroids(
24+
IndexIVF* index,
25+
int64_t shard_count,
26+
std::shared_ptr<FilenameTemplateGenerator> filename_template_generator,
27+
std::shared_ptr<ShardingFunction> sharding_function) {
28+
if (index->quantizer->ntotal == 0) {
29+
return std::vector<std::string>();
30+
}
31+
32+
if (filename_template_generator == nullptr) {
33+
filename_template_generator =
34+
std::make_shared<DefaultFilenameTemplateGenerator>();
35+
}
36+
if (sharding_function == nullptr) {
37+
sharding_function = std::make_shared<DefaultShardingFunction>();
38+
}
39+
40+
IndexIVF* sharded_indexes[shard_count];
41+
for (int i = 0; i < shard_count; i++) {
42+
sharded_indexes[i] = static_cast<IndexIVF*>(clone_index(index));
43+
sharded_indexes[i]->quantizer->reset();
44+
}
45+
46+
// assign centroids to each sharded Index based on sharding_function, and
47+
// add them to the quantizer of each sharded index
48+
std::vector<float> sharded_centroids[shard_count];
49+
for (int i = 0; i < index->quantizer->ntotal; i++) {
50+
int shard_id = (*sharding_function)(i, shard_count);
51+
float reconstructed[index->quantizer->d];
52+
index->quantizer->reconstruct(i, reconstructed);
53+
sharded_centroids[shard_id].insert(
54+
sharded_centroids[shard_id].end(),
55+
&reconstructed[0],
56+
&reconstructed[index->quantizer->d]);
57+
}
58+
for (int i = 0; i < shard_count; i++) {
59+
sharded_indexes[i]->quantizer->add(
60+
sharded_centroids[i].size() / index->quantizer->d,
61+
sharded_centroids[i].data());
62+
}
63+
64+
std::vector<std::string> result;
65+
for (int i = 0; i < shard_count; i++) {
66+
char fname[256];
67+
std::string template_filename = (*filename_template_generator)();
68+
snprintf(fname, 256, template_filename.c_str(), i);
69+
result.emplace_back(fname);
70+
write_index(sharded_indexes[i], fname);
71+
}
72+
73+
for (int i = 0; i < shard_count; i++) {
74+
delete sharded_indexes[i];
75+
}
76+
77+
return result;
78+
}
79+
80+
} // namespace faiss

faiss/utils/sharding.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
// -*- c++ -*-
9+
10+
/*
11+
* Utilities related to sharding indexes.
12+
*/
13+
14+
#ifndef FAISS_SHARDING_H
15+
#define FAISS_SHARDING_H
16+
17+
#include <faiss/IndexIVF.h>
18+
19+
namespace faiss {
20+
21+
struct FilenameTemplateGenerator {
22+
virtual std::string operator()() = 0;
23+
virtual ~FilenameTemplateGenerator() {}
24+
};
25+
struct DefaultFilenameTemplateGenerator : FilenameTemplateGenerator {
26+
std::string operator()() override;
27+
};
28+
29+
struct ShardingFunction {
30+
virtual int64_t operator()(int64_t i, int64_t shard_count) = 0;
31+
virtual ~ShardingFunction() {}
32+
};
33+
struct DefaultShardingFunction : ShardingFunction {
34+
int64_t operator()(int64_t i, int64_t shard_count) override;
35+
};
36+
37+
/**
38+
* Shards an IVF index centroids by the given sharding function, and writes
39+
* the index to the path given by filename_generator. The centroids must already
40+
* be added to the index quantizer.
41+
*
42+
* @param index The IVF index containing centroids to
43+
* shard.
44+
* @param shard_count Number of shards.
45+
* @param filename_template_generator Function that generates a filename
46+
* template of the output indexes.
47+
* @param sharding_function The function to shard by. The default is
48+
* ith vector mod shard_count.
49+
* @return The list of output filenames.
50+
*/
51+
std::vector<std::string> shard_ivf_index_centroids(
52+
IndexIVF* index,
53+
int64_t shard_count = 20,
54+
std::shared_ptr<FilenameTemplateGenerator> filename_template_generator =
55+
nullptr,
56+
std::shared_ptr<ShardingFunction> sharding_function = nullptr);
57+
} // namespace faiss
58+
59+
#endif

tests/test_sharding.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import faiss
7+
import unittest
8+
import numpy as np
9+
import os
10+
11+
faiss.omp_set_num_threads(4)
12+
13+
14+
def custom_filename_template_generator_callback() -> str:
15+
randomInt = np.random.randint(0, 1000000000)
16+
return str("shard_%d.faissindex" + str(randomInt))
17+
18+
19+
def custom_sharding_callback(i, shard_count):
20+
del i
21+
del shard_count
22+
return 3
23+
24+
25+
class TestSharding(unittest.TestCase):
26+
"""
27+
Test the sharding.h from Python to ensure integration is working.
28+
No correctness checks are done here.
29+
"""
30+
def test_python_sharding(self):
31+
d = 32
32+
nb = 1000
33+
xb = np.random.rand(nb, d).astype('float32')
34+
nlist = 100
35+
quantizer = faiss.IndexFlatL2(d)
36+
index = faiss.IndexIVFFlat(quantizer, d, nlist)
37+
38+
index.quantizer.add(xb)
39+
40+
res = (
41+
faiss.shard_ivf_index_centroids(
42+
index,
43+
10,
44+
faiss.PyCallbackFilenameTemplateGenerator(
45+
custom_filename_template_generator_callback)
46+
)
47+
)
48+
for file in faiss.vector_to_array(res):
49+
os.remove(file)
50+
51+
res = (
52+
faiss.shard_ivf_index_centroids(
53+
index,
54+
10,
55+
faiss.PyCallbackFilenameTemplateGenerator(
56+
custom_filename_template_generator_callback),
57+
faiss.PyCallbackShardingFunction(
58+
custom_sharding_callback)
59+
)
60+
)
61+
for file in faiss.vector_to_array(res):
62+
os.remove(file)

0 commit comments

Comments
 (0)