Skip to content

Commit 3150e89

Browse files
asadoughifacebook-github-bot
authored andcommitted
Added IndexLSH to the demo (facebookresearch#4009)
Summary: Demonstrate IndexLSH does not need training or codebook serialization Pull Request resolved: facebookresearch#4009 Reviewed By: junjieqi Differential Revision: D65274645 Pulled By: asadoughi fbshipit-source-id: c9af463757edbd07cc07b1cf607b88373fa334c4
1 parent 6620e26 commit 3150e89

1 file changed

Lines changed: 227 additions & 13 deletions

File tree

demos/index_pq_flat_separate_codes_from_codebook.py

Lines changed: 227 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
1-
#!/usr/bin/env -S grimaldi --kernel faiss_binary_local
1+
#!/usr/bin/env -S grimaldi --kernel bento_kernel_faiss
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
33
#
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
6-
76
# fmt: off
87
# flake8: noqa
98

109

1110
""":md
12-
# IndexPQ: separate codes from codebook
11+
# Serializing codes separately, with IndexLSH and IndexPQ
12+
13+
Let's say, for example, you have a few vector embeddings per user
14+
and want to shard a flat index by user so you can re-use the same LSH or PQ method
15+
for all users but store each user's codes independently.
1316
14-
This notebook demonstrates how to separate serializing and deserializing the PQ codebook
15-
(via faiss.write_index for IndexPQ) independently of the vector codes. For example, in the case
16-
where you have a few vector embeddings per user and want to shard the flat index by user you
17-
can re-use the same PQ method for all users but store each user's codes independently.
1817
1918
"""
2019

@@ -24,11 +23,9 @@
2423

2524
""":py"""
2625
d = 768
27-
n = 10000
26+
n = 1_000
2827
ids = np.arange(n).astype('int64')
2928
training_data = np.random.rand(n, d).astype('float32')
30-
M = d//8
31-
nbits = 8
3229

3330
""":py"""
3431
def read_ids_codes():
@@ -50,9 +47,76 @@ def write_template_index(template_index):
5047
def read_template_index_instance():
5148
return faiss.read_index("/tmp/template.index")
5249

50+
""":md
51+
## IndexLSH: separate codes
52+
53+
The first half of this notebook demonstrates how to store LSH codes. Unlike PQ, LSH does not require training. In fact, it's compression method, a random projections matrix, is deterministic on construction based on a random seed value that's [hardcoded](https://github.com/facebookresearch/faiss/blob/2c961cc308ade8a85b3aa10a550728ce3387f625/faiss/IndexLSH.cpp#L35).
54+
"""
55+
5356
""":py"""
54-
# at train time
57+
nbits = 1536
58+
59+
""":py"""
60+
# demonstrating encoding is deterministic
61+
62+
codes = []
63+
database_vector_float32 = np.random.rand(1, d).astype(np.float32)
64+
for i in range(10):
65+
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
66+
code = index.index.sa_encode(database_vector_float32)
67+
codes.append(code)
68+
69+
for i in range(1, 10):
70+
assert np.array_equal(codes[0], codes[i])
71+
72+
""":py"""
73+
# new database vector
74+
75+
ids, codes = read_ids_codes()
76+
database_vector_id, database_vector_float32 = max(ids) + 1 if ids is not None else 1, np.random.rand(1, d).astype(np.float32)
77+
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
78+
79+
code = index.index.sa_encode(database_vector_float32)
5580

81+
if ids is not None and codes is not None:
82+
ids = np.concatenate((ids, [database_vector_id]))
83+
codes = np.vstack((codes, code))
84+
else:
85+
ids = np.array([database_vector_id])
86+
codes = np.array([code])
87+
88+
write_ids_codes(ids, codes)
89+
90+
""":py '2840581589434841'"""
91+
# then at query time
92+
93+
query_vector_float32 = np.random.rand(1, d).astype(np.float32)
94+
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
95+
ids, codes = read_ids_codes()
96+
97+
index.add_sa_codes(codes, ids)
98+
99+
index.search(query_vector_float32, k=5)
100+
101+
""":py"""
102+
!rm /tmp/ids.npy /tmp/codes.npy
103+
104+
""":md
105+
## IndexPQ: separate codes from codebook
106+
107+
The second half of this notebook demonstrates how to separate serializing and deserializing the PQ codebook
108+
(via faiss.write_index for IndexPQ) independently of the vector codes. For example, in the case
109+
where you have a few vector embeddings per user and want to shard the flat index by user you
110+
can re-use the same PQ method for all users but store each user's codes independently.
111+
112+
"""
113+
114+
""":py"""
115+
M = d//8
116+
nbits = 8
117+
118+
""":py"""
119+
# at train time
56120
template_index = faiss.index_factory(d, f"IDMap2,PQ{M}x{nbits}")
57121
template_index.train(training_data)
58122
write_template_index(template_index)
@@ -61,8 +125,8 @@ def read_template_index_instance():
61125
# New database vector
62126

63127
index = read_template_index_instance()
64-
database_vector_id, database_vector_float32 = np.random.randint(10000), np.random.rand(1, d).astype(np.float32)
65128
ids, codes = read_ids_codes()
129+
database_vector_id, database_vector_float32 = max(ids) + 1 if ids is not None else 1, np.random.rand(1, d).astype(np.float32)
66130

67131
code = index.index.sa_encode(database_vector_float32)
68132

@@ -75,7 +139,7 @@ def read_template_index_instance():
75139

76140
write_ids_codes(ids, codes)
77141

78-
""":py '331546060044009'"""
142+
""":py '1858280061369209'"""
79143
# then at query time
80144
query_vector_float32 = np.random.rand(1, d).astype(np.float32)
81145
id_wrapper_index = read_template_index_instance()
@@ -87,3 +151,153 @@ def read_template_index_instance():
87151

88152
""":py"""
89153
!rm /tmp/ids.npy /tmp/codes.npy /tmp/template.index
154+
155+
""":md
156+
## Comparing these methods
157+
158+
- methods: Flat, LSH, PQ
159+
- vary cost: nbits, M for 1x, 2x, 4x, 8x, 16x, 32x compression
160+
- measure: recall@1
161+
162+
We don't measure latency as the number of vectors per user shard is insignificant.
163+
164+
"""
165+
166+
""":py '2898032417027201'"""
167+
n, d
168+
169+
""":py"""
170+
database_vector_ids, database_vector_float32s = np.arange(n), np.random.rand(n, d).astype(np.float32)
171+
query_vector_float32s = np.random.rand(n, d).astype(np.float32)
172+
173+
""":py"""
174+
index = faiss.index_factory(d, "IDMap2,Flat")
175+
index.add_with_ids(database_vector_float32s, database_vector_ids)
176+
_, ground_truth_result_ids= index.search(query_vector_float32s, k=1)
177+
178+
""":py '857475336204238'"""
179+
from dataclasses import dataclass
180+
181+
pq_m_nbits = (
182+
# 96 bytes
183+
(96, 8),
184+
(192, 4),
185+
# 192 bytes
186+
(192, 8),
187+
(384, 4),
188+
# 384 bytes
189+
(384, 8),
190+
(768, 4),
191+
)
192+
lsh_nbits = (768, 1536, 3072, 6144, 12288, 24576)
193+
194+
195+
@dataclass
196+
class Record:
197+
type_: str
198+
index: faiss.Index
199+
args: tuple
200+
recall: float
201+
202+
203+
results = []
204+
205+
for m, nbits in pq_m_nbits:
206+
print("pq", m, nbits)
207+
index = faiss.index_factory(d, f"IDMap2,PQ{m}x{nbits}")
208+
index.train(training_data)
209+
index.add_with_ids(database_vector_float32s, database_vector_ids)
210+
_, result_ids = index.search(query_vector_float32s, k=1)
211+
recall = sum(result_ids == ground_truth_result_ids)
212+
results.append(Record("pq", index, (m, nbits), recall))
213+
214+
for nbits in lsh_nbits:
215+
print("lsh", nbits)
216+
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
217+
index.add_with_ids(database_vector_float32s, database_vector_ids)
218+
_, result_ids = index.search(query_vector_float32s, k=1)
219+
recall = sum(result_ids == ground_truth_result_ids)
220+
results.append(Record("lsh", index, (nbits,), recall))
221+
222+
""":py '556918346720794'"""
223+
import matplotlib.pyplot as plt
224+
import numpy as np
225+
226+
def create_grouped_bar_chart(x_values, y_values_list, labels_list, xlabel, ylabel, title):
227+
num_bars_per_group = len(x_values)
228+
229+
plt.figure(figsize=(12, 6))
230+
231+
for x, y_values, labels in zip(x_values, y_values_list, labels_list):
232+
num_bars = len(y_values)
233+
bar_width = 0.08 * x
234+
bar_positions = np.arange(num_bars) * bar_width - (num_bars - 1) * bar_width / 2 + x
235+
236+
bars = plt.bar(bar_positions, y_values, width=bar_width)
237+
238+
for bar, label in zip(bars, labels):
239+
height = bar.get_height()
240+
plt.annotate(
241+
label,
242+
xy=(bar.get_x() + bar.get_width() / 2, height),
243+
xytext=(0, 3),
244+
textcoords="offset points",
245+
ha='center', va='bottom'
246+
)
247+
248+
plt.xscale('log')
249+
plt.xlabel(xlabel)
250+
plt.ylabel(ylabel)
251+
plt.title(title)
252+
plt.xticks(x_values, labels=[str(x) for x in x_values])
253+
plt.tight_layout()
254+
plt.show()
255+
256+
# # Example usage:
257+
# x_values = [1, 2, 4, 8, 16, 32]
258+
# y_values_list = [
259+
# [2.5, 3.6, 1.8],
260+
# [3.0, 2.8],
261+
# [2.5, 3.5, 4.0, 1.0],
262+
# [4.2],
263+
# [3.0, 5.5, 2.2],
264+
# [6.0, 4.5]
265+
# ]
266+
# labels_list = [
267+
# ['A1', 'B1', 'C1'],
268+
# ['A2', 'B2'],
269+
# ['A3', 'B3', 'C3', 'D3'],
270+
# ['A4'],
271+
# ['A5', 'B5', 'C5'],
272+
# ['A6', 'B6']
273+
# ]
274+
275+
# create_grouped_bar_chart(x_values, y_values_list, labels_list, "x axis", "y axis", "title")
276+
277+
""":py '1630106834206134'"""
278+
# x-axis: compression ratio
279+
# y-axis: recall@1
280+
281+
from collections import defaultdict
282+
283+
x = defaultdict(list)
284+
x[1].append(("flat", 1.00))
285+
for r in results:
286+
y_value = r.recall[0] / n
287+
x_value = int(d * 4 / r.index.sa_code_size())
288+
label = None
289+
if r.type_ == "pq":
290+
label = f"PQ{r.args[0]}x{r.args[1]}"
291+
if r.type_ == "lsh":
292+
label = f"LSH{r.args[0]}"
293+
x[x_value].append((label, y_value))
294+
295+
x_values = sorted(list(x.keys()))
296+
create_grouped_bar_chart(
297+
x_values,
298+
[[e[1] for e in x[x_value]] for x_value in x_values],
299+
[[e[0] for e in x[x_value]] for x_value in x_values],
300+
"compression ratio",
301+
"recall@1 q=1,000 queries",
302+
"recall@1 for a database of n=1,000 d=768 vectors",
303+
)

0 commit comments

Comments
 (0)