Skip to content

Commit 35eb9d7

Browse files
committedNov 12, 2024
feat: 修改标注时可以选择默认知识库
--story=1016833 --user=王孝刚 对话日志-修改标注时可以选择默认知识库 #229 https://www.tapd.cn/57709429/s/1608846
1 parent 0f60017 commit 35eb9d7

File tree

7 files changed

+335
-9
lines changed

7 files changed

+335
-9
lines changed
 

‎apps/application/serializers/chat_serializers.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
4141
from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id
4242
from dataset.serializers.paragraph_serializers import ParagraphSerializers
43-
from embedding.task import embedding_by_paragraph
43+
from embedding.task import embedding_by_paragraph, embedding_by_paragraph_list
4444
from setting.models import Model
4545
from setting.models_provider import get_model_credential
4646
from smartdoc.conf import PROJECT_DIR
@@ -658,3 +658,67 @@ def delete(self, with_valid=True):
658658
data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id})
659659
o.is_valid(raise_exception=True)
660660
return o.delete()
661+
662+
class PostImprove(serializers.Serializer):
663+
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
664+
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
665+
chat_ids = serializers.ListSerializer(child=serializers.UUIDField(), required=True,
666+
error_messages=ErrMessage.list("对话id"))
667+
668+
def is_valid(self, *, raise_exception=False):
669+
super().is_valid(raise_exception=True)
670+
if not Document.objects.filter(id=self.data['document_id'], dataset_id=self.data['dataset_id']).exists():
671+
raise AppApiException(500, "文档id不正确")
672+
673+
@staticmethod
674+
def post_embedding_paragraph(paragraph_ids, dataset_id):
675+
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
676+
embedding_by_paragraph_list(paragraph_ids, model_id)
677+
678+
@post(post_function=post_embedding_paragraph)
679+
@transaction.atomic
680+
def post_improve(self, instance: Dict):
681+
ChatRecordSerializer.PostImprove(data=instance).is_valid(raise_exception=True)
682+
683+
chat_ids = instance['chat_ids']
684+
document_id = instance['document_id']
685+
dataset_id = instance['dataset_id']
686+
687+
# 获取所有聊天记录
688+
chat_record_list = list(ChatRecord.objects.filter(chat_id__in=chat_ids))
689+
if len(chat_record_list) < len(chat_ids):
690+
raise AppApiException(500, "存在不存在的对话记录")
691+
692+
# 批量创建段落和问题映射
693+
paragraphs = []
694+
paragraph_ids = []
695+
problem_paragraph_mappings = []
696+
for chat_record in chat_record_list:
697+
paragraph = Paragraph(
698+
id=uuid.uuid1(),
699+
document_id=document_id,
700+
content=chat_record.answer_text,
701+
dataset_id=dataset_id,
702+
title=chat_record.problem_text
703+
)
704+
problem, _ = Problem.objects.get_or_create(content=chat_record.problem_text, dataset_id=dataset_id)
705+
problem_paragraph_mapping = ProblemParagraphMapping(
706+
id=uuid.uuid1(),
707+
dataset_id=dataset_id,
708+
document_id=document_id,
709+
problem_id=problem.id,
710+
paragraph_id=paragraph.id
711+
)
712+
paragraphs.append(paragraph)
713+
paragraph_ids.append(paragraph.id)
714+
problem_paragraph_mappings.append(problem_paragraph_mapping)
715+
chat_record.improve_paragraph_id_list.append(paragraph.id)
716+
717+
# 批量保存段落和问题映射
718+
Paragraph.objects.bulk_create(paragraphs)
719+
ProblemParagraphMapping.objects.bulk_create(problem_paragraph_mappings)
720+
721+
# 批量保存聊天记录
722+
ChatRecord.objects.bulk_update(chat_record_list, ['improve_paragraph_id_list'])
723+
724+
return paragraph_ids, dataset_id

