Skip to content

Commit 2d7dd5b

Browse files
mdouzefacebook-github-bot
authored andcommitted
support checkpointing in big batch search
Summary: Big batch search can be running for hours so it's useful to have a checkpointing mechanism in case it's run on a best-effort cluster queue. Reviewed By: algoriddle Differential Revision: D44059758 fbshipit-source-id: 5cb5e80800c6d2bf76d9f6cb40736009cd5d4b8e
1 parent 371d9c2 commit 2d7dd5b

File tree

2 files changed

+103
-7
lines changed

2 files changed

+103
-7
lines changed

contrib/ivf_tools.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import time
7+
import pickle
8+
import os
79
from multiprocessing.pool import ThreadPool
810
import threading
911

@@ -241,6 +243,28 @@ def add_results_to_heap(self, q_subset, D, list_ids, I):
241243
self.rh.add_result_subset(q_subset, D, I)
242244
self.t_accu[3] += time.time() - t0
243245

246+
def sizes_in_checkpoint(self):
247+
return (self.xq.shape, self.index.nprobe, self.index.nlist)
248+
249+
def write_checkpoint(self, fname, cur_list_no):
250+
# write to temp file then move to final file
251+
tmpname = fname + ".tmp"
252+
pickle.dump(
253+
{
254+
"sizes": self.sizes_in_checkpoint(),
255+
"cur_list_no": cur_list_no,
256+
"rh": (self.rh.D, self.rh.I),
257+
}, open(tmpname, "wb"), -1
258+
)
259+
os.rename(tmpname, fname)
260+
261+
def read_checkpoint(self, fname):
262+
ckp = pickle.load(open(fname, "rb"))
263+
assert ckp["sizes"] == self.sizes_in_checkpoint()
264+
self.rh.D[:] = ckp["rh"][0]
265+
self.rh.I[:] = ckp["rh"][1]
266+
return ckp["cur_list_no"]
267+
244268

245269
class BlockComputer:
246270
""" computation within one bucket """
@@ -308,7 +332,13 @@ def big_batch_search(
308332
use_float16=False,
309333
prefetch_threads=8,
310334
computation_threads=0,
311-
q_assign=None):
335+
q_assign=None,
336+
checkpoint=None,
337+
checkpoint_freq=64,
338+
start_list=0,
339+
end_list=None,
340+
crash_at=-1
341+
):
312342
"""
313343
Search queries xq in the IVF index, with a search function that collects
314344
batches of query vectors per inverted list. This can be faster than the
@@ -336,7 +366,13 @@ def big_batch_search(
336366
337367
use_float16: convert all matrices to float16 (faster for GPU gemm)
338368
339-
q_assign: override coarse assignment
369+
q_assign: override coarse assignment, should be a matrix of size nq * nprobe
370+
371+
checkpointing (only for threaded > 1):
372+
checkpoint: file where the checkpoints are stored
373+
checkpoint_freq: when to perform checkpoinging. Should be a multiple of threaded
374+
375+
start_list, end_list: process only a subset of invlists
340376
"""
341377
nprobe = index.nprobe
342378

@@ -377,10 +413,22 @@ def big_batch_search(
377413
bbs.q_assign = q_assign
378414
bbs.reorder_assign()
379415

416+
if end_list is None:
417+
end_list = index.nlist
418+
419+
if checkpoint is not None:
420+
assert (start_list, end_list) == (0, index.nlist)
421+
if os.path.exists(checkpoint):
422+
print("recovering checkpoint", checkpoint)
423+
start_list = bbs.read_checkpoint(checkpoint)
424+
print(" start at list", start_list)
425+
else:
426+
print("no checkpoint: starting from scratch")
427+
380428
if threaded == 0:
381429
# simple sequential version
382430

383-
for l in range(index.nlist):
431+
for l in range(start_list, end_list):
384432
bbs.report(l)
385433
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l)
386434
t0i = time.time()
@@ -400,11 +448,11 @@ def add_results_and_prefetch(to_add, l):
400448
if l < index.nlist:
401449
return bbs.prepare_bucket(l)
402450

