Skip to content

Commit

Permalink
fix(similarity_query): new param to set max num of results
Browse files Browse the repository at this point in the history
  • Loading branch information
valearna committed Dec 14, 2020
1 parent fc2814c commit e2bbdba
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions wbtools/literature/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def load(self, file_path: str) -> None:
def query_papers_by_doc_similarity(self, query_docs: List[str], sentence_search: bool = False,
remove_sections: List[PaperSections] = None,
must_be_present: List[PaperSections] = None, path_to_model: str = None,
average_match: bool = True) -> List[SimilarityResult]:
average_match: bool = True, num_best: int = 10) -> List[SimilarityResult]:
"""query papers in the corpus by similarity with the provided query documents, which can be fulltext documents
or sentences
Expand All @@ -186,6 +186,7 @@ def query_papers_by_doc_similarity(self, query_docs: List[str], sentence_search:
sections
path_to_model (str): path to word2vec model
average_match (bool): merge query documents and calculate average similarity to them
num_best (int): limit to the first n results by similarity score
Returns:
List[SimilarityResult]: list of papers most similar to the provided query documents
Expand All @@ -198,12 +199,13 @@ def query_papers_by_doc_similarity(self, query_docs: List[str], sentence_search:
split_sentences=sentence_search, remove_sections=remove_sections, must_be_present=must_be_present,
lowercase=False, tokenize=False, remove_stopwords=False, remove_alpha=False)
docsim_index, dictionary = get_softcosine_index(model=model, model_path=path_to_model,
corpus_list_token=corpus_list_token)
corpus_list_token=corpus_list_token, num_best=num_best)
query_docs_preprocessed = [preprocess(doc=sentence, lower=True, tokenize=True, remove_stopwords=True,
remove_alpha=True) for sentence in query_docs]
sims = get_similar_documents(docsim_index, dictionary, query_docs_preprocessed, idx_paperid_map,
average_match=average_match)
return [SimilarityResult(score=sim.score, paper_id=sim.paper_id, match_idx=sim.match_idx,
query_idx=sim.query_idx, match="\"" + corpus_list_token_orig[sim.match_idx] + "\"",
query="\"" + (" ".join(query_docs) if average_match else query_docs[sim.query_idx]) +
"\"") for sim in sims]
results = [SimilarityResult(score=sim.score, paper_id=sim.paper_id, match_idx=sim.match_idx,
query_idx=sim.query_idx, match="\"" + corpus_list_token_orig[sim.match_idx] + "\"",
query="\"" + (" ".join(query_docs) if average_match else query_docs[sim.query_idx]
) + "\"") for sim in sims]
return results[0:num_best] if len(results) > num_best else results

0 comments on commit e2bbdba

Please # to comment.