|
12 | 12 | import sys |
13 | 13 | import pickle |
14 | 14 | from multiprocessing.pool import ThreadPool |
| 15 | +from common_faiss_tests import get_dataset_2 |
15 | 16 |
|
16 | 17 |
|
| 18 | +d = 32 |
| 19 | +nt = 2000 |
| 20 | +nb = 1000 |
| 21 | +nq = 200 |
| 22 | + |
17 | 23 | class TestIOVariants(unittest.TestCase): |
18 | 24 |
|
19 | 25 | def test_io_error(self): |
@@ -338,6 +344,113 @@ def test_read_vector_transform(self): |
338 | 344 | os.unlink(fname) |
339 | 345 |
|
340 | 346 |
|
| 347 | +class Test_IO_PQ(unittest.TestCase): |
| 348 | + """ |
| 349 | + test read and write PQ. |
| 350 | + """ |
| 351 | + def test_io_pq(self): |
| 352 | + xt, xb, xq = get_dataset_2(d, nt, nb, nq) |
| 353 | + index = faiss.IndexPQ(d, 4, 4) |
| 354 | + index.train(xt) |
| 355 | + |
| 356 | + fd, fname = tempfile.mkstemp() |
| 357 | + os.close(fd) |
| 358 | + |
| 359 | + try: |
| 360 | + faiss.write_ProductQuantizer(index.pq, fname) |
| 361 | + |
| 362 | + read_pq = faiss.read_ProductQuantizer(fname) |
| 363 | + |
| 364 | + self.assertEqual(index.pq.M, read_pq.M) |
| 365 | + self.assertEqual(index.pq.nbits, read_pq.nbits) |
| 366 | + self.assertEqual(index.pq.dsub, read_pq.dsub) |
| 367 | + self.assertEqual(index.pq.ksub, read_pq.ksub) |
| 368 | + np.testing.assert_array_equal( |
| 369 | + faiss.vector_to_array(index.pq.centroids), |
| 370 | + faiss.vector_to_array(read_pq.centroids) |
| 371 | + ) |
| 372 | + |
| 373 | + finally: |
| 374 | + if os.path.exists(fname): |
| 375 | + os.unlink(fname) |
| 376 | + |
| 377 | + |
| 378 | +class Test_IO_IndexLSH(unittest.TestCase): |
| 379 | + """ |
| 380 | + test read and write IndexLSH. |
| 381 | + """ |
| 382 | + def test_io_lsh(self): |
| 383 | + xt, xb, xq = get_dataset_2(d, nt, nb, nq) |
| 384 | + index_lsh = faiss.IndexLSH(d, 32, True, True) |
| 385 | + index_lsh.train(xt) |
| 386 | + index_lsh.add(xb) |
| 387 | + D, I = index_lsh.search(xq, 10) |
| 388 | + |
| 389 | + fd, fname = tempfile.mkstemp() |
| 390 | + os.close(fd) |
| 391 | + |
| 392 | + try: |
| 393 | + faiss.write_index(index_lsh, fname) |
| 394 | + |
| 395 | + reader = faiss.BufferedIOReader( |
| 396 | + faiss.FileIOReader(fname), 1234) |
| 397 | + read_index_lsh = faiss.read_index(reader) |
| 398 | + # Delete reader to prevent [WinError 32] The process cannot |
| 399 | + # access the file because it is being used by another process |
| 400 | + del reader |
| 401 | + |
| 402 | + self.assertEqual(index_lsh.d, read_index_lsh.d) |
| 403 | + np.testing.assert_array_equal( |
| 404 | + faiss.vector_to_array(index_lsh.codes), |
| 405 | + faiss.vector_to_array(read_index_lsh.codes) |
| 406 | + ) |
| 407 | + D_read, I_read = read_index_lsh.search(xq, 10) |
| 408 | + |
| 409 | + np.testing.assert_array_equal(D, D_read) |
| 410 | + np.testing.assert_array_equal(I, I_read) |
| 411 | + |
| 412 | + finally: |
| 413 | + if os.path.exists(fname): |
| 414 | + os.unlink(fname) |
| 415 | + |
| 416 | + |
| 417 | +class Test_IO_IndexIVFSpectralHash(unittest.TestCase): |
| 418 | + """ |
| 419 | + test read and write IndexIVFSpectralHash. |
| 420 | + """ |
| 421 | + def test_io_ivf_spectral_hash(self): |
| 422 | + nlist = 1000 |
| 423 | + xt, xb, xq = get_dataset_2(d, nt, nb, nq) |
| 424 | + quantizer = faiss.IndexFlatL2(d) |
| 425 | + index = faiss.IndexIVFSpectralHash(quantizer, d, nlist, 8, 1.0) |
| 426 | + index.train(xt) |
| 427 | + index.add(xb) |
| 428 | + D, I = index.search(xq, 10) |
| 429 | + |
| 430 | + fd, fname = tempfile.mkstemp() |
| 431 | + os.close(fd) |
| 432 | + |
| 433 | + try: |
| 434 | + faiss.write_index(index, fname) |
| 435 | + |
| 436 | + reader = faiss.BufferedIOReader( |
| 437 | + faiss.FileIOReader(fname), 1234) |
| 438 | + read_index = faiss.read_index(reader) |
| 439 | + del reader |
| 440 | + |
| 441 | + self.assertEqual(index.d, read_index.d) |
| 442 | + self.assertEqual(index.nbit, read_index.nbit) |
| 443 | + self.assertEqual(index.period, read_index.period) |
| 444 | + self.assertEqual(index.threshold_type, read_index.threshold_type) |
| 445 | + |
| 446 | + D_read, I_read = read_index.search(xq, 10) |
| 447 | + np.testing.assert_array_equal(D, D_read) |
| 448 | + np.testing.assert_array_equal(I, I_read) |
| 449 | + |
| 450 | + finally: |
| 451 | + if os.path.exists(fname): |
| 452 | + os.unlink(fname) |
| 453 | + |
341 | 454 | class TestIVFPQRead(unittest.TestCase): |
342 | 455 | def test_reader(self): |
343 | 456 | d, n = 32, 1000 |
|
0 commit comments