diff --git a/faiss/impl/NSG.cpp b/faiss/impl/NSG.cpp index c974943343..0aa8197ada 100644 --- a/faiss/impl/NSG.cpp +++ b/faiss/impl/NSG.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #include #include @@ -18,14 +16,16 @@ namespace faiss { -namespace nsg { - namespace { +using LockGuard = std::lock_guard; + // It needs to be smaller than 0 constexpr int EMPTY_ID = -1; -} // namespace +} // anonymous namespace + +namespace nsg { DistanceComputer* storage_distance_computer(const Index* storage) { if (is_similarity_metric(storage->metric_type)) { @@ -35,14 +35,8 @@ DistanceComputer* storage_distance_computer(const Index* storage) { } } -} // namespace nsg - -using namespace nsg; - -using LockGuard = std::lock_guard; - struct Neighbor { - int id; + int32_t id; float distance; bool flag; @@ -56,7 +50,7 @@ struct Neighbor { }; struct Node { - int id; + int32_t id; float distance; Node() = default; @@ -65,6 +59,11 @@ struct Node { inline bool operator<(const Node& other) const { return distance < other.distance; } + + // to keep the compiler happy + inline bool operator<(int other) const { + return id < other; + } }; inline int insert_into_pool(Neighbor* addr, int K, Neighbor nn) { @@ -106,6 +105,10 @@ inline int insert_into_pool(Neighbor* addr, int K, Neighbor nn) { return right; } +} // namespace nsg + +using namespace nsg; + NSG::NSG(int R) : R(R), rng(0x0903) { L = R + 32; C = R + 100; @@ -253,9 +256,11 @@ void NSG::search_on_graph( std::vector init_ids(pool_size); int num_ids = 0; - for (int i = 0; i < init_ids.size() && i < graph.K; i++) { - int id = (int)graph.at(ep, i); - if (id < 0 || id >= ntotal) { + std::vector neighbors(graph.K); + size_t nneigh = graph.get_neighbors(ep, neighbors.data()); + for (int i = 0; i < init_ids.size() && i < nneigh; i++) { + int id = (int)neighbors[i]; + if (id >= ntotal) { continue; } @@ -296,9 +301,10 @@ void NSG::search_on_graph( retset[k].flag = false; int n = retset[k].id; - for (int m = 0; m < graph.K; m++) { - int id = (int)graph.at(n, m); - if (id < 0 || id > ntotal || vt.get(id)) { + size_t nneigh = graph.get_neighbors(n, neighbors.data()); + for (int m = 0; m < nneigh; m++) { + int id = neighbors[m]; + if (id > ntotal || vt.get(id)) { continue; } vt.set(id); diff --git a/faiss/impl/NSG.h b/faiss/impl/NSG.h index 641a42f8cf..2f59bc2f8b 100644 --- a/faiss/impl/NSG.h +++ b/faiss/impl/NSG.h @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #pragma once #include @@ -40,11 +38,12 @@ namespace faiss { */ struct DistanceComputer; // from AuxIndexStructures -struct Neighbor; -struct Node; namespace nsg { +struct Neighbor; +struct Node; + /*********************************************************** * Graph structure to store a graph. * @@ -75,7 +74,7 @@ struct Graph { } // release the allocated memory if needed - ~Graph() { + virtual ~Graph() { if (own_fields) { delete[] data; } @@ -90,6 +89,17 @@ struct Graph { inline node_t& at(int i, int j) { return data[i * K + j]; } + + // get all neighbors of node i (used during search only) + virtual size_t get_neighbors(int i, node_t* neighbors) const { + for (int j = 0; j < K; j++) { + if (data[i * K + j] < 0) { + return j; + } + neighbors[j] = data[i * K + j]; + } + return K; + } }; DistanceComputer* storage_distance_computer(const Index* storage); @@ -99,6 +109,8 @@ DistanceComputer* storage_distance_computer(const Index* storage); struct NSG { /// internal storage of vectors (32 bits: this is expensive) using storage_idx_t = int32_t; + using Node = nsg::Node; + using Neighbor = nsg::Neighbor; int ntotal = 0; ///< nb of nodes @@ -112,7 +124,7 @@ struct NSG { int enterpoint; ///< enterpoint - std::shared_ptr> final_graph; ///< NSG graph structure + std::shared_ptr> final_graph; ///< NSG graph structure bool is_built = false; ///< NSG is built or not diff --git a/tests/test_NSG_compressed_graph.cpp b/tests/test_NSG_compressed_graph.cpp new file mode 100644 index 0000000000..ecfc856be4 --- /dev/null +++ b/tests/test_NSG_compressed_graph.cpp @@ -0,0 +1,85 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using namespace faiss; + +using FinalNSGGraph = nsg::Graph; + +struct CompressedNSGGraph : FinalNSGGraph { + int bits; + size_t stride; + std::vector compressed_data; + + CompressedNSGGraph(const FinalNSGGraph& graph, int bits) + : FinalNSGGraph(graph.data, graph.N, graph.K), bits(bits) { + FAISS_THROW_IF_NOT((1 << bits) >= K + 1); + stride = (K * bits + 7) / 8; + compressed_data.resize(N * stride); + for (size_t i = 0; i < N; i++) { + BitstringWriter writer(compressed_data.data() + i * stride, stride); + for (size_t j = 0; j < K; j++) { + int32_t v = graph.data[i * K + j]; + if (v == -1) { + writer.write(K + 1, bits); + break; + } else { + writer.write(v, bits); + } + } + } + data = nullptr; + } + + size_t get_neighbors(int i, int32_t* neighbors) const override { + BitstringReader reader(compressed_data.data() + i * stride, stride); + for (int j = 0; j < K; j++) { + int32_t v = reader.read(bits); + if (v == K + 1) { + return j; + } + neighbors[j] = v; + } + return K; + } +}; + +TEST(NSGCompressed, test_compressed) { + size_t nq = 10, nt = 0, nb = 5000, d = 32, k = 10; + + using idx_t = faiss::idx_t; + + std::vector buf((nq + nb + nt) * d); + faiss::rand_smooth_vectors(nq + nb + nt, d, buf.data(), 1234); + const float* xt = buf.data(); + const float* xb = xt + nt * d; + const float* xq = xb + nb * d; + + faiss::IndexNSGFlat index(d, 32); + + index.add(nb, xb); + + std::vector Iref(nq * k); + std::vector Dref(nq * k); + index.search(nq, xq, k, Dref.data(), Iref.data()); + + // replace the shared ptr + index.nsg.final_graph.reset( + new CompressedNSGGraph(*index.nsg.final_graph, 13)); + + std::vector I(nq * k); + std::vector D(nq * k); + index.search(nq, xq, k, D.data(), I.data()); + + // make sure we find back the original results + EXPECT_EQ(Iref, I); + EXPECT_EQ(Dref, D); +}