1- #!/usr/bin/env -S grimaldi --kernel faiss_binary_local
1+ #!/usr/bin/env -S grimaldi --kernel bento_kernel_faiss
22# Copyright (c) Meta Platforms, Inc. and affiliates.
33#
44# This source code is licensed under the MIT license found in the
55# LICENSE file in the root directory of this source tree.
6-
76# fmt: off
87# flake8: noqa
98
109
1110""":md
12- # IndexPQ: separate codes from codebook
11+ # Serializing codes separately, with IndexLSH and IndexPQ
12+
13+ Let's say, for example, you have a few vector embeddings per user
14+ and want to shard a flat index by user so you can re-use the same LSH or PQ method
15+ for all users but store each user's codes independently.
1316
14- This notebook demonstrates how to separate serializing and deserializing the PQ codebook
15- (via faiss.write_index for IndexPQ) independently of the vector codes. For example, in the case
16- where you have a few vector embeddings per user and want to shard the flat index by user you
17- can re-use the same PQ method for all users but store each user's codes independently.
1817
1918"""
2019
2423
2524""":py"""
2625d = 768
27- n = 10000
26+ n = 1_000
2827ids = np .arange (n ).astype ('int64' )
2928training_data = np .random .rand (n , d ).astype ('float32' )
30- M = d // 8
31- nbits = 8
3229
3330""":py"""
3431def read_ids_codes ():
@@ -50,9 +47,76 @@ def write_template_index(template_index):
5047def read_template_index_instance ():
5148 return faiss .read_index ("/tmp/template.index" )
5249
50+ """:md
51+ ## IndexLSH: separate codes
52+
53+ The first half of this notebook demonstrates how to store LSH codes. Unlike PQ, LSH does not require training. In fact, it's compression method, a random projections matrix, is deterministic on construction based on a random seed value that's [hardcoded](https://github.com/facebookresearch/faiss/blob/2c961cc308ade8a85b3aa10a550728ce3387f625/faiss/IndexLSH.cpp#L35).
54+ """
55+
5356""":py"""
54- # at train time
57+ nbits = 1536
58+
59+ """:py"""
60+ # demonstrating encoding is deterministic
61+
62+ codes = []
63+ database_vector_float32 = np .random .rand (1 , d ).astype (np .float32 )
64+ for i in range (10 ):
65+ index = faiss .IndexIDMap2 (faiss .IndexLSH (d , nbits ))
66+ code = index .index .sa_encode (database_vector_float32 )
67+ codes .append (code )
68+
69+ for i in range (1 , 10 ):
70+ assert np .array_equal (codes [0 ], codes [i ])
71+
72+ """:py"""
73+ # new database vector
74+
75+ ids , codes = read_ids_codes ()
76+ database_vector_id , database_vector_float32 = max (ids ) + 1 if ids is not None else 1 , np .random .rand (1 , d ).astype (np .float32 )
77+ index = faiss .IndexIDMap2 (faiss .IndexLSH (d , nbits ))
78+
79+ code = index .index .sa_encode (database_vector_float32 )
5580
81+ if ids is not None and codes is not None :
82+ ids = np .concatenate ((ids , [database_vector_id ]))
83+ codes = np .vstack ((codes , code ))
84+ else :
85+ ids = np .array ([database_vector_id ])
86+ codes = np .array ([code ])
87+
88+ write_ids_codes (ids , codes )
89+
90+ """:py '2840581589434841'"""
91+ # then at query time
92+
93+ query_vector_float32 = np .random .rand (1 , d ).astype (np .float32 )
94+ index = faiss .IndexIDMap2 (faiss .IndexLSH (d , nbits ))
95+ ids , codes = read_ids_codes ()
96+
97+ index .add_sa_codes (codes , ids )
98+
99+ index .search (query_vector_float32 , k = 5 )
100+
101+ """:py"""
102+ !rm / tmp / ids .npy / tmp / codes .npy
103+
104+ """:md
105+ ## IndexPQ: separate codes from codebook
106+
107+ The second half of this notebook demonstrates how to separate serializing and deserializing the PQ codebook
108+ (via faiss.write_index for IndexPQ) independently of the vector codes. For example, in the case
109+ where you have a few vector embeddings per user and want to shard the flat index by user you
110+ can re-use the same PQ method for all users but store each user's codes independently.
111+
112+ """
113+
114+ """:py"""
115+ M = d // 8
116+ nbits = 8
117+
118+ """:py"""
119+ # at train time
56120template_index = faiss .index_factory (d , f"IDMap2,PQ{ M } x{ nbits } " )
57121template_index .train (training_data )
58122write_template_index (template_index )
@@ -61,8 +125,8 @@ def read_template_index_instance():
61125# New database vector
62126
63127index = read_template_index_instance ()
64- database_vector_id , database_vector_float32 = np .random .randint (10000 ), np .random .rand (1 , d ).astype (np .float32 )
65128ids , codes = read_ids_codes ()
129+ database_vector_id , database_vector_float32 = max (ids ) + 1 if ids is not None else 1 , np .random .rand (1 , d ).astype (np .float32 )
66130
67131code = index .index .sa_encode (database_vector_float32 )
68132
@@ -75,7 +139,7 @@ def read_template_index_instance():
75139
76140write_ids_codes (ids , codes )
77141
78- """:py '331546060044009 '"""
142+ """:py '1858280061369209 '"""
79143# then at query time
80144query_vector_float32 = np .random .rand (1 , d ).astype (np .float32 )
81145id_wrapper_index = read_template_index_instance ()
@@ -87,3 +151,153 @@ def read_template_index_instance():
87151
88152""":py"""
89153!rm / tmp / ids .npy / tmp / codes .npy / tmp / template .index
154+
155+ """:md
156+ ## Comparing these methods
157+
158+ - methods: Flat, LSH, PQ
159+ - vary cost: nbits, M for 1x, 2x, 4x, 8x, 16x, 32x compression
160+ - measure: recall@1
161+
162+ We don't measure latency as the number of vectors per user shard is insignificant.
163+
164+ """
165+
166+ """:py '2898032417027201'"""
167+ n , d
168+
169+ """:py"""
170+ database_vector_ids , database_vector_float32s = np .arange (n ), np .random .rand (n , d ).astype (np .float32 )
171+ query_vector_float32s = np .random .rand (n , d ).astype (np .float32 )
172+
173+ """:py"""
174+ index = faiss .index_factory (d , "IDMap2,Flat" )
175+ index .add_with_ids (database_vector_float32s , database_vector_ids )
176+ _ , ground_truth_result_ids = index .search (query_vector_float32s , k = 1 )
177+
178+ """:py '857475336204238'"""
179+ from dataclasses import dataclass
180+
181+ pq_m_nbits = (
182+ # 96 bytes
183+ (96 , 8 ),
184+ (192 , 4 ),
185+ # 192 bytes
186+ (192 , 8 ),
187+ (384 , 4 ),
188+ # 384 bytes
189+ (384 , 8 ),
190+ (768 , 4 ),
191+ )
192+ lsh_nbits = (768 , 1536 , 3072 , 6144 , 12288 , 24576 )
193+
194+
195+ @dataclass
196+ class Record :
197+ type_ : str
198+ index : faiss .Index
199+ args : tuple
200+ recall : float
201+
202+
203+ results = []
204+
205+ for m , nbits in pq_m_nbits :
206+ print ("pq" , m , nbits )
207+ index = faiss .index_factory (d , f"IDMap2,PQ{ m } x{ nbits } " )
208+ index .train (training_data )
209+ index .add_with_ids (database_vector_float32s , database_vector_ids )
210+ _ , result_ids = index .search (query_vector_float32s , k = 1 )
211+ recall = sum (result_ids == ground_truth_result_ids )
212+ results .append (Record ("pq" , index , (m , nbits ), recall ))
213+
214+ for nbits in lsh_nbits :
215+ print ("lsh" , nbits )
216+ index = faiss .IndexIDMap2 (faiss .IndexLSH (d , nbits ))
217+ index .add_with_ids (database_vector_float32s , database_vector_ids )
218+ _ , result_ids = index .search (query_vector_float32s , k = 1 )
219+ recall = sum (result_ids == ground_truth_result_ids )
220+ results .append (Record ("lsh" , index , (nbits ,), recall ))
221+
222+ """:py '556918346720794'"""
223+ import matplotlib .pyplot as plt
224+ import numpy as np
225+
226+ def create_grouped_bar_chart (x_values , y_values_list , labels_list , xlabel , ylabel , title ):
227+ num_bars_per_group = len (x_values )
228+
229+ plt .figure (figsize = (12 , 6 ))
230+
231+ for x , y_values , labels in zip (x_values , y_values_list , labels_list ):
232+ num_bars = len (y_values )
233+ bar_width = 0.08 * x
234+ bar_positions = np .arange (num_bars ) * bar_width - (num_bars - 1 ) * bar_width / 2 + x
235+
236+ bars = plt .bar (bar_positions , y_values , width = bar_width )
237+
238+ for bar , label in zip (bars , labels ):
239+ height = bar .get_height ()
240+ plt .annotate (
241+ label ,
242+ xy = (bar .get_x () + bar .get_width () / 2 , height ),
243+ xytext = (0 , 3 ),
244+ textcoords = "offset points" ,
245+ ha = 'center' , va = 'bottom'
246+ )
247+
248+ plt .xscale ('log' )
249+ plt .xlabel (xlabel )
250+ plt .ylabel (ylabel )
251+ plt .title (title )
252+ plt .xticks (x_values , labels = [str (x ) for x in x_values ])
253+ plt .tight_layout ()
254+ plt .show ()
255+
256+ # # Example usage:
257+ # x_values = [1, 2, 4, 8, 16, 32]
258+ # y_values_list = [
259+ # [2.5, 3.6, 1.8],
260+ # [3.0, 2.8],
261+ # [2.5, 3.5, 4.0, 1.0],
262+ # [4.2],
263+ # [3.0, 5.5, 2.2],
264+ # [6.0, 4.5]
265+ # ]
266+ # labels_list = [
267+ # ['A1', 'B1', 'C1'],
268+ # ['A2', 'B2'],
269+ # ['A3', 'B3', 'C3', 'D3'],
270+ # ['A4'],
271+ # ['A5', 'B5', 'C5'],
272+ # ['A6', 'B6']
273+ # ]
274+
275+ # create_grouped_bar_chart(x_values, y_values_list, labels_list, "x axis", "y axis", "title")
276+
277+ """:py '1630106834206134'"""
278+ # x-axis: compression ratio
279+ # y-axis: recall@1
280+
281+ from collections import defaultdict
282+
283+ x = defaultdict (list )
284+ x [1 ].append (("flat" , 1.00 ))
285+ for r in results :
286+ y_value = r .recall [0 ] / n
287+ x_value = int (d * 4 / r .index .sa_code_size ())
288+ label = None
289+ if r .type_ == "pq" :
290+ label = f"PQ{ r .args [0 ]} x{ r .args [1 ]} "
291+ if r .type_ == "lsh" :
292+ label = f"LSH{ r .args [0 ]} "
293+ x [x_value ].append ((label , y_value ))
294+
295+ x_values = sorted (list (x .keys ()))
296+ create_grouped_bar_chart (
297+ x_values ,
298+ [[e [1 ] for e in x [x_value ]] for x_value in x_values ],
299+ [[e [0 ] for e in x [x_value ]] for x_value in x_values ],
300+ "compression ratio" ,
301+ "recall@1 q=1,000 queries" ,
302+ "recall@1 for a database of n=1,000 d=768 vectors" ,
303+ )
0 commit comments