44# LICENSE file in the root directory of this source tree.
55
66import time
7+ import pickle
8+ import os
79from multiprocessing .pool import ThreadPool
810import threading
911
@@ -241,6 +243,28 @@ def add_results_to_heap(self, q_subset, D, list_ids, I):
241243 self .rh .add_result_subset (q_subset , D , I )
242244 self .t_accu [3 ] += time .time () - t0
243245
246+ def sizes_in_checkpoint (self ):
247+ return (self .xq .shape , self .index .nprobe , self .index .nlist )
248+
249+ def write_checkpoint (self , fname , cur_list_no ):
250+ # write to temp file then move to final file
251+ tmpname = fname + ".tmp"
252+ pickle .dump (
253+ {
254+ "sizes" : self .sizes_in_checkpoint (),
255+ "cur_list_no" : cur_list_no ,
256+ "rh" : (self .rh .D , self .rh .I ),
257+ }, open (tmpname , "wb" ), - 1
258+ )
259+ os .rename (tmpname , fname )
260+
261+ def read_checkpoint (self , fname ):
262+ ckp = pickle .load (open (fname , "rb" ))
263+ assert ckp ["sizes" ] == self .sizes_in_checkpoint ()
264+ self .rh .D [:] = ckp ["rh" ][0 ]
265+ self .rh .I [:] = ckp ["rh" ][1 ]
266+ return ckp ["cur_list_no" ]
267+
244268
245269class BlockComputer :
246270 """ computation within one bucket """
@@ -308,7 +332,13 @@ def big_batch_search(
308332 use_float16 = False ,
309333 prefetch_threads = 8 ,
310334 computation_threads = 0 ,
311- q_assign = None ):
335+ q_assign = None ,
336+ checkpoint = None ,
337+ checkpoint_freq = 64 ,
338+ start_list = 0 ,
339+ end_list = None ,
340+ crash_at = - 1
341+ ):
312342 """
313343 Search queries xq in the IVF index, with a search function that collects
314344 batches of query vectors per inverted list. This can be faster than the
@@ -336,7 +366,13 @@ def big_batch_search(
336366
337367 use_float16: convert all matrices to float16 (faster for GPU gemm)
338368
339- q_assign: override coarse assignment
369+ q_assign: override coarse assignment, should be a matrix of size nq * nprobe
370+
371+ checkpointing (only for threaded > 1):
372+ checkpoint: file where the checkpoints are stored
373+ checkpoint_freq: when to perform checkpoinging. Should be a multiple of threaded
374+
375+ start_list, end_list: process only a subset of invlists
340376 """
341377 nprobe = index .nprobe
342378
@@ -377,10 +413,22 @@ def big_batch_search(
377413 bbs .q_assign = q_assign
378414 bbs .reorder_assign ()
379415
416+ if end_list is None :
417+ end_list = index .nlist
418+
419+ if checkpoint is not None :
420+ assert (start_list , end_list ) == (0 , index .nlist )
421+ if os .path .exists (checkpoint ):
422+ print ("recovering checkpoint" , checkpoint )
423+ start_list = bbs .read_checkpoint (checkpoint )
424+ print (" start at list" , start_list )
425+ else :
426+ print ("no checkpoint: starting from scratch" )
427+
380428 if threaded == 0 :
381429 # simple sequential version
382430
383- for l in range (index . nlist ):
431+ for l in range (start_list , end_list ):
384432 bbs .report (l )
385433 q_subset , xq_l , list_ids , xb_l = bbs .prepare_bucket (l )
386434 t0i = time .time ()
@@ -400,11 +448,11 @@ def add_results_and_prefetch(to_add, l):
400448 if l < index .nlist :
401449 return bbs .prepare_bucket (l )
402450
403- prefetched_bucket = bbs .prepare_bucket (0 )
451+ prefetched_bucket = bbs .prepare_bucket (start_list )
404452 to_add = None
405453 pool = ThreadPool (1 )
406454
407- for l in range (index . nlist ):
455+ for l in range (start_list , end_list ):
408456 bbs .report (l )
409457 prefetched_bucket_a = pool .apply_async (
410458 add_results_and_prefetch , (to_add , l + 1 ))
@@ -422,6 +470,7 @@ def add_results_and_prefetch(to_add, l):
422470 else :
423471 # run by batches with parallel prefetch and parallel comp
424472 list_step = threaded
473+ assert start_list % list_step == 0
425474
426475 if prefetch_threads == 0 :
427476 prefetch_map = map
@@ -462,13 +511,13 @@ def do_comp(bucket):
462511 D , I = comp .block_search (xq_l , xb_l , list_ids , k , thread_id = tid )
463512 return q_subset , D , list_ids , I
464513
465- prefetched_buckets = add_results_and_prefetch_batch ([], 0 )
514+ prefetched_buckets = add_results_and_prefetch_batch ([], start_list )
466515 to_add = []
467516 pool = ThreadPool (1 )
468517 prefetched_buckets_a = None
469518
470519 # loop over inverted lists
471- for l in range (0 , index . nlist , list_step ):
520+ for l in range (start_list , end_list , list_step ):
472521 bbs .report (l )
473522 buckets = prefetched_buckets
474523 prefetched_buckets_a = pool .apply_async (
@@ -486,10 +535,19 @@ def do_comp(bucket):
486535
487536 bbs .stop_t_accu (2 )
488537
538+ # to test checkpointing
539+ if l == crash_at :
540+ 1 / 0
541+
489542 bbs .start_t_accu ()
490543 prefetched_buckets = prefetched_buckets_a .get ()
491544 bbs .stop_t_accu (4 )
492545
546+ if checkpoint is not None :
547+ if (l // list_step ) % checkpoint_freq == 0 :
548+ print ("writing checkpoint %s" % l )
549+ bbs .write_checkpoint (checkpoint , l )
550+
493551 # flush add
494552 for ta in to_add :
495553 bbs .add_results_to_heap (* ta )
0 commit comments