Skip to content

Commit

Permalink
refine query rewrite (#412)
Browse files Browse the repository at this point in the history
* refine query rewrite

* fix chat query

* update prompt

---------

Co-authored-by: ranxia <chenanyu.cay@alibaba-inc.com>
  • Loading branch information
moria97 and Ceceliachenen authored Feb 26, 2025
1 parent dbc59f2 commit a96a685
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 31 deletions.
26 changes: 15 additions & 11 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ async def aretrieve(
new_question = new_query_bundle.query_str
logger.info(f"Transformed question '{new_question}'.")
if new_question != question:
new_question = ",".join([question, new_question])
new_question = " ".join([question, new_question])
logger.info(f"Querying with question '{new_question}'.")

query_bundle = QueryBundle(new_question)
Expand Down Expand Up @@ -553,9 +553,6 @@ async def achat(
# Condense question
new_question = new_query_bundle.query_str
logger.info(f"Transformed question '{new_question}'.")
if new_question != question:
new_question = ",".join([question, new_question])
logger.info(f"Querying with question '{new_question}'.")

if not passed_guardrail:
# 多轮对话,用新查询检查
Expand All @@ -571,10 +568,6 @@ async def achat(
)
passed_guardrail = True

logger.info(f"Querying with question '{new_question}'.")
if new_question != question:
messages[-1].content = ",".join([question, new_question])

query_bundle = PaiQueryBundle(
query_str=new_question,
stream=chat_request.stream,
Expand All @@ -593,6 +586,8 @@ async def achat(
chat_request.search_web = False

if chat_request.search_web:
logger.info(f"Querying with question '{query_bundle.query_str}'.")

search_engine = resolve_searcher(self.config)
if not search_engine:
raise ValueError(
Expand All @@ -617,6 +612,11 @@ async def achat(
return_reference=chat_request.return_reference,
)

if new_question != question:
query_bundle.query_str = " ".join([question, new_question])

logger.info(f"Querying with question '{query_bundle.query_str}'.")

session_config = self.config.model_copy()
index_entry = index_manager.get_index_by_name(chat_request.index_name)
session_config.embedding = index_entry.embedding_config
Expand Down Expand Up @@ -712,9 +712,6 @@ async def aquery(
# Condense question
new_question = new_query_bundle.query_str
logger.info(f"Transformed question '{new_question}'.")
if new_question != question:
new_question = ",".join([question, new_question])
logger.info(f"Querying with question '{new_question}'.")

guardrail = resolve_llm_guardrail(self.config)
# 多轮对话,用新查询检查
Expand Down Expand Up @@ -759,6 +756,11 @@ async def aquery(
chat_messages_str=new_query_bundle.chat_messages_str,
)
if chat_type == RagChatType.RAG:
if new_question != question:
query_bundle.query_str = " ".join([question, new_question])

logger.info(f"Querying with question '{query_bundle.query_str}'.")

session_config = self.config.model_copy()
index_entry = index_manager.get_index_by_name(query.index_name)
session_config.embedding = index_entry.embedding_config
Expand All @@ -771,6 +773,8 @@ async def aquery(
prompt_template_str=query.custom_prompt_template,
)
elif chat_type == RagChatType.WEB:
logger.info(f"Querying with question '{new_question}'.")

search_engine = resolve_searcher(self.config)
if not search_engine:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/integrations/llms/pai/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from llama_index.core.constants import DEFAULT_TEMPERATURE

DEFAULT_MAX_TOKENS = 2000
DEFAULT_MAX_TOKENS = 4000


class DashScopeGenerationModels:
Expand Down
11 changes: 7 additions & 4 deletions src/pai_rag/integrations/query_transform/pai_query_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def run(
r"<think>.*?</think>\n*", "", transformed_query_str, flags=re.DOTALL
)
query_json = parse_json_from_code_block_str(transformed_query_str)
if ("queries" not in query_json) or (len(query_json["queries"]) == 0):
if ("query" not in query_json) or (len(query_json["query"]) == 0):
return PaiQueryBundle(
query_str=chat_messages[-1].content,
need_web_search=False,
Expand All @@ -235,7 +235,7 @@ def run(
)
else:
return PaiQueryBundle(
query_str=",".join(query_json["queries"]),
query_str=query_json["query"],
need_web_search=True,
custom_embedding_strs=[
chat_messages[-1].content,
Expand Down Expand Up @@ -276,16 +276,19 @@ async def arun(
r"<think>.*?</think>\n*", "", transformed_query_str, flags=re.DOTALL
)
query_json = parse_json_from_code_block_str(transformed_query_str)
if ("queries" not in query_json) or (len(query_json["queries"]) == 0):

if ("query" not in query_json) or (len(query_json["query"]) == 0):
return PaiQueryBundle(
query_str=chat_messages[-1].content,
need_web_search=False,
custom_embedding_strs=[chat_messages[-1].content],
chat_messages_str=chat_history_str,
)
else:
if chat_messages[-1].content != query_json["query"]:
chat_history_str += f' {query_json["query"]}'
return PaiQueryBundle(
query_str=",".join(query_json["queries"]),
query_str=query_json["query"],
need_web_search=True,
custom_embedding_strs=[transformed_query_str],
chat_messages_str=chat_history_str,
Expand Down
8 changes: 4 additions & 4 deletions src/pai_rag/integrations/synthesizer/prompt_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
**任务要求:**
- 请严格根据提供的参考内容回答问题,仅参考与问题相关的内容并忽略所有不相关的信息。
- 如果参考内容中没有相关信息或与问题无关,请基于你的已有知识进行回答。
- 确保答案准确、简洁,并且使用与提问相同的语言
- 确保答案准确、简洁,并且使用与用户提问相同的语种
- 在回答过程中,请避免使用“从参考内容得出”、“从材料得出”、“根据参考内容”等措辞。
- 保持回答的专业性和友好性。
- 如果需要更多信息来更好地回答问题,请礼貌地询问。
- 对于复杂的问题,尽量简化解释,使信息易于理解。
- 保持输出语言与输入语言的一致性
- 请保持输出语种与用户输入问题语种的一致性
- 对于涉及不安全/不道德/敏感/色情/暴力/赌博/违法等行为的问题,请明确拒绝提供所要求的信息,并简单解释为什么这样的请求不能被满足。
"""

Expand Down Expand Up @@ -87,7 +87,7 @@
"{context_str}\n"
"------\n"
"问题: {query_str}\n"
"请仔细思考,并使用与提问相同的语言来提供你的答案\n"
"请仔细思考,并使用与提问相同的语种来提供你的答案\n"
)


Expand Down Expand Up @@ -146,7 +146,7 @@
"{context_str}\n"
"------\n"
"问题: {query_str}\n"
"请必须使用和提问相同的语言,仔细思考,给出你的答案:"
"请必须使用和提问相同的语种,仔细思考,给出你的答案:"
)

CITATION_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL_EN = (
Expand Down
16 changes: 9 additions & 7 deletions src/pai_rag/utils/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,25 +113,27 @@
## 技能
### 技能 1: 聊天记录分析
- 分析提供的聊天记录,判断是否需要生成搜索查询。
- 如果存在任何不确定性或可能获取到有用信息的情况,只需生成1-2个广泛且相关的搜索查询
- 如果存在任何不确定性或可能获取到有用信息的情况,只需生成1个相关且精确的搜索查询
### 技能 2: 生成搜索查询
- 生成的搜索查询应简洁、明确且与主题相关。
- 查询应尽可能广泛,以便获取更多相关信息。
- 时间相关查询:(1)高频波动信息(如黄金价格、外汇汇率、股票价格等):请提供具体且最新的时间信息,例如最新一天或实时数据,并使用适当的短时间间隔。(2)低频更新信息(如汽车评测、电影上映、歌曲发布等):请使用较宽泛的时间范围,如最近一个月或更长时间,并提供相关的时间信息。
- 查询应尽可能精准,以便获取更多相关信息。
- 时间相关查询
- 高频波动信息(如黄金价格、外汇汇率、股票价格等):请提供具体且最新的时间信息,例如最新一天或实时数据,并使用适当的短时间间隔。如今天为2025年1月1日,搜索"xxxx最新股价"改写为`2025年1月1日xxxx股价`.
- 低频更新信息(如汽车评测、电影上映、歌曲发布等):请使用较宽泛的时间范围,如最近一个月或更长时间,并提供相关的时间信息。如今天为2025年1月1日,搜索`最近好看的电影`改写为`2025年1月好看的电影`。
- 非时间相关查询:避免随意添加时间信息,确保回答专注于查询的主要内容。
- 生成的查询格式为 JSON 对象:```{ "queries": ["query1", "query2"] }```。
- 生成的查询格式为 JSON 对象:```{ "query": "new query" }```。
### 技能 3: 确定无需搜索
- 如果完全确定不需要额外信息,返回空列表:```{ "queries": [] }```。
- 如果完全确定不需要额外信息,返回空字符串:```{ "query": "" }```。
## 限制
- **仅**以 JSON 对象的形式响应,不允许任何形式的额外评论、解释或附加文本。
- 除非绝对确定没有有用的结果可以通过搜索获得,否则建议生成搜索查询。
- 在生成搜索查询时,确保每个查询都是独立的、简洁的,并且与主题相关。
- 保持输出格式的一致性,严格遵循给定的 JSON 格式要求。
- 简明扼要地专注于撰写高质量的搜索查询,避免不必要的详细说明、评论或假设。
- 保持输出语言与输入语言的一致性
- 保持输出语种与用户输入问题语种的一致性
"""

CONDENSE_QUESTION_ANSWER_PROMPT_ZH = """## 聊天记录:
Expand All @@ -140,7 +142,7 @@
用户:
{question}
请仔细思考后,给出你的答案:
请仔细思考后,给出你的答案,请保持输出语种与用户输入问题语种的一致性
"""

QUERY_GEN_PROMPT = (
Expand Down
28 changes: 24 additions & 4 deletions tests/app/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

from pai_rag.app.app import app

DEFAULT_GUARDRAIL_RESPONSE = "抱歉,无法处理这个请求。"
DEFAULT_EMPTY_RESPONSE = "看起来你发了一条空白消息,有什么能帮到你的吗?"
DEFAULT_ERROR_RESPONSE = "抱歉,系统出错,暂时无法处理这个请求。"


async def upload_file(input_files, index_name="default_index"):
files = []
Expand Down Expand Up @@ -224,7 +228,11 @@ async def test_openai_websearch():

answer = response.json()["choices"][0]["message"]["content"]

assert "股价" in answer
assert (
answer != DEFAULT_GUARDRAIL_RESPONSE
and answer != DEFAULT_EMPTY_RESPONSE
and answer != DEFAULT_ERROR_RESPONSE
)
assert len(response.json()["citations"]) > 0

# 不返回reference
Expand All @@ -245,7 +253,11 @@ async def test_openai_websearch():

answer = response.json()["choices"][0]["message"]["content"]

assert "股价" in answer
assert (
answer != DEFAULT_GUARDRAIL_RESPONSE
and answer != DEFAULT_EMPTY_RESPONSE
and answer != DEFAULT_ERROR_RESPONSE
)
assert len(response.json()["citations"]) == 0


Expand Down Expand Up @@ -306,7 +318,11 @@ async def test_openai_websearch_stream():
answer += delta
citations = chunk_data.get("citations", [])

assert "股价" in answer
assert (
answer != DEFAULT_GUARDRAIL_RESPONSE
and answer != DEFAULT_EMPTY_RESPONSE
and answer != DEFAULT_ERROR_RESPONSE
)
assert len(citations) > 0

# 不返回引用
Expand Down Expand Up @@ -335,7 +351,11 @@ async def test_openai_websearch_stream():
answer += delta
citations = chunk_data.get("citations", [])

assert "股价" in answer
assert (
answer != DEFAULT_GUARDRAIL_RESPONSE
and answer != DEFAULT_EMPTY_RESPONSE
and answer != DEFAULT_ERROR_RESPONSE
)
assert len(citations) == 0


Expand Down

0 comments on commit a96a685

Please # to comment.