From a1a889df28d14e002c7e80c90e32770d3b61068b Mon Sep 17 00:00:00 2001 From: Marco Wrzalik Date: Fri, 2 Jul 2021 15:39:24 +0200 Subject: [PATCH 1/2] faster search_batch for ElasticsearchIndex due to threading --- src/datasets/search.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/datasets/search.py b/src/datasets/search.py index 699bf9228fd..fa6508b3aaa 100644 --- a/src/datasets/search.py +++ b/src/datasets/search.py @@ -186,6 +186,19 @@ def search(self, query: str, k=10) -> SearchResults: hits = response["hits"]["hits"] return SearchResults([hit["_score"] for hit in hits], [int(hit["_id"]) for hit in hits]) + def search_batch(self, queries, k: int = 10, max_workers=10) -> BatchedSearchResults: + import concurrent.futures + total_scores, total_indices = [None]*len(queries), [None]*len(queries) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_index = {executor.submit(self.search, query, k): i + for i, query in enumerate(queries)} + for future in concurrent.futures.as_completed(future_to_index): + index = future_to_index[future] + results:SearchResults = future.result() + total_scores[index] = results.scores + total_indices[index] = results.indices + return BatchedSearchResults(total_indices=total_indices, total_scores=total_scores) + class FaissIndex(BaseIndex): """ From 6e149a1e9f860243688796aed4ecebe1b574ab1e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Mon, 12 Jul 2021 11:01:05 +0200 Subject: [PATCH 2/2] style --- src/datasets/search.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/datasets/search.py b/src/datasets/search.py index fa6508b3aaa..d4cd28a9bfc 100644 --- a/src/datasets/search.py +++ b/src/datasets/search.py @@ -188,13 +188,13 @@ def search(self, query: str, k=10) -> SearchResults: def search_batch(self, queries, k: int = 10, max_workers=10) -> BatchedSearchResults: import concurrent.futures - total_scores, total_indices = [None]*len(queries), [None]*len(queries) + + total_scores, total_indices = [None] * len(queries), [None] * len(queries) with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_index = {executor.submit(self.search, query, k): i - for i, query in enumerate(queries)} + future_to_index = {executor.submit(self.search, query, k): i for i, query in enumerate(queries)} for future in concurrent.futures.as_completed(future_to_index): index = future_to_index[future] - results:SearchResults = future.result() + results: SearchResults = future.result() total_scores[index] = results.scores total_indices[index] = results.indices return BatchedSearchResults(total_indices=total_indices, total_scores=total_scores)