Skip to content

Commit f348fe2

Browse files
Michael Norrisfacebook-github-bot
authored andcommitted
Add more unit tests for index_read and index_write
Summary: Adds missing coverage for index_write and index_read Differential Revision: D66846063
1 parent 8939f48 commit f348fe2

1 file changed

Lines changed: 110 additions & 0 deletions

File tree

tests/test_io.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,14 @@
1212
import sys
1313
import pickle
1414
from multiprocessing.pool import ThreadPool
15+
from common_faiss_tests import get_dataset_2
1516

1617

18+
d = 32
19+
nt = 2000
20+
nb = 1000
21+
nq = 200
22+
1723
class TestIOVariants(unittest.TestCase):
1824

1925
def test_io_error(self):
@@ -338,6 +344,110 @@ def test_read_vector_transform(self):
338344
os.unlink(fname)
339345

340346

347+
class Test_IO_PQ(unittest.TestCase):
348+
"""
349+
test read and write PQ.
350+
"""
351+
def test_io_pq(self):
352+
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
353+
index = faiss.IndexPQ(d, 4, 4)
354+
index.train(xt)
355+
356+
fd, fname = tempfile.mkstemp()
357+
os.close(fd)
358+
359+
try:
360+
faiss.write_ProductQuantizer(index.pq, fname)
361+
362+
read_pq = faiss.read_ProductQuantizer(fname)
363+
364+
self.assertEqual(index.pq.M, read_pq.M)
365+
self.assertEqual(index.pq.nbits, read_pq.nbits)
366+
self.assertEqual(index.pq.dsub, read_pq.dsub)
367+
self.assertEqual(index.pq.ksub, read_pq.ksub)
368+
np.testing.assert_array_equal(
369+
faiss.vector_to_array(index.pq.centroids),
370+
faiss.vector_to_array(read_pq.centroids)
371+
)
372+
373+
finally:
374+
if os.path.exists(fname):
375+
os.unlink(fname)
376+
377+
378+
class Test_IO_IndexLSH(unittest.TestCase):
379+
"""
380+
test read and write IndexLSH.
381+
"""
382+
def test_io_lsh(self):
383+
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
384+
index_lsh = faiss.IndexLSH(d, 32, True, True)
385+
index_lsh.train(xt)
386+
index_lsh.add(xb)
387+
D, I = index_lsh.search(xq, 10)
388+
389+
fd, fname = tempfile.mkstemp()
390+
os.close(fd)
391+
392+
try:
393+
faiss.write_index(index_lsh, fname)
394+
395+
reader = faiss.BufferedIOReader(
396+
faiss.FileIOReader(fname), 1234)
397+
read_index_lsh = faiss.read_index(reader)
398+
399+
self.assertEqual(index_lsh.d, read_index_lsh.d)
400+
np.testing.assert_array_equal(
401+
faiss.vector_to_array(index_lsh.codes),
402+
faiss.vector_to_array(read_index_lsh.codes)
403+
)
404+
D_read, I_read = read_index_lsh.search(xq, 10)
405+
406+
np.testing.assert_array_equal(D, D_read)
407+
np.testing.assert_array_equal(I, I_read)
408+
409+
finally:
410+
if os.path.exists(fname):
411+
os.unlink(fname)
412+
413+
414+
class Test_IO_IndexIVFSpectralHash(unittest.TestCase):
415+
"""
416+
test read and write IndexIVFSpectralHash.
417+
"""
418+
def test_io_ivf_spectral_hash(self):
419+
nlist = 1000
420+
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
421+
quantizer = faiss.IndexFlatL2(d)
422+
index = faiss.IndexIVFSpectralHash(quantizer, d, nlist, 8, 1.0)
423+
index.train(xt)
424+
index.add(xb)
425+
D, I = index.search(xq, 10)
426+
427+
fd, fname = tempfile.mkstemp()
428+
os.close(fd)
429+
430+
try:
431+
faiss.write_index(index, fname)
432+
433+
reader = faiss.BufferedIOReader(
434+
faiss.FileIOReader(fname), 1234)
435+
436+
read_index = faiss.read_index(reader)
437+
438+
self.assertEqual(index.d, read_index.d)
439+
self.assertEqual(index.nbit, read_index.nbit)
440+
self.assertEqual(index.period, read_index.period)
441+
self.assertEqual(index.threshold_type, read_index.threshold_type)
442+
443+
D_read, I_read = read_index.search(xq, 10)
444+
np.testing.assert_array_equal(D, D_read)
445+
np.testing.assert_array_equal(I, I_read)
446+
447+
finally:
448+
if os.path.exists(fname):
449+
os.unlink(fname)
450+
341451
class TestIVFPQRead(unittest.TestCase):
342452
def test_reader(self):
343453
d, n = 32, 1000

0 commit comments

Comments
 (0)