‎apps/application/swagger_api/chat_api.py

+32
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,38 @@ def get_request_body_api():
267267
}
268268
)
269269

270+
@staticmethod
271+
def get_request_body_api_post():
272+
return openapi.Schema(
273+
type=openapi.TYPE_OBJECT,
274+
required=['dataset_id', 'document_id', 'chat_ids'],
275+
properties={
276+
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
277+
description="知识库id"),
278+
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
279+
description="文档id"),
280+
'chat_ids': openapi.Schema(type=openapi.TYPE_ARRAY, title="会话id列表",
281+
description="会话id列表",
282+
items=openapi.Schema(type=openapi.TYPE_STRING))
283+
284+
}
285+
)
286+
287+
@staticmethod
288+
def get_request_params_api_post():
289+
return [openapi.Parameter(name='application_id',
290+
in_=openapi.IN_PATH,
291+
type=openapi.TYPE_STRING,
292+
required=True,
293+
description='应用id'),
294+
openapi.Parameter(name='dataset_id',
295+
in_=openapi.IN_PATH,
296+
type=openapi.TYPE_STRING,
297+
required=True,
298+
description='知识库id'),
299+
300+
]
301+
270302

271303
class VoteApi(ApiMixin):
272304
@staticmethod

‎apps/application/urls.py

+4
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@
6161
'application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/dataset/<str:dataset_id>/document_id/<str:document_id>/improve',
6262
views.ChatView.ChatRecord.Improve.as_view(),
6363
name=''),
64+
path(
65+
'application/<str:application_id>/dataset/<str:dataset_id>/improve',
66+
views.ChatView.ChatRecord.Improve.as_view(),
67+
name=''),
6468
path('application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/improve',
6569
views.ChatView.ChatRecord.ChatRecordImprove.as_view()),
6670
path('application/chat_message/<str:chat_id>', views.ChatView.Message.as_view(), name='application/message'),

‎apps/application/views/chat_views.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ def post(self, request: Request, chat_id: str):
129129
'client_id': request.auth.client_id,
130130
'form_data': (request.data.get(
131131
'form_data') if 'form_data' in request.data else {}),
132-
'image_list': request.data.get('image_list') if 'image_list' in request.data else [],
132+
'image_list': request.data.get(
133+
'image_list') if 'image_list' in request.data else [],
133134
'client_type': request.auth.client_type}).chat()
134135

135136
@action(methods=['GET'], detail=False)
@@ -364,6 +365,28 @@ def put(self, request: Request, application_id: str, chat_id: str, chat_record_i
364365
data={'chat_id': chat_id, 'chat_record_id': chat_record_id,
365366
'dataset_id': dataset_id, 'document_id': document_id}).improve(request.data))
366367

368+
@action(methods=['POST'], detail=False)
369+
@swagger_auto_schema(operation_summary="添加至知识库",
370+
operation_id="添加至知识库",
371+
manual_parameters=ImproveApi.get_request_params_api_post(),
372+
request_body=ImproveApi.get_request_body_api_post(),
373+
tags=["应用/对话日志/添加至知识库"]
374+
)
375+
@has_permissions(
376+
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
377+
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
378+
dynamic_tag=keywords.get('application_id'))],
379+
380+
), ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
381+
[lambda r, keywords: Permission(group=Group.DATASET,
382+
operate=Operate.MANAGE,
383+
dynamic_tag=keywords.get(
384+
'dataset_id'))],
385+
compare=CompareConstants.AND
386+
), compare=CompareConstants.AND)
387+
def post(self, request: Request, application_id: str, dataset_id: str):
388+
return result.success(ChatRecordSerializer.PostImprove().post_improve(request.data))
389+
367390
class Operate(APIView):
368391
authentication_classes = [TokenAuth]
369392

@@ -417,4 +440,3 @@ def post(self, request: Request, application_id: str, chat_id: str):
417440
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
418441
file_ids.append({'name': file.name, 'url': file_url, 'file_id': file_url.split('/')[-1]})
419442
return result.success(file_ids)
420-

