Skip to content

Commit

Permalink
Add transformed query to query_str (#383)
Browse files Browse the repository at this point in the history
  • Loading branch information
wwxxzz authored Feb 14, 2025
1 parent 4b72f67 commit 2170351
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ async def achat(

new_question = new_query_bundle.query_str
logger.info(f"Querying with question '{new_question}'.")
messages[-1].content = ",".join([question, new_question])
query_bundle = PaiQueryBundle(
query_str=new_question,
stream=chat_request.stream,
Expand Down
36 changes: 28 additions & 8 deletions src/pai_rag/integrations/query_transform/pai_query_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def _run(self, query_bundle: QueryBundle, session_id, chat_history) -> QueryBund
self._chat_store.add_message(hist_mes)

chat_history = self._chat_store.get_messages(session_id)
chat_history.append(ChatMessage(role="user", content=query_str))
chat_history_str = messages_to_history_str(chat_history)

logger.debug(f"Chat history: {chat_history_str}")
Expand All @@ -205,22 +204,31 @@ def _run(self, query_bundle: QueryBundle, session_id, chat_history) -> QueryBund
question=query_str,
chat_history=chat_history_str,
)
logger.debug(f"Transformed query: {transformed_query_str}")
logger.debug(f"Transformed query [{query_str}] --> [{transformed_query_str}]")
# 修复thought输出
transformed_query_str = re.sub(
r"<think>.*?</think>\n*", "", transformed_query_str, flags=re.DOTALL
)
query_json = parse_json_from_code_block_str(transformed_query_str)
if ("queries" not in query_json) or (len(query_json["queries"]) == 0):
chat_history.append(ChatMessage(role="user", content=query_str))
chat_history_str = messages_to_history_str(chat_history)
return PaiQueryBundle(
query_str=query_str,
need_web_search=False,
custom_embedding_strs=[query_str, transformed_query_str],
chat_messages_str=chat_history_str,
)
else:
transformed_queries = ",".join(query_json["queries"])
chat_history.append(
ChatMessage(
role="user", content=",".join([query_str, transformed_queries])
)
)
chat_history_str = messages_to_history_str(chat_history)
return PaiQueryBundle(
query_str=",".join(query_json["queries"]),
query_str=transformed_queries,
need_web_search=True,
custom_embedding_strs=[query_str, transformed_query_str],
chat_messages_str=chat_history_str,
Expand Down Expand Up @@ -255,7 +263,6 @@ async def _arun(
self._chat_store.add_message(key=session_id, message=hist_mes)

chat_history = self._chat_store.get_messages(key=session_id)
chat_history.append(ChatMessage(role="user", content=query_str))
chat_history_str = messages_to_history_str(chat_history)

logger.debug(f"Chat history: {chat_history_str}")
Expand All @@ -264,22 +271,31 @@ async def _arun(
question=query_str,
chat_history=chat_history_str,
)
logger.debug(f"Transformed query: {transformed_query_str}")
logger.debug(f"Transformed query [{query_str}] --> [{transformed_query_str}]")
# 修复thought输出
transformed_query_str = re.sub(
r"<think>.*?</think>\n*", "", transformed_query_str, flags=re.DOTALL
)
query_json = parse_json_from_code_block_str(transformed_query_str)
if ("queries" not in query_json) or (len(query_json["queries"]) == 0):
chat_history.append(ChatMessage(role="user", content=query_str))
chat_history_str = messages_to_history_str(chat_history)
return PaiQueryBundle(
query_str=query_str,
need_web_search=False,
custom_embedding_strs=[query_str, transformed_query_str],
chat_messages_str=chat_history_str,
)
else:
transformed_queries = ",".join(query_json["queries"])
chat_history.append(
ChatMessage(
role="user", content=",".join([query_str, transformed_queries])
)
)
chat_history_str = messages_to_history_str(chat_history)
return PaiQueryBundle(
query_str=",".join(query_json["queries"]),
query_str=transformed_queries,
need_web_search=True,
custom_embedding_strs=[query_str, transformed_query_str],
chat_messages_str=chat_history_str,
Expand Down Expand Up @@ -341,7 +357,9 @@ def run(
question=chat_messages[-1].content,
chat_history=chat_history_str,
)
logger.debug(f"Transformed query: {transformed_query_str}")
logger.debug(
f"Transformed query [{chat_messages[-1].content}] --> [{transformed_query_str}]"
)
# 修复thought输出
transformed_query_str = re.sub(
r"<think>.*?</think>\n*", "", transformed_query_str, flags=re.DOTALL
Expand Down Expand Up @@ -381,7 +399,9 @@ async def arun(
question=chat_messages[-1].content,
chat_history=chat_history_str,
)
logger.debug(f"Transformed query: {transformed_query_str}")
logger.debug(
f"Transformed query [{chat_messages[-1].content}] --> [{transformed_query_str}]"
)
# 修复thought输出
transformed_query_str = re.sub(
r"<think>.*?</think>\n*", "", transformed_query_str, flags=re.DOTALL
Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/integrations/search/aliyun_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def aquery(
prompt_template_str=prompt_template_str,
)

logger.info(f"Search with query {query.query_str,}.")
logger.info(f"Aliyun Search with query {query.query_str,}.")
nodes = await self._asearch(query=query.query_str)
logger.info(f"Get {len(nodes)} docs from url.")

Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/integrations/search/bing_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def aquery(
prompt_template_str=prompt_template_str,
)

logger.info(f"Search with query {query.query_str,}.")
logger.info(f"Bing Search with query {query.query_str,}.")
docs = await self._asearch(
query=query.query_str,
)
Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/integrations/search/quark_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def aquery(
prompt_template_str=prompt_template_str,
)

logger.info(f"Search with query {query.query_str,}.")
logger.info(f"Quark Search with query {query.query_str,}.")
nodes = await self.asearch(query=query.query_str)
logger.info(f"Get {len(nodes)} docs from url.")

Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/utils/json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def parse_json_from_code_block_str(input_str):
content = input_str[start : end + 1]
try:
data = json.loads(content)
logger.debug("解析后的 JSON 对象:", data)
logger.debug(f"解析后的 JSON 对象:{data}")
return data
except json.JSONDecodeError as e:
logger.debug("JSON 解码错误:", e)
Expand Down

0 comments on commit 2170351

Please # to comment.