diff --git a/src/datasets/search.py b/src/datasets/search.py index 699bf9228fd..d4cd28a9bfc 100644 --- a/src/datasets/search.py +++ b/src/datasets/search.py @@ -186,6 +186,19 @@ def search(self, query: str, k=10) -> SearchResults: hits = response["hits"]["hits"] return SearchResults([hit["_score"] for hit in hits], [int(hit["_id"]) for hit in hits]) + def search_batch(self, queries, k: int = 10, max_workers=10) -> BatchedSearchResults: + import concurrent.futures + + total_scores, total_indices = [None] * len(queries), [None] * len(queries) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_index = {executor.submit(self.search, query, k): i for i, query in enumerate(queries)} + for future in concurrent.futures.as_completed(future_to_index): + index = future_to_index[future] + results: SearchResults = future.result() + total_scores[index] = results.scores + total_indices[index] = results.indices + return BatchedSearchResults(total_indices=total_indices, total_scores=total_scores) + class FaissIndex(BaseIndex): """