diff --git a/tests/test_index.py b/tests/test_index.py index d266bdbcc0..2e0174d173 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -16,6 +16,7 @@ import warnings from common_faiss_tests import get_dataset, get_dataset_2 +from faiss.contrib.evaluation import check_ref_knn_with_draws class TestModuleInterface(unittest.TestCase): @@ -422,7 +423,7 @@ def run_search_and_reconstruct(self, index, xb, xq, k=10, eps=None): D, I, R = index.search_and_reconstruct(xq, k) np.testing.assert_almost_equal(D, D_ref, decimal=5) - self.assertTrue((I == I_ref).all()) + check_ref_knn_with_draws(D_ref, I_ref, D, I) self.assertEqual(R.shape[:2], I.shape) self.assertEqual(R.shape[2], d)