-
Notifications
You must be signed in to change notification settings - Fork 5.7k
/
Copy pathfile_chat.py
236 lines (207 loc) · 8.99 KB
/
file_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import asyncio
import json
import logging
import os
from typing import AsyncIterable, List, Optional
import nest_asyncio
from fastapi import Body, File, Form, UploadFile
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate
from sse_starlette.sse import EventSourceResponse
from chatchat.settings import Settings
from chatchat.server.chat.utils import History
from chatchat.server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import (
BaseResponse,
get_ChatOpenAI,
get_Embeddings,
get_prompt_template,
get_temp_dir,
run_in_thread_pool,
wrap_done,
)
from chatchat.utils import build_logger
logger = build_logger()
def _parse_files_in_thread(
files: List[UploadFile],
dir: str,
zh_title_enhance: bool,
chunk_size: int,
chunk_overlap: int,
):
"""
通过多线程将上传的文件保存到对应目录内。
生成器返回保存结果:[success or error, filename, msg, docs]
"""
def parse_file(file: UploadFile) -> dict:
"""
保存单个文件。
"""
try:
filename = file.filename
file_path = os.path.join(dir, filename)
file_content = file.file.read() # 读取上传文件的内容
if not os.path.isdir(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, "wb") as f:
f.write(file_content)
kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp")
kb_file.filepath = file_path
docs = kb_file.file2text(
zh_title_enhance=zh_title_enhance,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
return True, filename, f"成功上传文件 {filename}", docs
except Exception as e:
msg = f"{filename} 文件上传失败,报错信息为: {e}"
return False, filename, msg, []
params = [{"file": file} for file in files]
for result in run_in_thread_pool(parse_file, params=params):
yield result
def upload_temp_docs(
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
prev_id: str = Form(None, description="前知识库ID"),
chunk_size: int = Form(Settings.kb_settings.CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(Settings.kb_settings.OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(Settings.kb_settings.ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
) -> BaseResponse:
"""
将文件保存到临时目录,并进行向量化。
返回临时目录名称作为ID,同时也是临时向量库的ID。
"""
if prev_id is not None:
memo_faiss_pool.pop(prev_id)
failed_files = []
documents = []
path, id = get_temp_dir(prev_id)
for success, file, msg, docs in _parse_files_in_thread(
files=files,
dir=path,
zh_title_enhance=zh_title_enhance,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
):
if success:
documents += docs
else:
failed_files.append({file: msg})
try:
with memo_faiss_pool.load_vector_store(kb_name=id).acquire() as vs:
vs.add_documents(documents)
except Exception as e:
logger.error(f"Failed to add documents to faiss: {e}")
return BaseResponse(data={"id": id, "failed_files": failed_files})
async def file_chat(
query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_id: str = Body(..., description="临时知识库ID"),
top_k: int = Body(Settings.kb_settings.VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(
Settings.kb_settings.SCORE_THRESHOLD,
description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
ge=0,
le=2,
),
history: List[History] = Body(
[],
description="历史对话",
examples=[
[
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"},
]
],
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
None, description="限制LLM生成Token数量,默认None代表模型最大值"
),
prompt_name: str = Body(
"default",
description="使用的prompt模板名称(在 prompt_settings.yaml 中配置)",
),
):
if knowledge_id not in memo_faiss_pool.keys():
# return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件")
return BaseResponse(
code=404,
msg=f"""[冲!]欢迎试用【环评查特助手】\r\n
请先上传环评报告等文件以启用编制行为监督检查、项目风险评测分析等功能!""",
)
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
try:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
callbacks = [callback]
# Enable langchain-chatchat to support langfuse
import os
langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
langfuse_host = os.environ.get("LANGFUSE_HOST")
if langfuse_secret_key and langfuse_public_key and langfuse_host:
from langfuse import Langfuse
from langfuse.callback import CallbackHandler
langfuse_handler = CallbackHandler()
callbacks.append(langfuse_handler)
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
)
embed_func = get_Embeddings()
embeddings = await embed_func.aembed_query(query)
with memo_faiss_pool.acquire(knowledge_id) as vs:
docs = vs.similarity_search_with_score_by_vector(
embeddings, k=top_k, score_threshold=score_threshold
)
docs = [x[0] for x in docs]
context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: # 如果没有找到相关文档,使用Empty模板
prompt_template = get_prompt_template("rag", "empty")
else:
prompt_template = get_prompt_template("rag", "default")
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg]
)
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(
wrap_done(
chain.acall({"context": context, "question": query}), callback.done
),
)
source_documents = []
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
text = f"""出处 [{inum + 1}] [{filename}] \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(
f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>"""
)
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps(
{"answer": answer, "docs": source_documents}, ensure_ascii=False
)
await task
except asyncio.exceptions.CancelledError:
logger.warning("streaming progress has been interrupted by user.")
return
return EventSourceResponse(knowledge_base_chat_iterator())