@@ -84,9 +84,7 @@ def sp(x):
8484 b = btab [0 ]
8585 dis_new = self .compute_dis_quant (codes , LUTq , biasq , a , b )
8686
87- # print(a, b, dis_ref.sum())
8887 avg_realtive_error = np .abs (dis_new - dis_ref ).sum () / dis_ref .sum ()
89- # print('a=', a, 'avg_relative_error=', avg_realtive_error)
9088 self .assertLess (avg_realtive_error , 0.0005 )
9189
9290 def test_no_residual_ip (self ):
@@ -228,8 +226,6 @@ def eval_quant_loss(self, by_residual, metric=faiss.METRIC_L2):
228226
229227 m3 = three_metrics (Da , Ia , Db , Ib )
230228
231-
232- # print(by_residual, metric, recall_at_1, recall_at_10, intersection_at_10)
233229 ref_results = {
234230 (True , 1 ): [0.985 , 1.0 , 9.872 ],
235231 (True , 0 ): [ 0.987 , 1.0 , 9.914 ],
@@ -261,36 +257,80 @@ class TestEquivPQ(unittest.TestCase):
261257
262258 def test_equiv_pq (self ):
263259 ds = datasets .SyntheticDataset (32 , 2000 , 200 , 4 )
260+ xq = ds .get_queries ()
264261
265262 index = faiss .index_factory (32 , "IVF1,PQ16x4np" )
266263 index .by_residual = False
267264 # force coarse quantizer
268265 index .quantizer .add (np .zeros ((1 , 32 ), dtype = 'float32' ))
269266 index .train (ds .get_train ())
270267 index .add (ds .get_database ())
271- Dref , Iref = index .search (ds . get_queries () , 4 )
268+ Dref , Iref = index .search (xq , 4 )
272269
273270 index_pq = faiss .index_factory (32 , "PQ16x4np" )
274271 index_pq .pq = index .pq
275272 index_pq .is_trained = True
276273 index_pq .codes = faiss . downcast_InvertedLists (
277274 index .invlists ).codes .at (0 )
278275 index_pq .ntotal = index .ntotal
279- Dnew , Inew = index_pq .search (ds . get_queries () , 4 )
276+ Dnew , Inew = index_pq .search (xq , 4 )
280277
281278 np .testing .assert_array_equal (Iref , Inew )
282279 np .testing .assert_array_equal (Dref , Dnew )
283280
284281 index_pq2 = faiss .IndexPQFastScan (index_pq )
285282 index_pq2 .implem = 12
286- Dref , Iref = index_pq2 .search (ds . get_queries () , 4 )
283+ Dref , Iref = index_pq2 .search (xq , 4 )
287284
288285 index2 = faiss .IndexIVFPQFastScan (index )
289286 index2 .implem = 12
290- Dnew , Inew = index2 .search (ds . get_queries () , 4 )
287+ Dnew , Inew = index2 .search (xq , 4 )
291288 np .testing .assert_array_equal (Iref , Inew )
292289 np .testing .assert_array_equal (Dref , Dnew )
293290
291+ # test encode and decode
292+
293+ np .testing .assert_array_equal (
294+ index_pq .sa_encode (xq ),
295+ index2 .sa_encode (xq )
296+ )
297+
298+ np .testing .assert_array_equal (
299+ index_pq .sa_decode (index_pq .sa_encode (xq )),
300+ index2 .sa_decode (index2 .sa_encode (xq ))
301+ )
302+
303+ np .testing .assert_array_equal (
304+ ((index_pq .sa_decode (index_pq .sa_encode (xq )) - xq ) ** 2 ).sum (1 ),
305+ ((index2 .sa_decode (index2 .sa_encode (xq )) - xq ) ** 2 ).sum (1 )
306+ )
307+
308+ def test_equiv_pq_encode_decode (self ):
309+ ds = datasets .SyntheticDataset (32 , 1000 , 200 , 10 )
310+ xq = ds .get_queries ()
311+
312+ index_ivfpq = faiss .index_factory (ds .d , "IVF10,PQ8x4np" )
313+ index_ivfpq .train (ds .get_train ())
314+
315+ index_ivfpqfs = faiss .IndexIVFPQFastScan (index_ivfpq )
316+
317+ np .testing .assert_array_equal (
318+ index_ivfpq .sa_encode (xq ),
319+ index_ivfpqfs .sa_encode (xq )
320+ )
321+
322+ np .testing .assert_array_equal (
323+ index_ivfpq .sa_decode (index_ivfpq .sa_encode (xq )),
324+ index_ivfpqfs .sa_decode (index_ivfpqfs .sa_encode (xq ))
325+ )
326+
327+ np .testing .assert_array_equal (
328+ ((index_ivfpq .sa_decode (index_ivfpq .sa_encode (xq )) - xq ) ** 2 )
329+ .sum (1 ),
330+ ((index_ivfpqfs .sa_decode (index_ivfpqfs .sa_encode (xq )) - xq ) ** 2 )
331+ .sum (1 )
332+ )
333+
294334
295335class TestIVFImplem12 (unittest .TestCase ):
296336
@@ -463,7 +503,6 @@ def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32):
463503 Dnew , Inew = index2 .search (ds .get_queries (), 10 )
464504
465505 m3 = three_metrics (Dref , Iref , Dnew , Inew )
466- # print((by_residual, metric, d), ":", m3)
467506 ref_m3_tab = {
468507 (True , 1 , 32 ): (0.995 , 1.0 , 9.91 ),
469508 (True , 0 , 32 ): (0.99 , 1.0 , 9.91 ),
@@ -554,7 +593,6 @@ def subtest_accuracy(self, aq, st, by_residual, implem, metric_type='L2'):
554593 recall_ref = (Iref == gt ).sum () / nq
555594 recall1 = (I1 == gt ).sum () / nq
556595
557- print (aq , st , by_residual , implem , metric_type , recall_ref , recall1 )
558596 assert abs (recall_ref - recall1 ) < 0.051
559597
560598 def xx_test_accuracy (self ):
@@ -599,7 +637,6 @@ def subtest_rescale_accuracy(self, aq, st, by_residual, implem):
599637 recall_ref = (Iref == gt ).sum () / nq
600638 recall1 = (I1 == gt ).sum () / nq
601639
602- print (aq , st , by_residual , implem , recall_ref , recall1 )
603640 assert abs (recall_ref - recall1 ) < 0.05
604641
605642 def xx_test_rescale_accuracy (self ):
@@ -624,7 +661,6 @@ def subtest_from_ivfaq(self, implem):
624661 nq = Iref .shape [0 ]
625662 recall_ref = (Iref == gt ).sum () / nq
626663 recall1 = (I1 == gt ).sum () / nq
627- print (recall_ref , recall1 )
628664 assert abs (recall_ref - recall1 ) < 0.02
629665
630666 def test_from_ivfaq (self ):
@@ -763,7 +799,6 @@ def subtest_accuracy(self, paq):
763799 recall_ref = (Iref == gt ).sum () / nq
764800 recall1 = (I1 == gt ).sum () / nq
765801
766- print (paq , recall_ref , recall1 )
767802 assert abs (recall_ref - recall1 ) < 0.05
768803
769804 def test_accuracy_PLSQ (self ):
@@ -847,7 +882,6 @@ def do_test(self, metric=faiss.METRIC_L2):
847882 # find a reasonable radius
848883 D , I = index .search (ds .get_queries (), 10 )
849884 radius = np .median (D [:, - 1 ])
850- # print("radius=", radius)
851885 lims1 , D1 , I1 = index .range_search (ds .get_queries (), radius )
852886
853887 index2 = faiss .IndexIVFPQFastScan (index )
@@ -860,7 +894,6 @@ def do_test(self, metric=faiss.METRIC_L2):
860894 for i in range (ds .nq ):
861895 ref = set (I1 [lims1 [i ]: lims1 [i + 1 ]])
862896 new = set (I2 [lims2 [i ]: lims2 [i + 1 ]])
863- print (ref , new )
864897 nmiss += len (ref - new )
865898 nextra += len (new - ref )
866899
0 commit comments