|
40 | 40 | from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
41 | 41 | from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id
|
42 | 42 | from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
43 |
| -from embedding.task import embedding_by_paragraph |
| 43 | +from embedding.task import embedding_by_paragraph, embedding_by_paragraph_list |
44 | 44 | from setting.models import Model
|
45 | 45 | from setting.models_provider import get_model_credential
|
46 | 46 | from smartdoc.conf import PROJECT_DIR
|
@@ -658,3 +658,67 @@ def delete(self, with_valid=True):
|
658 | 658 | data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id})
|
659 | 659 | o.is_valid(raise_exception=True)
|
660 | 660 | return o.delete()
|
| 661 | + |
| 662 | + class PostImprove(serializers.Serializer): |
| 663 | + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) |
| 664 | + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) |
| 665 | + chat_ids = serializers.ListSerializer(child=serializers.UUIDField(), required=True, |
| 666 | + error_messages=ErrMessage.list("对话id")) |
| 667 | + |
| 668 | + def is_valid(self, *, raise_exception=False): |
| 669 | + super().is_valid(raise_exception=True) |
| 670 | + if not Document.objects.filter(id=self.data['document_id'], dataset_id=self.data['dataset_id']).exists(): |
| 671 | + raise AppApiException(500, "文档id不正确") |
| 672 | + |
| 673 | + @staticmethod |
| 674 | + def post_embedding_paragraph(paragraph_ids, dataset_id): |
| 675 | + model_id = get_embedding_model_id_by_dataset_id(dataset_id) |
| 676 | + embedding_by_paragraph_list(paragraph_ids, model_id) |
| 677 | + |
| 678 | + @post(post_function=post_embedding_paragraph) |
| 679 | + @transaction.atomic |
| 680 | + def post_improve(self, instance: Dict): |
| 681 | + ChatRecordSerializer.PostImprove(data=instance).is_valid(raise_exception=True) |
| 682 | + |
| 683 | + chat_ids = instance['chat_ids'] |
| 684 | + document_id = instance['document_id'] |
| 685 | + dataset_id = instance['dataset_id'] |
| 686 | + |
| 687 | + # 获取所有聊天记录 |
| 688 | + chat_record_list = list(ChatRecord.objects.filter(chat_id__in=chat_ids)) |
| 689 | + if len(chat_record_list) < len(chat_ids): |
| 690 | + raise AppApiException(500, "存在不存在的对话记录") |
| 691 | + |
| 692 | + # 批量创建段落和问题映射 |
| 693 | + paragraphs = [] |
| 694 | + paragraph_ids = [] |
| 695 | + problem_paragraph_mappings = [] |
| 696 | + for chat_record in chat_record_list: |
| 697 | + paragraph = Paragraph( |
| 698 | + id=uuid.uuid1(), |
| 699 | + document_id=document_id, |
| 700 | + content=chat_record.answer_text, |
| 701 | + dataset_id=dataset_id, |
| 702 | + title=chat_record.problem_text |
| 703 | + ) |
| 704 | + problem, _ = Problem.objects.get_or_create(content=chat_record.problem_text, dataset_id=dataset_id) |
| 705 | + problem_paragraph_mapping = ProblemParagraphMapping( |
| 706 | + id=uuid.uuid1(), |
| 707 | + dataset_id=dataset_id, |
| 708 | + document_id=document_id, |
| 709 | + problem_id=problem.id, |
| 710 | + paragraph_id=paragraph.id |
| 711 | + ) |
| 712 | + paragraphs.append(paragraph) |
| 713 | + paragraph_ids.append(paragraph.id) |
| 714 | + problem_paragraph_mappings.append(problem_paragraph_mapping) |
| 715 | + chat_record.improve_paragraph_id_list.append(paragraph.id) |
| 716 | + |
| 717 | + # 批量保存段落和问题映射 |
| 718 | + Paragraph.objects.bulk_create(paragraphs) |
| 719 | + ProblemParagraphMapping.objects.bulk_create(problem_paragraph_mappings) |
| 720 | + |
| 721 | + # 批量保存聊天记录 |
| 722 | + ChatRecord.objects.bulk_update(chat_record_list, ['improve_paragraph_id_list']) |
| 723 | + |
| 724 | + return paragraph_ids, dataset_id |
0 commit comments