|
| 1 | +# coding=utf-8 |
| 2 | +import json |
| 3 | +import time |
| 4 | +import uuid |
| 5 | +from typing import List, Dict |
| 6 | +from application.flow.i_step_node import NodeResult, INode |
| 7 | +from application.flow.step_node.application_node.i_application_node import IApplicationNode |
| 8 | +from application.models import Chat |
| 9 | +from common.handle.impl.response.openai_to_response import OpenaiToResponse |
| 10 | + |
| 11 | + |
| 12 | +def string_to_uuid(input_str): |
| 13 | + return str(uuid.uuid5(uuid.NAMESPACE_DNS, input_str)) |
| 14 | + |
| 15 | + |
| 16 | +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): |
| 17 | + result = node_variable.get('result') |
| 18 | + node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0) |
| 19 | + node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0) |
| 20 | + node.context['answer'] = answer |
| 21 | + node.context['question'] = node_variable['question'] |
| 22 | + node.context['run_time'] = time.time() - node.context['start_time'] |
| 23 | + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): |
| 24 | + workflow.answer += answer |
| 25 | + |
| 26 | + |
| 27 | +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): |
| 28 | + """ |
| 29 | + 写入上下文数据 (流式) |
| 30 | + @param node_variable: 节点数据 |
| 31 | + @param workflow_variable: 全局数据 |
| 32 | + @param node: 节点 |
| 33 | + @param workflow: 工作流管理器 |
| 34 | + """ |
| 35 | + response = node_variable.get('result') |
| 36 | + answer = '' |
| 37 | + usage = {} |
| 38 | + for chunk in response: |
| 39 | + # 先把流转成字符串 |
| 40 | + response_content = chunk.decode('utf-8')[6:] |
| 41 | + response_content = json.loads(response_content) |
| 42 | + choices = response_content.get('choices') |
| 43 | + if choices and isinstance(choices, list) and len(choices) > 0: |
| 44 | + content = choices[0].get('delta', {}).get('content', '') |
| 45 | + answer += content |
| 46 | + yield content |
| 47 | + usage = response_content.get('usage', {}) |
| 48 | + node_variable['result'] = {'usage': usage} |
| 49 | + _write_context(node_variable, workflow_variable, node, workflow, answer) |
| 50 | + |
| 51 | + |
| 52 | +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): |
| 53 | + """ |
| 54 | + 写入上下文数据 |
| 55 | + @param node_variable: 节点数据 |
| 56 | + @param workflow_variable: 全局数据 |
| 57 | + @param node: 节点实例对象 |
| 58 | + @param workflow: 工作流管理器 |
| 59 | + """ |
| 60 | + response = node_variable.get('result')['choices'][0]['message'] |
| 61 | + answer = response.get('content', '') or "抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。" |
| 62 | + _write_context(node_variable, workflow_variable, node, workflow, answer) |
| 63 | + |
| 64 | + |
| 65 | +class BaseApplicationNode(IApplicationNode): |
| 66 | + |
| 67 | + def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, |
| 68 | + **kwargs) -> NodeResult: |
| 69 | + from application.serializers.chat_message_serializers import ChatMessageSerializer |
| 70 | + # 生成嵌入应用的chat_id |
| 71 | + current_chat_id = string_to_uuid(chat_id + application_id) |
| 72 | + Chat.objects.get_or_create(id=current_chat_id, defaults={ |
| 73 | + 'application_id': application_id, |
| 74 | + 'abstract': message |
| 75 | + }) |
| 76 | + response = ChatMessageSerializer( |
| 77 | + data={'chat_id': current_chat_id, 'message': message, |
| 78 | + 're_chat': re_chat, |
| 79 | + 'stream': stream, |
| 80 | + 'application_id': application_id, |
| 81 | + 'client_id': client_id, |
| 82 | + 'client_type': client_type, 'form_data': kwargs}).chat(base_to_response=OpenaiToResponse()) |
| 83 | + if response.status_code == 200: |
| 84 | + if stream: |
| 85 | + content_generator = response.streaming_content |
| 86 | + return NodeResult({'result': content_generator, 'question': message}, {}, |
| 87 | + _write_context=write_context_stream) |
| 88 | + else: |
| 89 | + data = json.loads(response.content) |
| 90 | + return NodeResult({'result': data, 'question': message}, {}, |
| 91 | + _write_context=write_context) |
| 92 | + |
| 93 | + def get_details(self, index: int, **kwargs): |
| 94 | + global_fields = [] |
| 95 | + for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []): |
| 96 | + global_fields.append({ |
| 97 | + 'label': api_input_field['variable'], |
| 98 | + 'key': api_input_field['variable'], |
| 99 | + 'value': self.workflow_manage.get_reference_field( |
| 100 | + api_input_field['value'][0], |
| 101 | + api_input_field['value'][1:]) |
| 102 | + }) |
| 103 | + for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []): |
| 104 | + global_fields.append({ |
| 105 | + 'label': user_input_field['label'], |
| 106 | + 'key': user_input_field['field'], |
| 107 | + 'value': self.workflow_manage.get_reference_field( |
| 108 | + user_input_field['value'][0], |
| 109 | + user_input_field['value'][1:]) |
| 110 | + }) |
| 111 | + return { |
| 112 | + 'name': self.node.properties.get('stepName'), |
| 113 | + "index": index, |
| 114 | + "info": self.node.properties.get('node_data'), |
| 115 | + 'run_time': self.context.get('run_time'), |
| 116 | + 'question': self.context.get('question'), |
| 117 | + 'answer': self.context.get('answer'), |
| 118 | + 'type': self.node.type, |
| 119 | + 'message_tokens': self.context.get('message_tokens'), |
| 120 | + 'answer_tokens': self.context.get('answer_tokens'), |
| 121 | + 'status': self.status, |
| 122 | + 'err_message': self.err_message, |
| 123 | + 'global_fields': global_fields |
| 124 | + } |
0 commit comments