Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions faiss/impl/NSG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#include <faiss/impl/NSG.h>

#include <algorithm>
Expand All @@ -18,14 +16,16 @@

namespace faiss {

namespace nsg {

namespace {

using LockGuard = std::lock_guard<std::mutex>;

// It needs to be smaller than 0
constexpr int EMPTY_ID = -1;

} // namespace
} // anonymous namespace

namespace nsg {

DistanceComputer* storage_distance_computer(const Index* storage) {
if (is_similarity_metric(storage->metric_type)) {
Expand All @@ -35,14 +35,8 @@ DistanceComputer* storage_distance_computer(const Index* storage) {
}
}

} // namespace nsg

using namespace nsg;

using LockGuard = std::lock_guard<std::mutex>;

struct Neighbor {
int id;
int32_t id;
float distance;
bool flag;

Expand All @@ -56,7 +50,7 @@ struct Neighbor {
};

struct Node {
int id;
int32_t id;
float distance;

Node() = default;
Expand All @@ -65,6 +59,11 @@ struct Node {
inline bool operator<(const Node& other) const {
return distance < other.distance;
}

// to keep the compiler happy
inline bool operator<(int other) const {
return id < other;
}
};

inline int insert_into_pool(Neighbor* addr, int K, Neighbor nn) {
Expand Down Expand Up @@ -106,6 +105,10 @@ inline int insert_into_pool(Neighbor* addr, int K, Neighbor nn) {
return right;
}

} // namespace nsg

using namespace nsg;

NSG::NSG(int R) : R(R), rng(0x0903) {
L = R + 32;
C = R + 100;
Expand Down Expand Up @@ -253,9 +256,11 @@ void NSG::search_on_graph(
std::vector<int> init_ids(pool_size);

int num_ids = 0;
for (int i = 0; i < init_ids.size() && i < graph.K; i++) {
int id = (int)graph.at(ep, i);
if (id < 0 || id >= ntotal) {
std::vector<index_t> neighbors(graph.K);
size_t nneigh = graph.get_neighbors(ep, neighbors.data());
for (int i = 0; i < init_ids.size() && i < nneigh; i++) {
int id = (int)neighbors[i];
if (id >= ntotal) {
continue;
}

Expand Down Expand Up @@ -296,9 +301,10 @@ void NSG::search_on_graph(
retset[k].flag = false;
int n = retset[k].id;

for (int m = 0; m < graph.K; m++) {
int id = (int)graph.at(n, m);
if (id < 0 || id > ntotal || vt.get(id)) {
size_t nneigh = graph.get_neighbors(n, neighbors.data());
for (int m = 0; m < nneigh; m++) {
int id = neighbors[m];
if (id > ntotal || vt.get(id)) {
continue;
}
vt.set(id);
Expand Down
24 changes: 18 additions & 6 deletions faiss/impl/NSG.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#pragma once

#include <memory>
Expand Down Expand Up @@ -40,11 +38,12 @@ namespace faiss {
*/

struct DistanceComputer; // from AuxIndexStructures
struct Neighbor;
struct Node;

namespace nsg {

struct Neighbor;
struct Node;

/***********************************************************
* Graph structure to store a graph.
*
Expand Down Expand Up @@ -75,7 +74,7 @@ struct Graph {
}

// release the allocated memory if needed
~Graph() {
virtual ~Graph() {
if (own_fields) {
delete[] data;
}
Expand All @@ -90,6 +89,17 @@ struct Graph {
inline node_t& at(int i, int j) {
return data[i * K + j];
}

// get all neighbors of node i (used during search only)
virtual size_t get_neighbors(int i, node_t* neighbors) const {
for (int j = 0; j < K; j++) {
if (data[i * K + j] < 0) {
return j;
}
neighbors[j] = data[i * K + j];
}
return K;
}
};

DistanceComputer* storage_distance_computer(const Index* storage);
Expand All @@ -99,6 +109,8 @@ DistanceComputer* storage_distance_computer(const Index* storage);
struct NSG {
/// internal storage of vectors (32 bits: this is expensive)
using storage_idx_t = int32_t;
using Node = nsg::Node;
using Neighbor = nsg::Neighbor;

int ntotal = 0; ///< nb of nodes

Expand All @@ -112,7 +124,7 @@ struct NSG {

int enterpoint; ///< enterpoint

std::shared_ptr<nsg::Graph<int>> final_graph; ///< NSG graph structure
std::shared_ptr<nsg::Graph<int32_t>> final_graph; ///< NSG graph structure

bool is_built = false; ///< NSG is built or not

Expand Down
85 changes: 85 additions & 0 deletions tests/test_NSG_compressed_graph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <faiss/IndexNSG.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/random.h>
#include <gtest/gtest.h>

using namespace faiss;

using FinalNSGGraph = nsg::Graph<int32_t>;

struct CompressedNSGGraph : FinalNSGGraph {
int bits;
size_t stride;
std::vector<uint8_t> compressed_data;

CompressedNSGGraph(const FinalNSGGraph& graph, int bits)
: FinalNSGGraph(graph.data, graph.N, graph.K), bits(bits) {
FAISS_THROW_IF_NOT((1 << bits) >= K + 1);
stride = (K * bits + 7) / 8;
compressed_data.resize(N * stride);
for (size_t i = 0; i < N; i++) {
BitstringWriter writer(compressed_data.data() + i * stride, stride);
for (size_t j = 0; j < K; j++) {
int32_t v = graph.data[i * K + j];
if (v == -1) {
writer.write(K + 1, bits);
break;
} else {
writer.write(v, bits);
}
}
}
data = nullptr;
}

size_t get_neighbors(int i, int32_t* neighbors) const override {
BitstringReader reader(compressed_data.data() + i * stride, stride);
for (int j = 0; j < K; j++) {
int32_t v = reader.read(bits);
if (v == K + 1) {
return j;
}
neighbors[j] = v;
}
return K;
}
};

TEST(NSGCompressed, test_compressed) {
size_t nq = 10, nt = 0, nb = 5000, d = 32, k = 10;

using idx_t = faiss::idx_t;

std::vector<float> buf((nq + nb + nt) * d);
faiss::rand_smooth_vectors(nq + nb + nt, d, buf.data(), 1234);
const float* xt = buf.data();
const float* xb = xt + nt * d;
const float* xq = xb + nb * d;

faiss::IndexNSGFlat index(d, 32);

index.add(nb, xb);

std::vector<faiss::idx_t> Iref(nq * k);
std::vector<float> Dref(nq * k);
index.search(nq, xq, k, Dref.data(), Iref.data());

// replace the shared ptr
index.nsg.final_graph.reset(
new CompressedNSGGraph(*index.nsg.final_graph, 13));

std::vector<idx_t> I(nq * k);
std::vector<float> D(nq * k);
index.search(nq, xq, k, D.data(), I.data());

// make sure we find back the original results
EXPECT_EQ(Iref, I);
EXPECT_EQ(Dref, D);
}