Skip to content

Commit 02a6c41

Browse files
mdouzefacebook-github-bot
authored andcommitted
QINCo implementation in CPU Faiss (#3608)
Summary: Pull Request resolved: #3608 This is a straightforward implementation of QINCo in CPU Faiss, with encoding and decoding capabilities (not training). For this, we translate a simplified version of some torch classes: - tensors, restricted to 2D and int32 + float32 - Linear and Embedding layer Then the QINCoStep and QINCo can just be defined as C++ objects that are copy-constructable. There is some plumbing required in the wrapping layers to support the integration. Pytroch tensors are converted to numpy for getting / setting them in C++. Differential Revision: D59132952
1 parent f821704 commit 02a6c41

8 files changed

Lines changed: 1021 additions & 1 deletion

File tree

demos/demo_qinco.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
This demonstrates how to reproduce the QINCo paper results using the Faiss
8+
QINCo implementation. The code loads the reference model because training is not
9+
implemented in Faiss.
10+
11+
Prepare the data with
12+
13+
cd /tmp
14+
15+
# get the reference qinco code
16+
git clone https://github.com/facebookresearch/Qinco.git
17+
18+
# get the data
19+
wget https://dl.fbaipublicfiles.com/QINCo/datasets/bigann/bigann1M.bvecs
20+
21+
# get the model
22+
wget https://dl.fbaipublicfiles.com/QINCo/models/bigann_8x8_L2.pt
23+
24+
"""
25+
26+
import numpy as np
27+
from faiss.contrib.vecs_io import bvecs_mmap
28+
import sys
29+
import time
30+
import torch
31+
import faiss
32+
33+
# make sure pickle deserialization will work
34+
sys.path.append("/tmp/Qinco")
35+
import model_qinco
36+
37+
with torch.no_grad():
38+
39+
qinco = torch.load("/tmp/bigann_8x8_L2.pt")
40+
qinco.eval()
41+
# print(qinco)
42+
if True:
43+
torch.set_num_threads(1)
44+
faiss.omp_set_num_threads(1)
45+
46+
x_base = bvecs_mmap("/tmp/bigann1M.bvecs")[:1000].astype('float32')
47+
x_scaled = torch.from_numpy(x_base) / qinco.db_scale
48+
49+
t0 = time.time()
50+
codes, _ = qinco.encode(x_scaled)
51+
x_decoded_scaled = qinco.decode(codes)
52+
print(f"Pytorch encode {time.time() - t0:.3f} s")
53+
# multi-thread: 1.13s, single-thread: 7.744
54+
55+
x_decoded = x_decoded_scaled.numpy() * qinco.db_scale
56+
57+
err = ((x_decoded - x_base) ** 2).sum(1).mean()
58+
print("MSE=", err) # = 14211.956, near the L=2 result in Fig 4 of the paper
59+
60+
qinco2 = faiss.QINCo(qinco)
61+
t0 = time.time()
62+
codes2 = qinco2.encode(faiss.Tensor2D(x_scaled))
63+
x_decoded2 = qinco2.decode(codes2).numpy() * qinco.db_scale
64+
print(f"Faiss encode {time.time() - t0:.3f} s")
65+
# multi-thread: 3.2s, single thread: 7.019
66+
67+
# these tests don't work because there are outlier encodings
68+
# np.testing.assert_array_equal(codes.numpy(), codes2.numpy())
69+
# np.testing.assert_allclose(x_decoded, x_decoded2)
70+
71+
ndiff = (codes.numpy() != codes2.numpy()).sum() / codes.numel()
72+
assert ndiff < 0.01
73+
ndiff = (((x_decoded - x_decoded2) ** 2).sum(1) > 1e-5).sum()
74+
assert ndiff / len(x_base) < 0.01
75+
76+
err = ((x_decoded2 - x_base) ** 2).sum(1).mean()
77+
print("MSE=", err) # = 14213.551

faiss/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ set(FAISS_SRC
8181
invlists/InvertedLists.cpp
8282
invlists/InvertedListsIOHook.cpp
8383
utils/Heap.cpp
84+
utils/NeuralNet.cpp
8485
utils/WorkerThread.cpp
8586
utils/distances.cpp
8687
utils/distances_simd.cpp

faiss/python/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@
4444
class_wrappers.handle_IDSelectorSubset(IDSelectorBitmap, class_owns=False, force_int64=False)
4545
class_wrappers.handle_CodeSet(CodeSet)
4646

47+
class_wrappers.handle_Tensor2D(Tensor2D)
48+
class_wrappers.handle_Tensor2D(Int32Tensor2D)
49+
class_wrappers.handle_Embedding(Embedding)
50+
class_wrappers.handle_Linear(Linear)
51+
class_wrappers.handle_QINCo(QINCo)
52+
class_wrappers.handle_QINCoStep(QINCoStep)
53+
54+
4755
this_module = sys.modules[__name__]
4856

4957
# handle sub-classes

faiss/python/class_wrappers.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,3 +1247,133 @@ def replacement_insert(self, codes, inserted=None):
12471247
return inserted
12481248

12491249
replace_method(the_class, 'insert', replacement_insert)
1250+
1251+
######################################################
1252+
# Syntatic sugar for NeuralNet classes
1253+
######################################################
1254+
1255+
def handle_Tensor2D(the_class):
1256+
the_class.original_init = the_class.__init__
1257+
1258+
def replacement_init(self, *args):
1259+
if len(args) == 1:
1260+
array, = args
1261+
n, d = array.shape
1262+
self.original_init(n, d)
1263+
faiss.copy_array_to_vector(
1264+
np.ascontiguousarray(array).ravel(), self.v)
1265+
else:
1266+
self.original_init(*args)
1267+
1268+
def numpy(self):
1269+
shape = np.zeros(2, dtype=np.int64)
1270+
faiss.memcpy(faiss.swig_ptr(shape), self.shape, shape.nbytes)
1271+
return faiss.vector_to_array(self.v).reshape(shape[0], shape[1])
1272+
1273+
the_class.__init__ = replacement_init
1274+
the_class.numpy = numpy
1275+
1276+
1277+
def handle_Embedding(the_class):
1278+
the_class.original_init = the_class.__init__
1279+
1280+
def replacement_init(self, *args):
1281+
if len(args) != 1 or args[0].__class__ == the_class:
1282+
self.original_init(*args)
1283+
return
1284+
# assume it's a torch.Embedding
1285+
emb = args[0]
1286+
self.original_init(emb.num_embeddings, emb.embedding_dim)
1287+
self.from_torch(emb)
1288+
1289+
def from_torch(self, emb):
1290+
""" copy weights from torch.Embedding """
1291+
assert emb.weight.shape == (self.num_embeddings, self.embedding_dim)
1292+
faiss.copy_array_to_vector(
1293+
np.ascontiguousarray(emb.weight.data).ravel(), self.weight)
1294+
1295+
the_class.from_torch = from_torch
1296+
the_class.__init__ = replacement_init
1297+
1298+
1299+
def handle_Linear(the_class):
1300+
the_class.original_init = the_class.__init__
1301+
1302+
def replacement_init(self, *args):
1303+
if len(args) != 1 or args[0].__class__ == the_class:
1304+
self.original_init(*args)
1305+
return
1306+
# assume it's a torch.Linear
1307+
linear = args[0]
1308+
bias = linear.bias is not None
1309+
self.original_init(linear.in_features, linear.out_features, bias)
1310+
self.from_torch(linear)
1311+
1312+
def from_torch(self, linear):
1313+
""" copy weights from torch.Linear """
1314+
assert linear.weight.shape == (self.out_features, self.in_features)
1315+
faiss.copy_array_to_vector(linear.weight.data.numpy().ravel(), self.weight)
1316+
if linear.bias is not None:
1317+
assert linear.bias.shape == (self.out_features,)
1318+
faiss.copy_array_to_vector(linear.bias.data.numpy(), self.bias)
1319+
1320+
the_class.__init__ = replacement_init
1321+
the_class.from_torch = from_torch
1322+
1323+
######################################################
1324+
# Syntatic sugar for QINCo and QINCoStep
1325+
######################################################
1326+
1327+
def handle_QINCoStep(the_class):
1328+
the_class.original_init = the_class.__init__
1329+
1330+
def replacement_init(self, *args):
1331+
if len(args) != 1 or args[0].__class__ == the_class:
1332+
self.original_init(*args)
1333+
return
1334+
step = args[0]
1335+
# assume it's a Torch QINCoStep
1336+
self.original_init(step.d, step.K, step.L, step.h)
1337+
self.from_torch(step)
1338+
1339+
def from_torch(self, step):
1340+
""" copy weights from torch.QINCoStep """
1341+
assert (step.d, step.K, step.L, step.h) == (self.d, self.K, self.L, self.h)
1342+
self.codebook.from_torch(step.codebook)
1343+
self.MLPconcat.from_torch(step.MLPconcat)
1344+
1345+
for l in range(step.L):
1346+
src = step.residual_blocks[l]
1347+
dest = self.get_residual_block(l)
1348+
dest.linear1.from_torch(src[0])
1349+
dest.linear2.from_torch(src[2])
1350+
1351+
the_class.__init__ = replacement_init
1352+
the_class.from_torch = from_torch
1353+
1354+
1355+
def handle_QINCo(the_class):
1356+
the_class.original_init = the_class.__init__
1357+
1358+
def replacement_init(self, *args):
1359+
if len(args) != 1 or args[0].__class__ == the_class:
1360+
self.original_init(*args)
1361+
return
1362+
1363+
# assume it's a Torch QINCo
1364+
qinco = args[0]
1365+
self.original_init(qinco.d, qinco.K, qinco.L, qinco.M, qinco.h)
1366+
self.from_torch(qinco)
1367+
1368+
def from_torch(self, qinco):
1369+
""" copy weights from torch.QINCo """
1370+
assert (
1371+
(qinco.d, qinco.K, qinco.L, qinco.M, qinco.h) ==
1372+
(self.d, self.K, self.L, self.M, self.h)
1373+
)
1374+
self.codebook0.from_torch(qinco.codebook0)
1375+
for m in range(qinco.M - 1):
1376+
self.get_step(m).from_torch(qinco.steps[m])
1377+
1378+
the_class.__init__ = replacement_init
1379+
the_class.from_torch = from_torch

faiss/python/swigfaiss.swig

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ typedef uint64_t size_t;
145145
#include <faiss/impl/LocalSearchQuantizer.h>
146146
#include <faiss/impl/ProductAdditiveQuantizer.h>
147147
#include <faiss/impl/CodePacker.h>
148+
#include <faiss/utils/NeuralNet.h>
148149

149150
#include <faiss/invlists/BlockInvertedLists.h>
150151

@@ -257,7 +258,6 @@ namespace std {
257258
%template(ClusteringIterationStatsVector) std::vector<faiss::ClusteringIterationStats>;
258259
%template(ParameterRangeVector) std::vector<faiss::ParameterRange>;
259260

260-
261261
#ifndef SWIGWIN
262262
%template(OnDiskOneListVector) std::vector<faiss::OnDiskOneList>;
263263
#endif // !SWIGWIN
@@ -530,6 +530,9 @@ struct faiss::simd16uint16 {};
530530

531531
%include <faiss/IndexRowwiseMinMax.h>
532532

533+
%include <faiss/utils/NeuralNet.h>
534+
%template(Tensor2D) faiss::nn::Tensor2DTemplate<float>;
535+
%template(Int32Tensor2D) faiss::nn::Tensor2DTemplate<int32_t>;
533536

534537
%ignore faiss::BufferList::Buffer;
535538
%ignore faiss::RangeSearchPartialResult::QueryResult;

0 commit comments

Comments
 (0)