Skip to content

Commit

Permalink
Fix test bug (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
moria97 authored Aug 30, 2024
1 parent e2e11b2 commit 2289b9f
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def upload_data(
input_files.append(save_file)

background_tasks.add_task(
rag_service.add_knowledge_async,
rag_service.add_knowledge,
task_id=task_id,
input_files=input_files,
filter_pattern=None,
Expand Down
4 changes: 2 additions & 2 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def reload(self, config):
self.initialize(config)
self.logger.info("RagApplication reloaded successfully.")

async def aload_knowledge(
def load_knowledge(
self,
input_files,
filter_pattern=None,
Expand All @@ -70,7 +70,7 @@ async def aload_knowledge(
data_loader = module_registry.get_module_with_config(
"DataLoaderModule", sessioned_config
)
await data_loader.aload(
data_loader.load(
input_files, filter_pattern, enable_qa_extraction, enable_raptor
)

Expand Down
4 changes: 2 additions & 2 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def check_updates(self):
else:
logger.debug("No configuration updates")

async def add_knowledge_async(
def add_knowledge(
self,
task_id: str,
input_files: List[str],
Expand All @@ -109,7 +109,7 @@ async def add_knowledge_async(
with open(TASK_STATUS_FILE, "a") as f:
f.write(f"{task_id}\tprocessing\n")
try:
await self.rag.aload_knowledge(
self.rag.load_knowledge(
input_files,
filter_pattern,
faiss_path,
Expand Down
14 changes: 8 additions & 6 deletions src/pai_rag/modules/llm/multi_modal_llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import logging
from typing import Dict, List, Any
from llama_index.multi_modal_llms.dashscope import (
DashScopeMultiModal,
DashScopeMultiModalModels,
)

import os
from pai_rag.integrations.llms.multimodal.open_ai_alike_multi_modal import (
OpenAIAlikeMultiModal,
)
Expand All @@ -13,6 +9,8 @@

logger = logging.getLogger(__name__)

DEFAULT_DASHSCOPE_API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"


class MultiModalLlmModule(ConfigurableModule):
@staticmethod
Expand All @@ -23,7 +21,11 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
llm_config = new_params[MODULE_PARAM_CONFIG]
if llm_config.source.lower() == "dashscope":
logger.info("Using DashScope Multi-Modal-LLM.")
return DashScopeMultiModal(model_name=DashScopeMultiModalModels.QWEN_VL_MAX)
return OpenAIAlikeMultiModal(
model="qwen-vl-max",
api_base=DEFAULT_DASHSCOPE_API_BASE,
api_key=os.environ.get("DASHSCOPE_API_KEY"),
)
elif llm_config.source.lower() == "paieas" and llm_config.get("endpoint"):
logger.info("Using PAI-EAS Multi-Modal-LLM.")
return OpenAIAlikeMultiModal(
Expand Down
35 changes: 18 additions & 17 deletions tests/core/test_rag_application.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
from pathlib import Path
from pai_rag.app.api.models import RagQuery
Expand All @@ -13,7 +14,7 @@


@pytest.fixture
async def rag_app():
def rag_app():
config_file = os.path.join(BASE_DIR, "src/pai_rag/config/settings.toml")
config = RagConfiguration.from_file(config_file).get_value()

Expand All @@ -25,52 +26,52 @@ async def rag_app():
rag_app.initialize(config)

data_dir = os.path.join(BASE_DIR, "tests/testdata/paul_graham")
await rag_app.aload_knowledge(data_dir)
rag_app.load_knowledge(data_dir)

return rag_app


# Test rag query
async def test_query(rag_app: RagApplication):
def test_query(rag_app: RagApplication):
query = RagQuery(question="Why did he decide to learn AI?")
response = await rag_app.aquery(query)
response = asyncio.run(rag_app.aquery(query))
assert len(response.answer) > 10

query = RagQuery(question="")
response = await rag_app.aquery(query)
response = asyncio.run(rag_app.aquery(query))
assert response.answer == EXPECTED_EMPTY_RESPONSE


# Test llm query
async def test_llm(rag_app: RagApplication):
def test_llm(rag_app: RagApplication):
query = RagQuery(question="What is the result of 15+22?")
response = await rag_app.aquery_llm(query)
response = asyncio.run(rag_app.aquery_llm(query))
assert "37" in response.answer

query = RagQuery(question="")
response = await rag_app.aquery_llm(query)
response = asyncio.run(rag_app.aquery_llm(query))
assert response.answer == EXPECTED_EMPTY_RESPONSE


# Test retrieval query
async def test_retrieval(rag_app: RagApplication):
def test_retrieval(rag_app: RagApplication):
retrieval_query = RagQuery(question="Why did he decide to learn AI?")
response = await rag_app.aquery_retrieval(retrieval_query)
response = asyncio.run(rag_app.aquery_retrieval(retrieval_query))
assert len(response.docs) > 0

query = RagQuery(question="")
response = await rag_app.aquery_retrieval(query)
empty_query = RagQuery(question="")
response = asyncio.run(rag_app.aquery_retrieval(empty_query))
assert len(response.docs) == 0


# Test agent query
async def test_agent(rag_app: RagApplication):
def test_agent(rag_app: RagApplication):
query = RagQuery(question="What is the result of 15+22?")
response = await rag_app.aquery_agent(query)
response = asyncio.run(rag_app.aquery_agent(query))
assert "37" in response.answer

query = RagQuery(question="")
response = await rag_app.aquery_agent(query)
response = asyncio.run(rag_app.aquery_agent(query))
assert response.answer == EXPECTED_EMPTY_RESPONSE


Expand All @@ -79,6 +80,6 @@ async def test_agent(rag_app: RagApplication):
# print('eval_res_avg', eval_res_avg)


async def test_load_evaluation_qa_dataset(rag_app: RagApplication):
qa_dataset = await rag_app.aload_evaluation_qa_dataset()
def test_load_evaluation_qa_dataset(rag_app: RagApplication):
qa_dataset = asyncio.run(rag_app.aload_evaluation_qa_dataset())
print(qa_dataset)

0 comments on commit 2289b9f

Please # to comment.