77from copy import copy
88from dataclasses import dataclass
99from operator import itemgetter
10- from statistics import median , mean
10+ from statistics import mean , median
1111from typing import Any , Dict , List , Optional
1212
13- from .utils import dict_merge
14- from .index import Index , IndexFromCodec , IndexFromFactory
15- from .descriptors import DatasetDescriptor , IndexDescriptor
16-
1713import faiss # @manual=//faiss/python:pyfaiss_gpu
1814
1915import numpy as np
2016
2117from scipy .optimize import curve_fit
2218
19+ from .descriptors import DatasetDescriptor , IndexDescriptor
20+ from .index import Index , IndexFromCodec , IndexFromFactory
21+
22+ from .utils import dict_merge
23+
2324logger = logging .getLogger (__name__ )
2425
2526
@@ -274,8 +275,8 @@ def range_search(
274275 search_parameters : Optional [Dict [str , int ]],
275276 radius : Optional [float ] = None ,
276277 gt_radius : Optional [float ] = None ,
277- range_search_metric_function = None ,
278- gt_rsm = None ,
278+ range_search_metric_function = None ,
279+ gt_rsm = None ,
279280 ):
280281 logger .info ("range_search: begin" )
281282 if radius is None :
@@ -328,7 +329,13 @@ def knn_ground_truth(self):
328329 logger .info ("knn_ground_truth: begin" )
329330 flat_desc = self .get_index_desc ("Flat" )
330331 self .build_index_wrapper (flat_desc )
331- self .gt_knn_D , self .gt_knn_I , _ , _ , requires = flat_desc .index .knn_search (
332+ (
333+ self .gt_knn_D ,
334+ self .gt_knn_I ,
335+ _ ,
336+ _ ,
337+ requires ,
338+ ) = flat_desc .index .knn_search (
332339 dry_run = False ,
333340 search_parameters = None ,
334341 query_vectors = self .query_vectors ,
@@ -338,13 +345,13 @@ def knn_ground_truth(self):
338345 logger .info ("knn_ground_truth: end" )
339346
340347 def search_benchmark (
341- self ,
348+ self ,
342349 name ,
343350 search_func ,
344351 key_func ,
345352 cost_metrics ,
346353 perf_metrics ,
347- results : Dict [str , Any ],
354+ results : Dict [str , Any ],
348355 index : Index ,
349356 ):
350357 index_name = index .get_index_name ()
@@ -376,11 +383,18 @@ def experiment(parameters, cost_metric, perf_metric):
376383 logger .info (f"{ name } _benchmark: end" )
377384 return results , requires
378385
379- def knn_search_benchmark (self , dry_run , results : Dict [str , Any ], index : Index ):
386+ def knn_search_benchmark (
387+ self , dry_run , results : Dict [str , Any ], index : Index
388+ ):
380389 return self .search_benchmark (
381390 name = "knn_search" ,
382391 search_func = lambda parameters : index .knn_search (
383- dry_run , parameters , self .query_vectors , self .k , self .gt_knn_I , self .gt_knn_D ,
392+ dry_run ,
393+ parameters ,
394+ self .query_vectors ,
395+ self .k ,
396+ self .gt_knn_I ,
397+ self .gt_knn_D ,
384398 )[3 :],
385399 key_func = lambda parameters : index .get_knn_search_name (
386400 search_parameters = parameters ,
@@ -394,11 +408,17 @@ def knn_search_benchmark(self, dry_run, results: Dict[str, Any], index: Index):
394408 index = index ,
395409 )
396410
397- def reconstruct_benchmark (self , dry_run , results : Dict [str , Any ], index : Index ):
411+ def reconstruct_benchmark (
412+ self , dry_run , results : Dict [str , Any ], index : Index
413+ ):
398414 return self .search_benchmark (
399415 name = "reconstruct" ,
400416 search_func = lambda parameters : index .reconstruct (
401- dry_run , parameters , self .query_vectors , self .k , self .gt_knn_I ,
417+ dry_run ,
418+ parameters ,
419+ self .query_vectors ,
420+ self .k ,
421+ self .gt_knn_I ,
402422 ),
403423 key_func = lambda parameters : index .get_knn_search_name (
404424 search_parameters = parameters ,
@@ -426,31 +446,33 @@ def range_search_benchmark(
426446 return self .search_benchmark (
427447 name = "range_search" ,
428448 search_func = lambda parameters : self .range_search (
429- dry_run = dry_run ,
430- index = index ,
431- search_parameters = parameters ,
449+ dry_run = dry_run ,
450+ index = index ,
451+ search_parameters = parameters ,
432452 radius = radius ,
433453 gt_radius = gt_radius ,
434- range_search_metric_function = range_search_metric_function ,
454+ range_search_metric_function = range_search_metric_function ,
435455 gt_rsm = gt_rsm ,
436456 )[4 :],
437457 key_func = lambda parameters : index .get_range_search_name (
438458 search_parameters = parameters ,
439459 query_vectors = self .query_vectors ,
440460 radius = radius ,
441- ) + metric_key ,
461+ )
462+ + metric_key ,
442463 cost_metrics = ["time" ],
443464 perf_metrics = ["range_score_max_recall" ],
444465 results = results ,
445466 index = index ,
446467 )
447468
448469 def build_index_wrapper (self , index_desc : IndexDescriptor ):
449- if hasattr (index_desc , ' index' ):
470+ if hasattr (index_desc , " index" ):
450471 return
451472 if index_desc .factory is not None :
452473 training_vectors = copy (self .training_vectors )
453- training_vectors .num_vectors = index_desc .training_size
474+ if index_desc .training_size is not None :
475+ training_vectors .num_vectors = index_desc .training_size
454476 index = IndexFromFactory (
455477 num_threads = self .num_threads ,
456478 d = self .d ,
@@ -481,15 +503,24 @@ def clone_one(self, index_desc):
481503 training_vectors = self .training_vectors ,
482504 database_vectors = self .database_vectors ,
483505 query_vectors = self .query_vectors ,
484- index_descs = [self .get_index_desc ("Flat" ), index_desc ],
506+ index_descs = [self .get_index_desc ("Flat" ), index_desc ],
485507 range_ref_index_desc = self .range_ref_index_desc ,
486508 k = self .k ,
487509 distance_metric = self .distance_metric ,
488510 )
489- benchmark .set_io (self .io )
511+ benchmark .set_io (self .io . clone () )
490512 return benchmark
491513
492- def benchmark_one (self , dry_run , results : Dict [str , Any ], index_desc : IndexDescriptor , train , reconstruct , knn , range ):
514+ def benchmark_one (
515+ self ,
516+ dry_run ,
517+ results : Dict [str , Any ],
518+ index_desc : IndexDescriptor ,
519+ train ,
520+ reconstruct ,
521+ knn ,
522+ range ,
523+ ):
493524 faiss .omp_set_num_threads (self .num_threads )
494525 if not dry_run :
495526 self .knn_ground_truth ()
@@ -531,9 +562,12 @@ def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescr
531562 )
532563 assert requires is None
533564
534- if self .range_ref_index_desc is None or not index_desc .index .supports_range_search ():
565+ if (
566+ self .range_ref_index_desc is None
567+ or not index_desc .index .supports_range_search ()
568+ ):
535569 return results , None
536-
570+
537571 ref_index_desc = self .get_index_desc (self .range_ref_index_desc )
538572 if ref_index_desc is None :
539573 raise ValueError (
@@ -550,7 +584,9 @@ def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescr
550584 coefficients ,
551585 coefficients_training_data ,
552586 ) = self .range_search_reference (
553- ref_index_desc .index , ref_index_desc .search_params , range_metric
587+ ref_index_desc .index ,
588+ ref_index_desc .search_params ,
589+ range_metric ,
554590 )
555591 gt_rsm = self .range_ground_truth (
556592 gt_radius , range_search_metric_function
@@ -583,7 +619,15 @@ def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescr
583619
584620 return results , None
585621
586- def benchmark (self , result_file = None , local = False , train = False , reconstruct = False , knn = False , range = False ):
622+ def benchmark (
623+ self ,
624+ result_file = None ,
625+ local = False ,
626+ train = False ,
627+ reconstruct = False ,
628+ knn = False ,
629+ range = False ,
630+ ):
587631 logger .info ("begin evaluate" )
588632
589633 faiss .omp_set_num_threads (self .num_threads )
@@ -656,20 +700,34 @@ def benchmark(self, result_file=None, local=False, train=False, reconstruct=Fals
656700
657701 if current_todo :
658702 results_one = {"indices" : {}, "experiments" : {}}
659- params = [(self .clone_one (index_desc ), results_one , index_desc , train , reconstruct , knn , range ) for index_desc in current_todo ]
660- for result in self .io .launch_jobs (run_benchmark_one , params , local = local ):
703+ params = [
704+ (
705+ index_desc ,
706+ self .clone_one (index_desc ),
707+ results_one ,
708+ train ,
709+ reconstruct ,
710+ knn ,
711+ range ,
712+ )
713+ for index_desc in current_todo
714+ ]
715+ for result in self .io .launch_jobs (
716+ run_benchmark_one , params , local = local
717+ ):
661718 dict_merge (results , result )
662719
663- todo = next_todo
720+ todo = next_todo
664721
665722 if result_file is not None :
666723 self .io .write_json (results , result_file , overwrite = True )
667724 logger .info ("end evaluate" )
668725 return results
669726
727+
670728def run_benchmark_one (params ):
671729 logger .info (params )
672- benchmark , results , index_desc , train , reconstruct , knn , range = params
730+ index_desc , benchmark , results , train , reconstruct , knn , range = params
673731 results , requires = benchmark .benchmark_one (
674732 dry_run = False ,
675733 results = results ,
0 commit comments