diff --git a/wbtools/literature/corpus.py b/wbtools/literature/corpus.py index 08563fb..2e719b8 100644 --- a/wbtools/literature/corpus.py +++ b/wbtools/literature/corpus.py @@ -174,7 +174,7 @@ def load(self, file_path: str) -> None: def query_papers_by_doc_similarity(self, query_docs: List[str], sentence_search: bool = False, remove_sections: List[PaperSections] = None, must_be_present: List[PaperSections] = None, path_to_model: str = None, - average_match: bool = True) -> List[SimilarityResult]: + average_match: bool = True, num_best: int = 10) -> List[SimilarityResult]: """query papers in the corpus by similarity with the provided query documents, which can be fulltext documents or sentences @@ -186,6 +186,7 @@ def query_papers_by_doc_similarity(self, query_docs: List[str], sentence_search: sections path_to_model (str): path to word2vec model average_match (bool): merge query documents and calculate average similarity to them + num_best (int): limit to the first n results by similarity score Returns: List[SimilarityResult]: list of papers most similar to the provided query documents @@ -198,12 +199,13 @@ def query_papers_by_doc_similarity(self, query_docs: List[str], sentence_search: split_sentences=sentence_search, remove_sections=remove_sections, must_be_present=must_be_present, lowercase=False, tokenize=False, remove_stopwords=False, remove_alpha=False) docsim_index, dictionary = get_softcosine_index(model=model, model_path=path_to_model, - corpus_list_token=corpus_list_token) + corpus_list_token=corpus_list_token, num_best=num_best) query_docs_preprocessed = [preprocess(doc=sentence, lower=True, tokenize=True, remove_stopwords=True, remove_alpha=True) for sentence in query_docs] sims = get_similar_documents(docsim_index, dictionary, query_docs_preprocessed, idx_paperid_map, average_match=average_match) - return [SimilarityResult(score=sim.score, paper_id=sim.paper_id, match_idx=sim.match_idx, - query_idx=sim.query_idx, match="\"" + corpus_list_token_orig[sim.match_idx] + "\"", - query="\"" + (" ".join(query_docs) if average_match else query_docs[sim.query_idx]) + - "\"") for sim in sims] + results = [SimilarityResult(score=sim.score, paper_id=sim.paper_id, match_idx=sim.match_idx, + query_idx=sim.query_idx, match="\"" + corpus_list_token_orig[sim.match_idx] + "\"", + query="\"" + (" ".join(query_docs) if average_match else query_docs[sim.query_idx] + ) + "\"") for sim in sims] + return results[0:num_best] if len(results) > num_best else results