From 6bcc00b006d8e3c01f3b25cad1e72b46b5b7d8d5 Mon Sep 17 00:00:00 2001 From: altescy Date: Tue, 31 Jan 2023 12:13:51 +0900 Subject: [PATCH 1/2] support batch search --- tinysearch/tinysearch.py | 54 ++++++++++++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 7 deletions(-) diff --git a/tinysearch/tinysearch.py b/tinysearch/tinysearch.py index bbaf035..174b63a 100644 --- a/tinysearch/tinysearch.py +++ b/tinysearch/tinysearch.py @@ -52,19 +52,59 @@ def search( ) -> List[Tuple[Document, float]]: ... + @overload def search( self, - query: str, + query: List[str], + *, + topk: Optional[int] = ..., + ) -> List[List[Document]]: + ... + + @overload + def search( + self, + query: List[str], + *, + return_scores: Literal[False], + topk: Optional[int] = ..., + ) -> List[List[Document]]: + ... + + @overload + def search( + self, + query: List[str], + *, + return_scores: Literal[True], + topk: Optional[int] = ..., + ) -> List[List[Tuple[Document, float]]]: + ... + + def search( + self, + query: Union[str, List[str]], *, return_scores: bool = False, topk: Optional[int] = 10, - ) -> Union[List[Document], List[Tuple[Document, float]]]: - tokens = self.analyzer(query) - query_vector = self.vectorizer.vectorize_queries([tokens]) - results = self.indexer.search(query_vector, topk=topk)[0] + ) -> Union[List[Document], List[Tuple[Document, float]], List[List[Document]], List[List[Tuple[Document, float]]]]: + return_as_batch = True + if isinstance(query, str): + query = [query] + return_as_batch = False + + batched_tokens = [self.analyzer(q) for q in query] + query_vector = self.vectorizer.vectorize_queries(batched_tokens) + results = self.indexer.search(query_vector, topk=topk) + + output: Union[List[List[Document]], List[List[Tuple[Document, float]]]] if return_scores: - return [(self.storage[id_], score) for id_, score in results] - return [self.storage[id_] for id_, _ in results] + output = [[(self.storage[id_], score) for id_, score in result] for result in results] + else: + output = [[self.storage[id_] for id_, _ in result] for result in results] + if return_as_batch: + return output + return output[0] def save(self, filename: Union[str, PathLike]) -> None: with open(filename, "wb") as pklfile: From 465406c318d76623fc6ec827b90e7037016d903c Mon Sep 17 00:00:00 2001 From: altescy Date: Tue, 31 Jan 2023 12:14:14 +0900 Subject: [PATCH 2/2] add batch option --- examples/quora/evaluate.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/quora/evaluate.py b/examples/quora/evaluate.py index 3f4502f..d14018b 100644 --- a/examples/quora/evaluate.py +++ b/examples/quora/evaluate.py @@ -17,6 +17,7 @@ def main() -> None: parser.add_argument("tinysearch_filename", type=Path) parser.add_argument("--subset", choices=["dev", "test"], default="dev") parser.add_argument("--topk", type=int, default=10) + parser.add_argument("--batch-size", type=int, default=1) args = parser.parse_args() dataset_reader = DataLoader(f"beir/quora/{args.subset}") @@ -26,16 +27,20 @@ def main() -> None: relations = dataset_reader.load_relations() golds = dataset_reader.load_golds() + num_done = 0 elapsed_time = 0.0 - for i, query in enumerate(dataset_reader.load_query(), start=1): + for batch in tinysearch.util.batched(dataset_reader.load_query(), args.batch_size): + queries = [query["text"] for query in batch] start_time = time.time() - search_results = searcher.search(query["text"], topk=args.topk) + search_results = searcher.search(queries, topk=args.topk) elapsed_time += time.time() - start_time - gold = golds[query["id"]] - pred = [(doc["id"], relations[(query["id"], doc["id"])]) for doc in search_results] - metrics(gold, pred) - metrics_str = ", ".join(f"{k}={v:.4f}" for k, v in metrics.get_metrics().items()) - print(f"\r{100 * i/len(golds):6.2f}% speed={i/elapsed_time:.4f}qs/s {metrics_str}", end="") + for query, result in zip(batch, search_results): + gold = golds[query["id"]] + pred = [(doc["id"], relations[(query["id"], doc["id"])]) for doc in result] + metrics(gold, pred) + num_done += 1 + metrics_str = ", ".join(f"{k}={v:.4f}" for k, v in metrics.get_metrics().items()) + print(f"\r{100 * num_done/len(golds):6.2f}% speed={num_done/elapsed_time:.4f}qs/s {metrics_str}", end="") print()