-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathretrieval.py
167 lines (137 loc) · 7.65 KB
/
retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import logging
from abc import abstractmethod
from typing import List, Dict, Any
import pandas as pd
from scholarqa.rag.reranker.reranker_base import AbstractReranker
from scholarqa.rag.retriever_base import AbstractRetriever
from scholarqa.utils import make_int, get_ref_author_str
logger = logging.getLogger(__name__)
class AbsPaperFinder(AbstractRetriever):
@abstractmethod
def rerank(self, query: str, retrieved_ctxs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
pass
class PaperFinder(AbsPaperFinder):
def __init__(self, retriever: AbstractRetriever,
context_threshold: float = 0.0):
self.retriever = retriever
self.context_threshold = context_threshold
self.n_rerank = -1
def retrieve_passages(self, query: str, **filter_kwargs) -> List[Dict[str, Any]]:
"""Retrieve relevant passages along with scores from an index for the given query"""
return self.retriever.retrieve_passages(query, **filter_kwargs)
def retrieve_additional_papers(self, query: str, **filter_kwargs) -> List[Dict[str, Any]]:
return self.retriever.retrieve_additional_papers(query, **filter_kwargs)
def rerank(self, query: str, retrieved_ctxs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return retrieved_ctxs
def aggregate_into_dataframe(self, snippets_list: List[Dict[str, Any]], paper_metadata: Dict[str, Any]) -> \
pd.DataFrame:
"""The reranked snippets is passage level. This function aggregates the passages to the paper level,
The Dataframe also consists of aggregated passages stitched together with the paper title and abstract in the markdown format."""
snippets_list = [snippet for snippet in snippets_list if snippet["corpus_id"] in paper_metadata
and snippet["text"] is not None]
aggregated_candidates = self.aggregate_snippets_to_papers(snippets_list, paper_metadata)
return self.format_retrieval_response(aggregated_candidates)
@staticmethod
def aggregate_snippets_to_papers(snippets_list: List[Dict[str, Any]], paper_metadata: Dict[str, Any]) -> List[
Dict[str, Any]]:
logging.info("Aggregating the passages at paper level with metadata")
paper_snippets = dict()
for snippet in snippets_list:
corpus_id = snippet["corpus_id"]
if corpus_id not in paper_snippets:
paper_snippets[corpus_id] = paper_metadata[corpus_id]
paper_snippets[corpus_id]["corpus_id"] = corpus_id
paper_snippets[corpus_id]["sentences"] = []
paper_snippets[corpus_id]["sentences"].append(snippet)
paper_snippets[corpus_id]["relevance_judgement"] = max(
paper_snippets[corpus_id].get("relevance_judgement", -1),
snippet.get("rerank_score", snippet["score"]))
if not paper_snippets[corpus_id]["abstract"] and snippet["section_title"] == "abstract":
paper_snippets[corpus_id]["abstract"] = snippet["text"]
sorted_ctxs = sorted(paper_snippets.values(), key=lambda x: x["relevance_judgement"], reverse=True)
logger.info(f"Scores after aggregation: {[s['relevance_judgement'] for s in sorted_ctxs]}")
return sorted_ctxs
def format_retrieval_response(self, agg_reranked_candidates: List[Dict[str, Any]]) -> pd.DataFrame:
def format_sections_to_markdown(row: List[Dict[str, Any]]) -> str:
# convenience function to format the sections of a paper into markdown for function below
# Convert the list of dictionaries to a DataFrame
sentences_df = pd.DataFrame(row)
if sentences_df.empty:
return ""
# Sort by 'char_offset' to ensure sentences are in the correct order
sentences_df.sort_values(by="char_start_offset", inplace=True)
# Group by 'section_title', concatenate sentences, and maintain overall order by the first 'char_offset'
grouped = sentences_df.groupby("section_title", sort=False)["text"].apply("\n...\n".join)
# Exclude sections titled 'Abstract' or 'Title'
grouped = grouped[(grouped.index != "Abstract") & (grouped.index != "Title")]
# Format as Markdown
markdown_output = "\n\n".join(f"## {title}\n{text}" for title, text in grouped.items())
return markdown_output
df = pd.DataFrame(agg_reranked_candidates)
df = df[~df.sentences.isna() & ~df.year.isna()] if not df.empty else df
if df.empty:
return df
df["corpus_id"] = df["corpus_id"].astype(int)
# there are multiple relevance judgments in ['relevance_judgements'] for each paper
# we will keep rows where ANY of the relevance judgments are 2 or 3
df = df[df["relevance_judgement"] > self.context_threshold]
if df.empty:
return df
# authors are lists of jsons. process with "name" key inside
df["year"] = df["year"].apply(make_int)
df["authors"] = df["authors"].fillna(value="")
df.rename(
columns={
"citationCount": "citation_count",
"referenceCount": "reference_count",
"influentialCitationCount": "influential_citation_count",
},
inplace=True,
)
# drop corpusId, paperId,
df = df.drop(columns=["corpusId", "paperId"])
# now we need the big relevance_judgment_input_expanded
# top of it
# \n## Abstract\n{row['abstract']} --> Not using abstracts OR could use and not show
prepend_text = df.apply(
lambda
row: f"# Title: {row['title']}\n# Venue: {row['venue']}\n"
f"# Authors: {', '.join([a['name'] for a in row['authors']])}\n## Abstract\n{row['abstract']}\n",
axis=1,
)
section_text = df["sentences"].apply(format_sections_to_markdown)
# update relevance_judgment_input
df.loc[:, "relevance_judgment_input_expanded"] = prepend_text + section_text
df["reference_string"] = df.apply(
lambda
row: f"[{make_int(row.corpus_id)} | {get_ref_author_str(row.authors)} | "
f"{make_int(row['year'])} | Citations: {make_int(row['citation_count'])}]",
axis=1,
)
return df
class PaperFinderWithReranker(PaperFinder):
def __init__(self, retriever: AbstractRetriever, reranker: AbstractReranker, n_rerank: int = -1,
context_threshold: float = 0.5):
super().__init__(retriever, context_threshold)
self.n_rerank = n_rerank
if reranker:
self.reranker_engine = reranker
else:
raise Exception(f"Reranker not initialized: {reranker}")
def rerank(
self, query: str, retrieved_ctxs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Rerank the retrieved passages using a cross-encoder model and return the top n passages."""
passages = [doc["title"] + " " + doc["text"] if "title" in doc else doc["text"] for doc in retrieved_ctxs]
rerank_scores = self.reranker_engine.get_scores(
query, passages
)
logger.info(f"Reranker scores: {rerank_scores}")
for doc, rerank_score in zip(retrieved_ctxs, rerank_scores):
doc["rerank_score"] = rerank_score
sorted_ctxs = sorted(
retrieved_ctxs, key=lambda x: x["rerank_score"], reverse=True
)
sorted_ctxs = super().rerank(query, sorted_ctxs)
sorted_ctxs = sorted_ctxs[:self.n_rerank] if self.n_rerank > 0 else sorted_ctxs
logging.info(f"Done reranking: {len(sorted_ctxs)} passages remain")
return sorted_ctxs