@@ -256,9 +256,10 @@ def build_index_wrapper(self, codec_desc: CodecDescriptor):
256256 else :
257257 assert codec_desc .is_trained ()
258258
259- def train (
259+ def train_one (
260260 self , codec_desc : CodecDescriptor , results : Dict [str , Any ], dry_run = False
261261 ):
262+ faiss .omp_set_num_threads (codec_desc .num_threads )
262263 self .build_index_wrapper (codec_desc )
263264 if codec_desc .is_trained ():
264265 return results , None
@@ -274,6 +275,16 @@ def train(
274275 results ["indices" ][codec_desc .get_name ()] = meta
275276 return results , requires
276277
278+ def train (self , results , dry_run = False ):
279+ for desc in self .codec_descs :
280+ results , requires = self .train_one (desc , results , dry_run = dry_run )
281+ if dry_run :
282+ if requires is None :
283+ continue
284+ return results , requires
285+ assert requires is None
286+ return results , None
287+
277288
278289@dataclass
279290class BuildOperator (IndexOperator ):
@@ -322,17 +333,25 @@ def build_index_wrapper(self, index_desc: IndexDescriptor):
322333 else :
323334 assert index_desc .is_built ()
324335
325- def build (self , index_desc : IndexDescriptor , results : Dict [str , Any ]):
336+ def build_one (self , index_desc : IndexDescriptor , results : Dict [str , Any ]):
337+ faiss .omp_set_num_threads (index_desc .num_threads )
326338 self .build_index_wrapper (index_desc )
327339 if index_desc .is_built ():
328340 return
329341 index_desc .index .get_index ()
330342
343+ def build (self , results : Dict [str , Any ]):
344+ # TODO: add support for dry_run
345+ for index_desc in self .index_descs :
346+ self .build_one (index_desc , results )
347+ return results , None
348+
331349
332350@dataclass
333351class SearchOperator (IndexOperator ):
334352 knn_descs : List [KnnDescriptor ] = field (default_factory = lambda : [])
335353 range : bool = False
354+ compute_gt : bool = True
336355
337356 def get_desc (self , name : str ) -> Optional [KnnDescriptor ]:
338357 for desc in self .knn_descs :
@@ -655,85 +674,16 @@ def range_search_benchmark(
655674 index = index ,
656675 )
657676
658-
659- @dataclass
660- class ExecutionOperator :
661- distance_metric : str = "L2"
662- num_threads : int = 1
663- train_op : Optional [TrainOperator ] = None
664- build_op : Optional [BuildOperator ] = None
665- search_op : Optional [SearchOperator ] = None
666- compute_gt : bool = True
667-
668- def __post_init__ (self ):
669- if self .distance_metric == "IP" :
670- self .distance_metric_type = faiss .METRIC_INNER_PRODUCT
671- elif self .distance_metric == "L2" :
672- self .distance_metric_type = faiss .METRIC_L2
673- else :
674- raise ValueError
675-
676- def set_io (self , io : BenchmarkIO ):
677- self .io = io
678- self .io .distance_metric = self .distance_metric
679- self .io .distance_metric_type = self .distance_metric_type
680- if self .train_op :
681- self .train_op .set_io (io )
682- if self .build_op :
683- self .build_op .set_io (io )
684- if self .search_op :
685- self .search_op .set_io (io )
686-
687- def train_one (self , codec_desc : CodecDescriptor , results : Dict [str , Any ], dry_run ):
688- faiss .omp_set_num_threads (self .num_threads )
689- assert self .train_op is not None
690- self .train_op .train (codec_desc , results , dry_run )
691-
692- def train (self , results , dry_run = False ):
693- faiss .omp_set_num_threads (self .num_threads )
694- if self .train_op is None :
695- return
696-
697- for codec_desc in self .train_op .codec_descs :
698- self .train_one (codec_desc , results , dry_run )
699-
700- def build_one (self , results : Dict [str , Any ], index_desc : IndexDescriptor ):
701- faiss .omp_set_num_threads (self .num_threads )
702- assert self .build_op is not None
703- self .build_op .build (index_desc , results )
704-
705- def build (self , results : Dict [str , Any ]):
706- faiss .omp_set_num_threads (self .num_threads )
707- if self .build_op is None :
708- return
709-
710- for index_desc in self .build_op .index_descs :
711- self .build_one (index_desc , results )
712-
713- def search (self ):
714- faiss .omp_set_num_threads (self .num_threads )
715- if self .search_op is None :
716- return
717-
718- for index_desc in self .search_op .knn_descs :
719- self .search_one (index_desc )
720-
721677 def search_one (
722678 self ,
723679 knn_desc : KnnDescriptor ,
724680 results : Dict [str , Any ],
725681 dry_run = False ,
726682 range = False ,
727683 ):
728- faiss .omp_set_num_threads (self .num_threads )
729- assert self .search_op is not None
730-
731- if not dry_run and self .compute_gt :
732- self .create_gt_knn (knn_desc )
733- self .create_range_ref_knn (knn_desc )
734-
735- self .search_op .build_index_wrapper (knn_desc )
684+ faiss .omp_set_num_threads (knn_desc .num_threads )
736685
686+ self .build_index_wrapper (knn_desc )
737687 # results, requires = self.reconstruct_benchmark(
738688 # dry_run=True,
739689 # results=results,
@@ -749,7 +699,7 @@ def search_one(
749699 # index=index_desc.index,
750700 # )
751701 # assert requires is None
752- results , requires = self .search_op . knn_search_benchmark (
702+ results , requires = self .knn_search_benchmark (
753703 dry_run = True ,
754704 results = results ,
755705 knn_desc = knn_desc ,
@@ -758,7 +708,7 @@ def search_one(
758708 if dry_run :
759709 return results , requires
760710 else :
761- results , requires = self .search_op . knn_search_benchmark (
711+ results , requires = self .knn_search_benchmark (
762712 dry_run = False ,
763713 results = results ,
764714 knn_desc = knn_desc ,
@@ -771,7 +721,7 @@ def search_one(
771721 ):
772722 return results , None
773723
774- ref_index_desc = self .search_op . get_desc (knn_desc .range_ref_index_desc )
724+ ref_index_desc = self .get_desc (knn_desc .range_ref_index_desc )
775725 if ref_index_desc is None :
776726 raise ValueError (
777727 f"{ knn_desc .get_name ()} : Unknown range index { knn_desc .range_ref_index_desc } "
@@ -786,17 +736,18 @@ def search_one(
786736 range_search_metric_function ,
787737 coefficients ,
788738 coefficients_training_data ,
789- ) = self .search_op . range_search_reference (
739+ ) = self .range_search_reference (
790740 ref_index_desc .index ,
791741 ref_index_desc .search_params ,
792742 range_metric ,
743+ query_dataset = knn_desc .query_dataset ,
793744 )
794745 gt_rsm = None
795746 if self .compute_gt :
796- gt_rsm = self .search_op . range_ground_truth (
747+ gt_rsm = self .range_ground_truth (
797748 gt_radius , range_search_metric_function
798749 )
799- results , requires = self .search_op . range_search_benchmark (
750+ results , requires = self .range_search_benchmark (
800751 dry_run = True ,
801752 results = results ,
802753 index = knn_desc .index ,
@@ -805,13 +756,13 @@ def search_one(
805756 gt_radius = gt_radius ,
806757 range_search_metric_function = range_search_metric_function ,
807758 gt_rsm = gt_rsm ,
808- query_vectors = knn_desc .query_dataset ,
759+ query_dataset = knn_desc .query_dataset ,
809760 )
810761 if range and requires is not None :
811762 if dry_run :
812763 return results , requires
813764 else :
814- results , requires = self .search_op . range_search_benchmark (
765+ results , requires = self .range_search_benchmark (
815766 dry_run = False ,
816767 results = results ,
817768 index = knn_desc .index ,
@@ -820,12 +771,62 @@ def search_one(
820771 gt_radius = gt_radius ,
821772 range_search_metric_function = range_search_metric_function ,
822773 gt_rsm = gt_rsm ,
823- query_vectors = knn_desc .query_dataset ,
774+ query_dataset = knn_desc .query_dataset ,
824775 )
825776 assert requires is None
826777
827778 return results , None
828779
780+ def search (
781+ self ,
782+ results : Dict [str , Any ],
783+ dry_run : bool = False ,):
784+ for knn_desc in self .knn_descs :
785+ results , requires = self .search_one (
786+ knn_desc = knn_desc ,
787+ results = results ,
788+ dry_run = dry_run ,
789+ range = self .range )
790+ if dry_run :
791+ if requires is None :
792+ continue
793+ return results , requires
794+
795+ assert requires is None
796+ return results , None
797+
798+
799+ @dataclass
800+ class ExecutionOperator :
801+ distance_metric : str = "L2"
802+ num_threads : int = 1
803+ train_op : Optional [TrainOperator ] = None
804+ build_op : Optional [BuildOperator ] = None
805+ search_op : Optional [SearchOperator ] = None
806+ compute_gt : bool = True
807+
808+ def __post_init__ (self ):
809+ if self .distance_metric == "IP" :
810+ self .distance_metric_type = faiss .METRIC_INNER_PRODUCT
811+ elif self .distance_metric == "L2" :
812+ self .distance_metric_type = faiss .METRIC_L2
813+ else :
814+ raise ValueError
815+
816+ if self .search_op is not None :
817+ self .search_op .compute_gt = self .compute_gt
818+
819+ def set_io (self , io : BenchmarkIO ):
820+ self .io = io
821+ self .io .distance_metric = self .distance_metric
822+ self .io .distance_metric_type = self .distance_metric_type
823+ if self .train_op :
824+ self .train_op .set_io (io )
825+ if self .build_op :
826+ self .build_op .set_io (io )
827+ if self .search_op :
828+ self .search_op .set_io (io )
829+
829830 def create_gt_codec (
830831 self , codec_desc , results , train = True
831832 ) -> Optional [CodecDescriptor ]:
@@ -841,7 +842,7 @@ def create_gt_codec(
841842 )
842843 self .train_op .codec_descs .insert (0 , gt_codec_desc )
843844 if train :
844- self .train_op .train (gt_codec_desc , results , dry_run = False )
845+ self .train_op .train_one (gt_codec_desc , results , dry_run = False )
845846
846847 return gt_codec_desc
847848
@@ -865,7 +866,7 @@ def create_gt_index(
865866 )
866867 self .build_op .index_descs .insert (0 , gt_index_desc )
867868 if build :
868- self .build_op .build (gt_index_desc , results )
869+ self .build_op .build_one (gt_index_desc , results )
869870
870871 return gt_index_desc
871872
@@ -906,7 +907,9 @@ def create_range_ref_knn(self, knn_desc):
906907 return
907908
908909 if knn_desc .range_ref_index_desc is not None :
909- ref_index_desc = self .get_desc (knn_desc .range_ref_index_desc )
910+ ref_index_desc = (
911+ self .search_op .get_desc (knn_desc .range_ref_index_desc )
912+ )
910913 if ref_index_desc is None :
911914 raise ValueError (f"Unknown range index { knn_desc .range_ref_index_desc } " )
912915 if ref_index_desc .range_metrics is None :
@@ -921,19 +924,20 @@ def create_range_ref_knn(self, knn_desc):
921924 range_search_metric_function ,
922925 coefficients ,
923926 coefficients_training_data ,
924- ) = self .range_search_reference (
927+ ) = self .search_op . range_search_reference (
925928 knn_desc .index , knn_desc .search_params , range_metric
926929 )
927930 results ["metrics" ][metric_key ] = {
928931 "coefficients" : coefficients ,
929932 "training_data" : coefficients_training_data ,
930933 }
931- knn_desc .gt_rsm = self .range_ground_truth (
934+ knn_desc .gt_rsm = self .search_op . range_ground_truth (
932935 knn_desc .gt_radius , range_search_metric_function
933936 )
934937
935938 def create_ground_truths (self , results : Dict [str , Any ]):
936- # TODO: Create all ground truth descriptors and put them in index descriptor as reference
939+ # TODO: Create all ground truth descriptors and
940+ # put them in index descriptor as reference
937941 if self .train_op is not None :
938942 for codec_desc in self .train_op .codec_descs :
939943 self .create_gt_codec (codec_desc , results )
@@ -949,33 +953,33 @@ def create_ground_truths(self, results: Dict[str, Any]):
949953 self .create_gt_knn (knn_desc , results )
950954 self .create_range_ref_knn (knn_desc )
951955
952- def execute (self , results : Dict [str , Any ], dry_run : False ):
956+ def prepare_gt_or_range_knn (self , results : Dict [str , Any ]):
957+ if self .search_op is not None :
958+ for knn_desc in self .search_op .knn_descs :
959+ self .create_gt_knn (knn_desc , results )
960+ self .create_range_ref_knn (knn_desc )
961+
962+ def execute (self , results : Dict [str , Any ], dry_run : bool = False ):
963+ faiss .omp_set_num_threads (self .num_threads )
953964 if self .train_op is not None :
954- for desc in self .train_op .codec_descs :
955- results , requires = self .train_op .train (desc , results , dry_run = dry_run )
956- if dry_run :
957- if requires is None :
958- continue
959- return results , requires
960- assert requires is None
965+ results , requires = (
966+ self .train_op .train (results = results , dry_run = dry_run )
967+ )
968+ if dry_run and requires :
969+ return results , requires
961970
962971 if self .build_op is not None :
963- for desc in self .build_op .index_descs :
964- self . build_op . build ( desc , results )
972+ self .build_op .build ( results )
973+
965974 if self .search_op is not None :
966- for desc in self .search_op .knn_descs :
967- results , requires = self .search_one (
968- knn_desc = desc ,
969- results = results ,
970- dry_run = dry_run ,
971- range = self .search_op .range ,
972- )
973- if dry_run :
974- if requires is None :
975- continue
976- return results , requires
975+ if not dry_run and self .compute_gt :
976+ self .prepare_gt_or_range_knn (results )
977977
978- assert requires is None
978+ results , requires = (
979+ self .search_op .search (results = results , dry_run = dry_run )
980+ )
981+ if dry_run and requires :
982+ return results , requires
979983 return results , None
980984
981985 def execute_2 (self , result_file = None ):
0 commit comments