403-
prefetched_bucket = bbs.prepare_bucket(0)
451+
prefetched_bucket = bbs.prepare_bucket(start_list)
404452
to_add = None
405453
pool = ThreadPool(1)
406454

407-
for l in range(index.nlist):
455+
for l in range(start_list, end_list):
408456
bbs.report(l)
409457
prefetched_bucket_a = pool.apply_async(
410458
add_results_and_prefetch, (to_add, l + 1))
@@ -422,6 +470,7 @@ def add_results_and_prefetch(to_add, l):
422470
else:
423471
# run by batches with parallel prefetch and parallel comp
424472
list_step = threaded
473+
assert start_list % list_step == 0
425474

426475
if prefetch_threads == 0:
427476
prefetch_map = map
@@ -462,13 +511,13 @@ def do_comp(bucket):
462511
D, I = comp.block_search(xq_l, xb_l, list_ids, k, thread_id=tid)
463512
return q_subset, D, list_ids, I
464513

465-
prefetched_buckets = add_results_and_prefetch_batch([], 0)
514+
prefetched_buckets = add_results_and_prefetch_batch([], start_list)
466515
to_add = []
467516
pool = ThreadPool(1)
468517
prefetched_buckets_a = None
469518

470519
# loop over inverted lists
471-
for l in range(0, index.nlist, list_step):
520+
for l in range(start_list, end_list, list_step):
472521
bbs.report(l)
473522
buckets = prefetched_buckets
474523
prefetched_buckets_a = pool.apply_async(
@@ -486,10 +535,19 @@ def do_comp(bucket):
486535

487536
bbs.stop_t_accu(2)
488537

538+
# to test checkpointing
539+
if l == crash_at:
540+
1 / 0
541+
489542
bbs.start_t_accu()
490543
prefetched_buckets = prefetched_buckets_a.get()
491544
bbs.stop_t_accu(4)
492545

546+
if checkpoint is not None:
547+
if (l // list_step) % checkpoint_freq == 0:
548+
print("writing checkpoint %s" % l)
549+
bbs.write_checkpoint(checkpoint, l)
550+
493551
# flush add
494552
for ta in to_add:
495553
bbs.add_results_to_heap(*ta)

tests/test_contrib.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import unittest
88
import numpy as np
99
import platform
10+
import os
1011

1112
from faiss.contrib import datasets
1213
from faiss.contrib import inspect_tools
@@ -514,3 +515,40 @@ def test_PQ(self):
514515

515516
def test_SQ(self):
516517
self.do_test("IVF64,SQ8")
518+
519+
def test_checkpoint(self):
520+
ds = datasets.SyntheticDataset(32, 2000, 400, 500)
521+
k = 10
522+
index = faiss.index_factory(ds.d, "IVF64,SQ8")
523+
index.train(ds.get_train())
524+
index.add(ds.get_database())
525+
index.nprobe = 5
526+
Dref, Iref = index.search(ds.get_queries(), k)
527+
528+
checkpoint = "/tmp/test_big_batch_checkpoint.%d" % np.random.randint(int(1e16))
529+
try:
530+
# First big batch search
531+
try:
532+
Dnew, Inew = ivf_tools.big_batch_search(
533+
index, ds.get_queries(),
534+
k, method="knn_function",
535+
threaded=4,
536+
checkpoint=checkpoint, checkpoint_freq=4,
537+
crash_at=20
538+
)
539+
except ZeroDivisionError:
540+
pass
541+
else:
542+
self.assertFalse("should have crashed")
543+
# Second big batch search
544+
Dnew, Inew = ivf_tools.big_batch_search(
545+
index, ds.get_queries(),
546+
k, method="knn_function",
547+
threaded=4,
548+
checkpoint=checkpoint, checkpoint_freq=4
549+
)
550+
self.assertLess((Inew != Iref).sum() / Iref.size, 1e-4)
551+
np.testing.assert_almost_equal(Dnew, Dref, decimal=4)
552+
finally:
553+
if os.path.exists(checkpoint):
554+
os.unlink(checkpoint)

0 commit comments

Comments
 (0)