Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,17 @@ void IndexHNSW::reconstruct(idx_t key, float* recons) const {
storage->reconstruct(key, recons);
}

/**************************************************************
* This section of functions were used during the development of HNSW support.
* They may be useful in the future but are dormant for now, and thus are not
* unit tested at the moment.
* shrink_level_0_neighbors
* search_level_0
* init_level_0_from_knngraph
* init_level_0_from_entry_points
* reorder_links
* link_singletons
**************************************************************/
void IndexHNSW::shrink_level_0_neighbors(int new_size) {
#pragma omp parallel
{
Expand Down
122 changes: 46 additions & 76 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ void add_link(
}
}

} // namespace

/// search neighbors on a single level, starting from an entry point
void search_neighbors_to_add(
HNSW& hnsw,
Expand All @@ -360,9 +362,6 @@ void search_neighbors_to_add(
float d_entry_point,
int level,
VisitedTable& vt) {
// selects a version
const bool reference_version = false;

// top is nearest candidate
std::priority_queue<NodeDistFarther> candidates;

Expand All @@ -385,93 +384,64 @@ void search_neighbors_to_add(
size_t begin, end;
hnsw.neighbor_range(currNode, level, &begin, &end);

// select a version, based on a flag
if (reference_version) {
// a reference version
for (size_t i = begin; i < end; i++) {
storage_idx_t nodeId = hnsw.neighbors[i];
if (nodeId < 0)
break;
if (vt.get(nodeId))
continue;
vt.set(nodeId);

float dis = qdis(nodeId);
NodeDistFarther evE1(dis, nodeId);

if (results.size() < hnsw.efConstruction ||
results.top().d > dis) {
results.emplace(dis, nodeId);
candidates.emplace(dis, nodeId);
if (results.size() > hnsw.efConstruction) {
results.pop();
}
// process 4 neighbors at a time
// Compare this to reference version in test_hnsw.cpp
auto update_with_candidate = [&](const storage_idx_t idx,
const float dis) {
if (results.size() < hnsw.efConstruction || results.top().d > dis) {
results.emplace(dis, idx);
candidates.emplace(dis, idx);
if (results.size() > hnsw.efConstruction) {
results.pop();
}
}
} else {
// a faster version

// the following version processes 4 neighbors at a time
auto update_with_candidate = [&](const storage_idx_t idx,
const float dis) {
if (results.size() < hnsw.efConstruction ||
results.top().d > dis) {
results.emplace(dis, idx);
candidates.emplace(dis, idx);
if (results.size() > hnsw.efConstruction) {
results.pop();
}
}
};
};

int n_buffered = 0;
storage_idx_t buffered_ids[4];

int n_buffered = 0;
storage_idx_t buffered_ids[4];
for (size_t j = begin; j < end; j++) {
storage_idx_t nodeId = hnsw.neighbors[j];
if (nodeId < 0)
break;
if (vt.get(nodeId)) {
continue;
}
vt.set(nodeId);

for (size_t j = begin; j < end; j++) {
storage_idx_t nodeId = hnsw.neighbors[j];
if (nodeId < 0)
break;
if (vt.get(nodeId)) {
continue;
}
vt.set(nodeId);

buffered_ids[n_buffered] = nodeId;
n_buffered += 1;

if (n_buffered == 4) {
float dis[4];
qdis.distances_batch_4(
buffered_ids[0],
buffered_ids[1],
buffered_ids[2],
buffered_ids[3],
dis[0],
dis[1],
dis[2],
dis[3]);

for (size_t id4 = 0; id4 < 4; id4++) {
update_with_candidate(buffered_ids[id4], dis[id4]);
}
buffered_ids[n_buffered] = nodeId;
n_buffered += 1;

n_buffered = 0;
if (n_buffered == 4) {
float dis[4];
qdis.distances_batch_4(
buffered_ids[0],
buffered_ids[1],
buffered_ids[2],
buffered_ids[3],
dis[0],
dis[1],
dis[2],
dis[3]);

for (size_t id4 = 0; id4 < 4; id4++) {
update_with_candidate(buffered_ids[id4], dis[id4]);
}
}

// process leftovers
for (size_t icnt = 0; icnt < n_buffered; icnt++) {
float dis = qdis(buffered_ids[icnt]);
update_with_candidate(buffered_ids[icnt], dis);
n_buffered = 0;
}
}

// process leftovers
for (size_t icnt = 0; icnt < n_buffered; icnt++) {
float dis = qdis(buffered_ids[icnt]);
update_with_candidate(buffered_ids[icnt], dis);
}
}

vt.advance();
}

} // namespace

/// Finds neighbors and builds links with them, starting from an entry
/// point. The own neighbor list is assumed to be locked.
void HNSW::add_links_starting_from(
Expand Down
9 changes: 9 additions & 0 deletions faiss/impl/HNSW.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,13 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
VisitedTable* vt,
HNSWStats& stats);

void search_neighbors_to_add(
HNSW& hnsw,
DistanceComputer& qdis,
std::priority_queue<HNSW::NodeDistCloser>& results,
int entry_point,
float d_entry_point,
int level,
VisitedTable& vt);

} // namespace faiss
53 changes: 50 additions & 3 deletions tests/test_graph_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ def test_hnsw_unbounded_queue(self):

self.io_and_retest(index, Dhnsw, Ihnsw)

def test_hnsw_no_init_level0(self):
d = self.xq.shape[1]

index = faiss.IndexHNSWFlat(d, 16)
index.init_level0 = False
index.add(self.xb)
Dhnsw, Ihnsw = index.search(self.xq, 1)

# This is expected to be smaller because we are not initializing
# vectors into level 0.
self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 25)

self.io_and_retest(index, Dhnsw, Ihnsw)

def io_and_retest(self, index, Dhnsw, Ihnsw):
index2 = faiss.deserialize_index(faiss.serialize_index(index))
Dhnsw2, Ihnsw2 = index2.search(self.xq, 1)
Expand Down Expand Up @@ -101,6 +115,24 @@ def test_hnsw_2level(self):

self.io_and_retest(index, Dhnsw, Ihnsw)

def test_hnsw_2level_mixed_search(self):
d = self.xq.shape[1]

quant = faiss.IndexFlatL2(d)

storage = faiss.IndexIVFPQ(quant, d, 32, 8, 8)
storage.make_direct_map()
index = faiss.IndexHNSW2Level(quant, 32, 8, 8)
index.storage = storage
index.train(self.xb)
index.add(self.xb)
Dhnsw, Ihnsw = index.search(self.xq, 1)

# It is expected that the mixed search will perform worse.
self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 200)

self.io_and_retest(index, Dhnsw, Ihnsw)

def test_add_0_vecs(self):
index = faiss.IndexHNSWFlat(10, 16)
zero_vecs = np.zeros((0, 10), dtype='float32')
Expand Down Expand Up @@ -175,16 +207,31 @@ def test_abs_inner_product(self):
xb = self.xb - self.xb.mean(axis=0) # need to be centered to give interesting directions
xq = self.xq - self.xq.mean(axis=0)
Dref, Iref = faiss.knn(xq, xb, 10, faiss.METRIC_ABS_INNER_PRODUCT)

index = faiss.IndexHNSWFlat(d, 32, faiss.METRIC_ABS_INNER_PRODUCT)
index.add(xb)
Dnew, Inew = index.search(xq, 10)

inter = faiss.eval_intersection(Iref, Inew)
# 4769 vs. 500*10
self.assertGreater(inter, Iref.size * 0.9)



def test_hnsw_reset(self):
d = self.xb.shape[1]
index_flat = faiss.IndexFlat(d)
index_flat.add(self.xb)
self.assertEqual(index_flat.ntotal, self.xb.shape[0])
index_hnsw = faiss.IndexHNSW(index_flat)
index_hnsw.add(self.xb)
# * 2 because we add to storage twice. This is just for testing
# that storage gets cleared correctly.
self.assertEqual(index_hnsw.ntotal, self.xb.shape[0] * 2)

index_hnsw.reset()

self.assertEqual(index_flat.ntotal, 0)
self.assertEqual(index_hnsw.ntotal, 0)

class Issue3684(unittest.TestCase):

def test_issue3684(self):
Expand Down
Loading