Skip to content

Commit 96bc9c7

Browse files
Michael Norrisfacebook-github-bot
authored andcommitted
Add more unit tests for index_read and index_write (facebookresearch#4068)
Summary: Pull Request resolved: facebookresearch#4068 Adds missing coverage for index_write and index_read Reviewed By: asadoughi Differential Revision: D66846063 fbshipit-source-id: 68686318ecce64804425502a425a06f510fb5df8
1 parent 90e4c4d commit 96bc9c7

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed

tests/test_io.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,14 @@
1212
import sys
1313
import pickle
1414
from multiprocessing.pool import ThreadPool
15+
from common_faiss_tests import get_dataset_2
1516

1617

18+
d = 32
19+
nt = 2000
20+
nb = 1000
21+
nq = 200
22+
1723
class TestIOVariants(unittest.TestCase):
1824

1925
def test_io_error(self):
@@ -338,6 +344,113 @@ def test_read_vector_transform(self):
338344
os.unlink(fname)
339345

340346

347+
class Test_IO_PQ(unittest.TestCase):
348+
"""
349+
test read and write PQ.
350+
"""
351+
def test_io_pq(self):
352+
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
353+
index = faiss.IndexPQ(d, 4, 4)
354+
index.train(xt)
355+
356+
fd, fname = tempfile.mkstemp()
357+
os.close(fd)
358+
359+
try:
360+
faiss.write_ProductQuantizer(index.pq, fname)
361+
362+
read_pq = faiss.read_ProductQuantizer(fname)
363+
364+
self.assertEqual(index.pq.M, read_pq.M)
365+
self.assertEqual(index.pq.nbits, read_pq.nbits)
366+
self.assertEqual(index.pq.dsub, read_pq.dsub)
367+
self.assertEqual(index.pq.ksub, read_pq.ksub)
368+
np.testing.assert_array_equal(
369+
faiss.vector_to_array(index.pq.centroids),
370+
faiss.vector_to_array(read_pq.centroids)
371+
)
372+
373+
finally:
374+
if os.path.exists(fname):
375+
os.unlink(fname)
376+
377+
378+
class Test_IO_IndexLSH(unittest.TestCase):
379+
"""
380+
test read and write IndexLSH.
381+
"""
382+
def test_io_lsh(self):
383+
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
384+
index_lsh = faiss.IndexLSH(d, 32, True, True)
385+
index_lsh.train(xt)
386+
index_lsh.add(xb)
387+
D, I = index_lsh.search(xq, 10)
388+
389+
fd, fname = tempfile.mkstemp()
390+
os.close(fd)
391+
392+
try:
393+
faiss.write_index(index_lsh, fname)
394+
395+
reader = faiss.BufferedIOReader(
396+
faiss.FileIOReader(fname), 1234)
397+
read_index_lsh = faiss.read_index(reader)
398+
# Delete reader to prevent [WinError 32] The process cannot
399+
# access the file because it is being used by another process
400+
del reader
401+
402+
self.assertEqual(index_lsh.d, read_index_lsh.d)
403+
np.testing.assert_array_equal(
404+
faiss.vector_to_array(index_lsh.codes),
405+
faiss.vector_to_array(read_index_lsh.codes)
406+
)
407+
D_read, I_read = read_index_lsh.search(xq, 10)
408+
409+
np.testing.assert_array_equal(D, D_read)
410+
np.testing.assert_array_equal(I, I_read)
411+
412+
finally:
413+
if os.path.exists(fname):
414+
os.unlink(fname)
415+
416+
417+
class Test_IO_IndexIVFSpectralHash(unittest.TestCase):
418+
"""
419+
test read and write IndexIVFSpectralHash.
420+
"""
421+
def test_io_ivf_spectral_hash(self):
422+
nlist = 1000
423+
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
424+
quantizer = faiss.IndexFlatL2(d)
425+
index = faiss.IndexIVFSpectralHash(quantizer, d, nlist, 8, 1.0)
426+
index.train(xt)
427+
index.add(xb)
428+
D, I = index.search(xq, 10)
429+
430+
fd, fname = tempfile.mkstemp()
431+
os.close(fd)
432+
433+
try:
434+
faiss.write_index(index, fname)
435+
436+
reader = faiss.BufferedIOReader(
437+
faiss.FileIOReader(fname), 1234)
438+
read_index = faiss.read_index(reader)
439+
del reader
440+
441+
self.assertEqual(index.d, read_index.d)
442+
self.assertEqual(index.nbit, read_index.nbit)
443+
self.assertEqual(index.period, read_index.period)
444+
self.assertEqual(index.threshold_type, read_index.threshold_type)
445+
446+
D_read, I_read = read_index.search(xq, 10)
447+
np.testing.assert_array_equal(D, D_read)
448+
np.testing.assert_array_equal(I, I_read)
449+
450+
finally:
451+
if os.path.exists(fname):
452+
os.unlink(fname)
453+
341454
class TestIVFPQRead(unittest.TestCase):
342455
def test_reader(self):
343456
d, n = 32, 1000

0 commit comments

Comments
 (0)