‎ui/src/api/log.ts

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { Result } from '@/request/Result'
2-
import { get, del, put, exportExcel, exportExcelPost } from '@/request/index'
2+
import { get, del, put, exportExcel, exportExcelPost, post } from '@/request/index'
33
import type { pageRequest } from '@/api/type/common'
44
import { type Ref } from 'vue'
55

@@ -114,7 +114,22 @@ const putChatRecordLog: (
114114
loading
115115
)
116116
}
117+
/**
118+
* 对话记录提交至知识库
119+
* @param data
120+
* @param loading
121+
* @param application_id
122+
* @param dataset_id
123+
*/
117124

125+
const postChatRecordLog: (
126+
application_id: string,
127+
dataset_id: string,
128+
data: any,
129+
loading?: Ref<boolean>
130+
) => Promise<Result<any>> = (application_id, dataset_id, data, loading) => {
131+
return post(`${prefix}/${application_id}/dataset/${dataset_id}/improve`, data, undefined, loading)
132+
}
118133
/**
119134
* 获取标注段落列表信息
120135
* @param 参数
@@ -215,5 +230,6 @@ export default {
215230
delMarkRecord,
216231
exportChatLog,
217232
getChatLogClient,
218-
delChatClientLog
233+
delChatClientLog,
234+
postChatRecordLog
219235
}

‎ui/src/views/log/component/EditContentDialog.vue

+22-1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
filterable
8686
placeholder="请选择文档"
8787
:loading="optionLoading"
88+
@change="changeDocument"
8889
>
8990
<el-option
9091
v-for="item in documentList"
@@ -113,7 +114,7 @@ import logApi from '@/api/log'
113114
import imageApi from '@/api/image'
114115
import useStore from '@/stores'
115116
116-
const { application, document } = useStore()
117+
const { application, document, user } = useStore()
117118
118119
const route = useRoute()
119120
const {
@@ -216,10 +217,19 @@ const onUploadImg = async (files: any, callback: any) => {
216217
}
217218
218219
function changeDataset(id: string) {
220+
if (user.userInfo) {
221+
localStorage.setItem(user.userInfo.id + 'chat_dataset_id', id)
222+
}
219223
form.value.document_id = ''
220224
getDocument(id)
221225
}
222226
227+
function changeDocument(id: string) {
228+
if (user.userInfo) {
229+
localStorage.setItem(user.userInfo.id + 'chat_document_id', id)
230+
}
231+
}
232+
223233
function getDocument(id: string) {
224234
document.asyncGetAllDocument(id, loading).then((res: any) => {
225235
documentList.value = res.data
@@ -229,11 +239,22 @@ function getDocument(id: string) {
229239
function getDataset() {
230240
application.asyncGetApplicationDataset(id, loading).then((res: any) => {
231241
datasetList.value = res.data
242+
if (localStorage.getItem(user.userInfo?.id + 'chat_dataset_id')) {
243+
form.value.dataset_id = localStorage.getItem(user.userInfo?.id + 'chat_dataset_id') as string
244+
if (!datasetList.value.find((v) => v.id === form.value.dataset_id)) {
245+
form.value.dataset_id = ''
246+
} else {
247+
getDocument(form.value.dataset_id)
248+
}
249+
}
232250
})
233251
}
234252
235253
const open = (data: any) => {
236254
getDataset()
255+
if (localStorage.getItem(user.userInfo?.id + 'chat_document_id')) {
256+
form.value.document_id = localStorage.getItem(user.userInfo?.id + 'chat_document_id') as string
257+
}
237258
form.value.chat_id = data.chat_id
238259
form.value.record_id = data.id
239260
form.value.problem_text = data.problem_text ? data.problem_text.substring(0, 256) : ''

0 commit comments

Comments
 (0)