Skip to content

Commit 61eaf19

Browse files
kuarorafacebook-github-bot
authored andcommitted
Move train, build and search to their respective operators (#3934)
Summary: Pull Request resolved: #3934 Initial thought was to be able to call individual operations on execution operator but it make sense to keep single interface 'execute' and move all these implementations to respective operators. Reviewed By: satymish Differential Revision: D63290104 fbshipit-source-id: d1f0b1391c38552c5cdb0a8ea935e23d0d0cb75b
1 parent d243e62 commit 61eaf19

1 file changed

Lines changed: 114 additions & 110 deletions

File tree

benchs/bench_fw/benchmark.py

Lines changed: 114 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,10 @@ def build_index_wrapper(self, codec_desc: CodecDescriptor):
256256
else:
257257
assert codec_desc.is_trained()
258258

259-
def train(
259+
def train_one(
260260
self, codec_desc: CodecDescriptor, results: Dict[str, Any], dry_run=False
261261
):
262+
faiss.omp_set_num_threads(codec_desc.num_threads)
262263
self.build_index_wrapper(codec_desc)
263264
if codec_desc.is_trained():
264265
return results, None
@@ -274,6 +275,16 @@ def train(
274275
results["indices"][codec_desc.get_name()] = meta
275276
return results, requires
276277

278+
def train(self, results, dry_run=False):
279+
for desc in self.codec_descs:
280+
results, requires = self.train_one(desc, results, dry_run=dry_run)
281+
if dry_run:
282+
if requires is None:
283+
continue
284+
return results, requires
285+
assert requires is None
286+
return results, None
287+
277288

278289
@dataclass
279290
class BuildOperator(IndexOperator):
@@ -322,17 +333,25 @@ def build_index_wrapper(self, index_desc: IndexDescriptor):
322333
else:
323334
assert index_desc.is_built()
324335

325-
def build(self, index_desc: IndexDescriptor, results: Dict[str, Any]):
336+
def build_one(self, index_desc: IndexDescriptor, results: Dict[str, Any]):
337+
faiss.omp_set_num_threads(index_desc.num_threads)
326338
self.build_index_wrapper(index_desc)
327339
if index_desc.is_built():
328340
return
329341
index_desc.index.get_index()
330342

343+
def build(self, results: Dict[str, Any]):
344+
# TODO: add support for dry_run
345+
for index_desc in self.index_descs:
346+
self.build_one(index_desc, results)
347+
return results, None
348+
331349

332350
@dataclass
333351
class SearchOperator(IndexOperator):
334352
knn_descs: List[KnnDescriptor] = field(default_factory=lambda: [])
335353
range: bool = False
354+
compute_gt: bool = True
336355

337356
def get_desc(self, name: str) -> Optional[KnnDescriptor]:
338357
for desc in self.knn_descs:
@@ -655,85 +674,16 @@ def range_search_benchmark(
655674
index=index,
656675
)
657676

658-
659-
@dataclass
660-
class ExecutionOperator:
661-
distance_metric: str = "L2"
662-
num_threads: int = 1
663-
train_op: Optional[TrainOperator] = None
664-
build_op: Optional[BuildOperator] = None
665-
search_op: Optional[SearchOperator] = None
666-
compute_gt: bool = True
667-
668-
def __post_init__(self):
669-
if self.distance_metric == "IP":
670-
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
671-
elif self.distance_metric == "L2":
672-
self.distance_metric_type = faiss.METRIC_L2
673-
else:
674-
raise ValueError
675-
676-
def set_io(self, io: BenchmarkIO):
677-
self.io = io
678-
self.io.distance_metric = self.distance_metric
679-
self.io.distance_metric_type = self.distance_metric_type
680-
if self.train_op:
681-
self.train_op.set_io(io)
682-
if self.build_op:
683-
self.build_op.set_io(io)
684-
if self.search_op:
685-
self.search_op.set_io(io)
686-
687-
def train_one(self, codec_desc: CodecDescriptor, results: Dict[str, Any], dry_run):
688-
faiss.omp_set_num_threads(self.num_threads)
689-
assert self.train_op is not None
690-
self.train_op.train(codec_desc, results, dry_run)
691-
692-
def train(self, results, dry_run=False):
693-
faiss.omp_set_num_threads(self.num_threads)
694-
if self.train_op is None:
695-
return
696-
697-
for codec_desc in self.train_op.codec_descs:
698-
self.train_one(codec_desc, results, dry_run)
699-
700-
def build_one(self, results: Dict[str, Any], index_desc: IndexDescriptor):
701-
faiss.omp_set_num_threads(self.num_threads)
702-
assert self.build_op is not None
703-
self.build_op.build(index_desc, results)
704-
705-
def build(self, results: Dict[str, Any]):
706-
faiss.omp_set_num_threads(self.num_threads)
707-
if self.build_op is None:
708-
return
709-
710-
for index_desc in self.build_op.index_descs:
711-
self.build_one(index_desc, results)
712-
713-
def search(self):
714-
faiss.omp_set_num_threads(self.num_threads)
715-
if self.search_op is None:
716-
return
717-
718-
for index_desc in self.search_op.knn_descs:
719-
self.search_one(index_desc)
720-
721677
def search_one(
722678
self,
723679
knn_desc: KnnDescriptor,
724680
results: Dict[str, Any],
725681
dry_run=False,
726682
range=False,
727683
):
728-
faiss.omp_set_num_threads(self.num_threads)
729-
assert self.search_op is not None
730-
731-
if not dry_run and self.compute_gt:
732-
self.create_gt_knn(knn_desc)
733-
self.create_range_ref_knn(knn_desc)
734-
735-
self.search_op.build_index_wrapper(knn_desc)
684+
faiss.omp_set_num_threads(knn_desc.num_threads)
736685

686+
self.build_index_wrapper(knn_desc)
737687
# results, requires = self.reconstruct_benchmark(
738688
# dry_run=True,
739689
# results=results,
@@ -749,7 +699,7 @@ def search_one(
749699
# index=index_desc.index,
750700
# )
751701
# assert requires is None
752-
results, requires = self.search_op.knn_search_benchmark(
702+
results, requires = self.knn_search_benchmark(
753703
dry_run=True,
754704
results=results,
755705
knn_desc=knn_desc,
@@ -758,7 +708,7 @@ def search_one(
758708
if dry_run:
759709
return results, requires
760710
else:
761-
results, requires = self.search_op.knn_search_benchmark(
711+
results, requires = self.knn_search_benchmark(
762712
dry_run=False,
763713
results=results,
764714
knn_desc=knn_desc,
@@ -771,7 +721,7 @@ def search_one(
771721
):
772722
return results, None
773723

774-
ref_index_desc = self.search_op.get_desc(knn_desc.range_ref_index_desc)
724+
ref_index_desc = self.get_desc(knn_desc.range_ref_index_desc)
775725
if ref_index_desc is None:
776726
raise ValueError(
777727
f"{knn_desc.get_name()}: Unknown range index {knn_desc.range_ref_index_desc}"
@@ -786,17 +736,18 @@ def search_one(
786736
range_search_metric_function,
787737
coefficients,
788738
coefficients_training_data,
789-
) = self.search_op.range_search_reference(
739+
) = self.range_search_reference(
790740
ref_index_desc.index,
791741
ref_index_desc.search_params,
792742
range_metric,
743+
query_dataset=knn_desc.query_dataset,
793744
)
794745
gt_rsm = None
795746
if self.compute_gt:
796-
gt_rsm = self.search_op.range_ground_truth(
747+
gt_rsm = self.range_ground_truth(
797748
gt_radius, range_search_metric_function
798749
)
799-
results, requires = self.search_op.range_search_benchmark(
750+
results, requires = self.range_search_benchmark(
800751
dry_run=True,
801752
results=results,
802753
index=knn_desc.index,
@@ -805,13 +756,13 @@ def search_one(
805756
gt_radius=gt_radius,
806757
range_search_metric_function=range_search_metric_function,
807758
gt_rsm=gt_rsm,
808-
query_vectors=knn_desc.query_dataset,
759+
query_dataset=knn_desc.query_dataset,
809760
)
810761
if range and requires is not None:
811762
if dry_run:
812763
return results, requires
813764
else:
814-
results, requires = self.search_op.range_search_benchmark(
765+
results, requires = self.range_search_benchmark(
815766
dry_run=False,
816767
results=results,
817768
index=knn_desc.index,
@@ -820,12 +771,62 @@ def search_one(
820771
gt_radius=gt_radius,
821772
range_search_metric_function=range_search_metric_function,
822773
gt_rsm=gt_rsm,
823-
query_vectors=knn_desc.query_dataset,
774+
query_dataset=knn_desc.query_dataset,
824775
)
825776
assert requires is None
826777

827778
return results, None
828779

780+
def search(
781+
self,
782+
results: Dict[str, Any],
783+
dry_run: bool = False,):
784+
for knn_desc in self.knn_descs:
785+
results, requires = self.search_one(
786+
knn_desc=knn_desc,
787+
results=results,
788+
dry_run=dry_run,
789+
range=self.range)
790+
if dry_run:
791+
if requires is None:
792+
continue
793+
return results, requires
794+
795+
assert requires is None
796+
return results, None
797+
798+
799+
@dataclass
800+
class ExecutionOperator:
801+
distance_metric: str = "L2"
802+
num_threads: int = 1
803+
train_op: Optional[TrainOperator] = None
804+
build_op: Optional[BuildOperator] = None
805+
search_op: Optional[SearchOperator] = None
806+
compute_gt: bool = True
807+
808+
def __post_init__(self):
809+
if self.distance_metric == "IP":
810+
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
811+
elif self.distance_metric == "L2":
812+
self.distance_metric_type = faiss.METRIC_L2
813+
else:
814+
raise ValueError
815+
816+
if self.search_op is not None:
817+
self.search_op.compute_gt = self.compute_gt
818+
819+
def set_io(self, io: BenchmarkIO):
820+
self.io = io
821+
self.io.distance_metric = self.distance_metric
822+
self.io.distance_metric_type = self.distance_metric_type
823+
if self.train_op:
824+
self.train_op.set_io(io)
825+
if self.build_op:
826+
self.build_op.set_io(io)
827+
if self.search_op:
828+
self.search_op.set_io(io)
829+
829830
def create_gt_codec(
830831
self, codec_desc, results, train=True
831832
) -> Optional[CodecDescriptor]:
@@ -841,7 +842,7 @@ def create_gt_codec(
841842
)
842843
self.train_op.codec_descs.insert(0, gt_codec_desc)
843844
if train:
844-
self.train_op.train(gt_codec_desc, results, dry_run=False)
845+
self.train_op.train_one(gt_codec_desc, results, dry_run=False)
845846

846847
return gt_codec_desc
847848

@@ -865,7 +866,7 @@ def create_gt_index(
865866
)
866867
self.build_op.index_descs.insert(0, gt_index_desc)
867868
if build:
868-
self.build_op.build(gt_index_desc, results)
869+
self.build_op.build_one(gt_index_desc, results)
869870

870871
return gt_index_desc
871872

@@ -906,7 +907,9 @@ def create_range_ref_knn(self, knn_desc):
906907
return
907908

908909
if knn_desc.range_ref_index_desc is not None:
909-
ref_index_desc = self.get_desc(knn_desc.range_ref_index_desc)
910+
ref_index_desc = (
911+
self.search_op.get_desc(knn_desc.range_ref_index_desc)
912+
)
910913
if ref_index_desc is None:
911914
raise ValueError(f"Unknown range index {knn_desc.range_ref_index_desc}")
912915
if ref_index_desc.range_metrics is None:
@@ -921,19 +924,20 @@ def create_range_ref_knn(self, knn_desc):
921924
range_search_metric_function,
922925
coefficients,
923926
coefficients_training_data,
924-
) = self.range_search_reference(
927+
) = self.search_op.range_search_reference(
925928
knn_desc.index, knn_desc.search_params, range_metric
926929
)
927930
results["metrics"][metric_key] = {
928931
"coefficients": coefficients,
929932
"training_data": coefficients_training_data,
930933
}
931-
knn_desc.gt_rsm = self.range_ground_truth(
934+
knn_desc.gt_rsm = self.search_op.range_ground_truth(
932935
knn_desc.gt_radius, range_search_metric_function
933936
)
934937

935938
def create_ground_truths(self, results: Dict[str, Any]):
936-
# TODO: Create all ground truth descriptors and put them in index descriptor as reference
939+
# TODO: Create all ground truth descriptors and
940+
# put them in index descriptor as reference
937941
if self.train_op is not None:
938942
for codec_desc in self.train_op.codec_descs:
939943
self.create_gt_codec(codec_desc, results)
@@ -949,33 +953,33 @@ def create_ground_truths(self, results: Dict[str, Any]):
949953
self.create_gt_knn(knn_desc, results)
950954
self.create_range_ref_knn(knn_desc)
951955

952-
def execute(self, results: Dict[str, Any], dry_run: False):
956+
def prepare_gt_or_range_knn(self, results: Dict[str, Any]):
957+
if self.search_op is not None:
958+
for knn_desc in self.search_op.knn_descs:
959+
self.create_gt_knn(knn_desc, results)
960+
self.create_range_ref_knn(knn_desc)
961+
962+
def execute(self, results: Dict[str, Any], dry_run: bool = False):
963+
faiss.omp_set_num_threads(self.num_threads)
953964
if self.train_op is not None:
954-
for desc in self.train_op.codec_descs:
955-
results, requires = self.train_op.train(desc, results, dry_run=dry_run)
956-
if dry_run:
957-
if requires is None:
958-
continue
959-
return results, requires
960-
assert requires is None
965+
results, requires = (
966+
self.train_op.train(results=results, dry_run=dry_run)
967+
)
968+
if dry_run and requires:
969+
return results, requires
961970

962971
if self.build_op is not None:
963-
for desc in self.build_op.index_descs:
964-
self.build_op.build(desc, results)
972+
self.build_op.build(results)
973+
965974
if self.search_op is not None:
966-
for desc in self.search_op.knn_descs:
967-
results, requires = self.search_one(
968-
knn_desc=desc,
969-
results=results,
970-
dry_run=dry_run,
971-
range=self.search_op.range,
972-
)
973-
if dry_run:
974-
if requires is None:
975-
continue
976-
return results, requires
975+
if not dry_run and self.compute_gt:
976+
self.prepare_gt_or_range_knn(results)
977977

978-
assert requires is None
978+
results, requires = (
979+
self.search_op.search(results=results, dry_run=dry_run)
980+
)
981+
if dry_run and requires:
982+
return results, requires
979983
return results, None
980984

981985
def execute_2(self, result_file=None):

0 commit comments

Comments
 (0)