Skip to content

Commit

Permalink
Add fast bm25 (#66)
Browse files Browse the repository at this point in the history
* Add fast bm25

* Fix bm25 bug

* Fix bug

* Fix test
  • Loading branch information
moria97 committed Jun 20, 2024
1 parent d55b7d9 commit e0923ad
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 7 deletions.
1 change: 0 additions & 1 deletion src/pai_rag/data/rag_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,3 @@ async def aload(self, file_directory: str, enable_qa_extraction: bool):
def load(self, file_directory: str, enable_qa_extraction: bool):
loop = asyncio.get_event_loop()
loop.run_until_complete(self.aload(file_directory, enable_qa_extraction))
return
1 change: 0 additions & 1 deletion src/pai_rag/modules/index/pai_bm25_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ def query(self, query_str: str, top_n: int = 5) -> List[NodeWithScore]:

doc_scores = self.index_matrix.multiply(query_vec).sum(axis=1).getA1()
doc_indexes = doc_scores.argsort()[::-1][:top_n]

text_nodes = self.load_docs_with_index(doc_indexes)
for i, node in enumerate(text_nodes):
results.append(NodeWithScore(node=node, score=doc_scores[doc_indexes[i]]))
Expand Down
7 changes: 2 additions & 5 deletions tests/core/test_rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,11 @@ def rag_app():
rag_app = RagApplication()
rag_app.initialize(config)

return rag_app


# Test load knowledge file
def test_add_knowledge_file(rag_app: RagApplication):
data_dir = os.path.join(BASE_DIR, "tests/testdata/paul_graham")
rag_app.load_knowledge(data_dir)

return rag_app


# Test rag query
async def test_query(rag_app: RagApplication):
Expand Down

0 comments on commit e0923ad

Please # to comment.