diff --git a/benchs/bench_fw/descriptors.py b/benchs/bench_fw/descriptors.py index 8b1d65a505..212e643426 100644 --- a/benchs/bench_fw/descriptors.py +++ b/benchs/bench_fw/descriptors.py @@ -83,6 +83,9 @@ class DatasetDescriptor: embedding_column: Optional[str] = None + # only when the embedding column is a map + embedding_column_key: Optional[Any] = None + embedding_id_column: Optional[str] = None # unused in open-source @@ -106,6 +109,8 @@ class DatasetDescriptor: # desc_name desc_name: Optional[str] = None + normalize_L2: bool = False + def __hash__(self): return hash(self.get_filename()) diff --git a/benchs/bench_fw/index.py b/benchs/bench_fw/index.py index fe2fe103ef..b1252ad1b0 100644 --- a/benchs/bench_fw/index.py +++ b/benchs/bench_fw/index.py @@ -1138,6 +1138,8 @@ def assemble(self, dry_run): return None, None, "" logger.info(f"assemble, train {self.factory}") xt = self.io.get_dataset(self.training_vectors) + if self.training_vectors.normalize_L2: + faiss.normalize_L2(xt) _, t, _ = timer("train", lambda: codec.train(xt), once=True) t_aggregate += t