Skip to content

Commit

Permalink
Add client default timeout limitation and support UI interactive (#90)
Browse files Browse the repository at this point in the history
* Add client default timeout limitation and support UI interactive

* support interactivate for vectordb type
  • Loading branch information
wwxxzz authored Jul 3, 2024
1 parent e0ed40f commit 93a4e67
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
31 changes: 20 additions & 11 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import mimetypes
from pai_rag.app.web.view_model import ViewModel

DEFAULT_CLIENT_TIME_OUT = 60


class dotdict(dict):
"""dot.notation access to dictionary attributes"""
Expand Down Expand Up @@ -62,7 +64,7 @@ def get_evaluate_response_url(self):

def query(self, text: str, session_id: str = None):
q = dict(question=text, session_id=session_id)
r = requests.post(self.query_url, json=q)
r = requests.post(self.query_url, json=q, timeout=DEFAULT_CLIENT_TIME_OUT)
r.raise_for_status()
response = dotdict(json.loads(r.text))
referenced_docs = ""
Expand All @@ -88,15 +90,15 @@ def query_llm(
session_id=session_id,
)

r = requests.post(self.llm_url, json=q)
r = requests.post(self.llm_url, json=q, timeout=DEFAULT_CLIENT_TIME_OUT)
r.raise_for_status()
response = dotdict(json.loads(r.text))

return response

def query_vector(self, text: str):
q = dict(question=text)
r = requests.post(self.retrieval_url, json=q)
r = requests.post(self.retrieval_url, json=q, timeout=DEFAULT_CLIENT_TIME_OUT)
r.raise_for_status()
response = dotdict(json.loads(r.text))
formatted_text = "<tr><th>Document</th><th>Score</th><th>Text</th></tr>\n"
Expand All @@ -122,8 +124,7 @@ def add_knowledge(self, input_files: str, enable_qa_extraction: bool):

try:
r = requests.post(
self.load_data_url,
files=files,
self.load_data_url, files=files, timeout=DEFAULT_CLIENT_TIME_OUT
)
r.raise_for_status()
except:
Expand All @@ -136,7 +137,7 @@ def add_knowledge(self, input_files: str, enable_qa_extraction: bool):
return response

async def get_knowledge_state(self, task_id: str):
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(timeout=DEFAULT_CLIENT_TIME_OUT) as client:
r = await client.get(self.get_load_state_url, params={"task_id": task_id})
r.raise_for_status()
response = dotdict(json.loads(r.text))
Expand All @@ -148,33 +149,41 @@ def patch_config(self, update_dict: Any):
view_model.update(update_dict)
new_config = view_model.to_app_config()

r = requests.patch(self.config_url, json=new_config)
r = requests.patch(
self.config_url, json=new_config, timeout=DEFAULT_CLIENT_TIME_OUT
)
r.raise_for_status()
return

def get_config(self):
r = requests.get(self.config_url)
r = requests.get(self.config_url, timeout=DEFAULT_CLIENT_TIME_OUT)
r.raise_for_status()
response = dotdict(json.loads(r.text))
print(response)
return response

def evaluate_for_generate_qa(self, overwrite):
r = requests.post(
self.get_evaluate_generate_url, params={"overwrite": overwrite}
self.get_evaluate_generate_url,
params={"overwrite": overwrite},
timeout=DEFAULT_CLIENT_TIME_OUT,
)
r.raise_for_status()
response = dotdict(json.loads(r.text))
return response

def evaluate_for_retrieval_stage(self):
r = requests.post(self.get_evaluate_retrieval_url)
r = requests.post(
self.get_evaluate_retrieval_url, timeout=DEFAULT_CLIENT_TIME_OUT
)
r.raise_for_status()
response = dotdict(json.loads(r.text))
return response

def evaluate_for_response_stage(self):
r = requests.post(self.get_evaluate_response_url)
r = requests.post(
self.get_evaluate_response_url, timeout=DEFAULT_CLIENT_TIME_OUT
)
r.raise_for_status()
response = dotdict(json.loads(r.text))
print("evaluate_for_response_stage response", response)
Expand Down
5 changes: 5 additions & 0 deletions src/pai_rag/app/web/tabs/settings_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from pai_rag.app.web.utils import components_to_dict
from pai_rag.app.web.tabs.vector_db_panel import create_vector_db_panel
import logging
import os

logger = logging.getLogger(__name__)

DEFAULT_IS_INTERACTIVE = os.environ.get("PAIRAG_RAG__SETTING__interactive", "true")


def connect_vector_db(input_elements: List[Any]):
try:
Expand All @@ -39,6 +42,7 @@ def create_setting_tab() -> Dict[str, Any]:
EMBEDDING_API_KEY_DICT.keys(),
label="Embedding Type",
elem_id="embed_source",
interactive=DEFAULT_IS_INTERACTIVE.lower() != "false",
)
embed_model = gr.Dropdown(
EMBEDDING_DIM_DICT.keys(),
Expand Down Expand Up @@ -88,6 +92,7 @@ def change_emb_model(model):
["PaiEas", "OpenAI", "DashScope"],
label="LLM Model Source",
elem_id="llm",
interactive=DEFAULT_IS_INTERACTIVE.lower() != "false",
)
with gr.Column(visible=(llm == "PaiEas")) as eas_col:
llm_eas_url = gr.Textbox(
Expand Down
4 changes: 4 additions & 0 deletions src/pai_rag/app/web/tabs/vector_db_panel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import gradio as gr
from typing import Any, Set, Callable, Dict
from pai_rag.app.web.utils import components_to_dict
import os

DEFAULT_IS_INTERACTIVE = os.environ.get("PAIRAG_RAG__SETTING__interactive", "true")


def create_vector_db_panel(
Expand All @@ -15,6 +18,7 @@ def create_vector_db_panel(
["Hologres", "Milvus", "ElasticSearch", "AnalyticDB", "FAISS"],
label="Which VectorStore do you want to use?",
elem_id="vectordb_type",
interactive=DEFAULT_IS_INTERACTIVE.lower() != "false",
)
# Adb
with gr.Column(visible=(vectordb_type == "AnalyticDB")) as adb_col:
Expand Down

0 comments on commit 93a4e67

Please # to comment.