diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py index 62273818cb3..23537f289be 100644 --- a/apps/application/flow/step_node/__init__.py +++ b/apps/application/flow/step_node/__init__.py @@ -7,6 +7,7 @@ @desc: """ from .ai_chat_step_node import * +from .application_node import BaseApplicationNode from .condition_node import * from .question_node import * from .search_dataset_node import * @@ -17,7 +18,7 @@ from .reranker_node import * node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode, - BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode] + BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode] def get_node(node_type): diff --git a/apps/application/flow/step_node/application_node/__init__.py b/apps/application/flow/step_node/application_node/__init__.py new file mode 100644 index 00000000000..d1ea91ca7f8 --- /dev/null +++ b/apps/application/flow/step_node/application_node/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +from .impl import * diff --git a/apps/application/flow/step_node/application_node/i_application_node.py b/apps/application/flow/step_node/application_node/i_application_node.py new file mode 100644 index 00000000000..b11fa00232f --- /dev/null +++ b/apps/application/flow/step_node/application_node/i_application_node.py @@ -0,0 +1,40 @@ +# coding=utf-8 +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class ApplicationNodeSerializer(serializers.Serializer): + application_id = serializers.CharField(required=True, error_messages=ErrMessage.char("应用id")) + question_reference_address = serializers.ListField(required=True, error_messages=ErrMessage.list("用户问题")) + api_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.list("api输入字段")) + user_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.uuid("用户输入字段")) + + +class IApplicationNode(INode): + type = 'application-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ApplicationNodeSerializer + + def _run(self): + question = self.workflow_manage.get_reference_field( + self.node_params_serializer.data.get('question_reference_address')[0], + self.node_params_serializer.data.get('question_reference_address')[1:]) + kwargs = {} + for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []): + kwargs[api_input_field['variable']] = self.workflow_manage.get_reference_field(api_input_field['value'][0], + api_input_field['value'][1:]) + for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []): + kwargs[user_input_field['field']] = self.workflow_manage.get_reference_field(user_input_field['value'][0], + user_input_field['value'][1:]) + + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data, + message=str(question), **kwargs) + + def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/application_node/impl/__init__.py b/apps/application/flow/step_node/application_node/impl/__init__.py new file mode 100644 index 00000000000..e31a8d885cd --- /dev/null +++ b/apps/application/flow/step_node/application_node/impl/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +from .base_application_node import BaseApplicationNode diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py new file mode 100644 index 00000000000..7f4644a5815 --- /dev/null +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -0,0 +1,124 @@ +# coding=utf-8 +import json +import time +import uuid +from typing import List, Dict +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.application_node.i_application_node import IApplicationNode +from application.models import Chat +from common.handle.impl.response.openai_to_response import OpenaiToResponse + + +def string_to_uuid(input_str): + return str(uuid.uuid5(uuid.NAMESPACE_DNS, input_str)) + + +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): + result = node_variable.get('result') + node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0) + node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0) + node.context['answer'] = answer + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): + workflow.answer += answer + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + usage = {} + for chunk in response: + # 先把流转成字符串 + response_content = chunk.decode('utf-8')[6:] + response_content = json.loads(response_content) + choices = response_content.get('choices') + if choices and isinstance(choices, list) and len(choices) > 0: + content = choices[0].get('delta', {}).get('content', '') + answer += content + yield content + usage = response_content.get('usage', {}) + node_variable['result'] = {'usage': usage} + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result')['choices'][0]['message'] + answer = response.get('content', '') or "抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。" + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +class BaseApplicationNode(IApplicationNode): + + def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, + **kwargs) -> NodeResult: + from application.serializers.chat_message_serializers import ChatMessageSerializer + # 生成嵌入应用的chat_id + current_chat_id = string_to_uuid(chat_id + application_id) + Chat.objects.get_or_create(id=current_chat_id, defaults={ + 'application_id': application_id, + 'abstract': message + }) + response = ChatMessageSerializer( + data={'chat_id': current_chat_id, 'message': message, + 're_chat': re_chat, + 'stream': stream, + 'application_id': application_id, + 'client_id': client_id, + 'client_type': client_type, 'form_data': kwargs}).chat(base_to_response=OpenaiToResponse()) + if response.status_code == 200: + if stream: + content_generator = response.streaming_content + return NodeResult({'result': content_generator, 'question': message}, {}, + _write_context=write_context_stream) + else: + data = json.loads(response.content) + return NodeResult({'result': data, 'question': message}, {}, + _write_context=write_context) + + def get_details(self, index: int, **kwargs): + global_fields = [] + for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []): + global_fields.append({ + 'label': api_input_field['variable'], + 'key': api_input_field['variable'], + 'value': self.workflow_manage.get_reference_field( + api_input_field['value'][0], + api_input_field['value'][1:]) + }) + for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []): + global_fields.append({ + 'label': user_input_field['label'], + 'key': user_input_field['field'], + 'value': self.workflow_manage.get_reference_field( + user_input_field['value'][0], + user_input_field['value'][1:]) + }) + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "info": self.node.properties.get('node_data'), + 'run_time': self.context.get('run_time'), + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message, + 'global_fields': global_fields + } diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index d2e99bce85d..0e5cb15fbcb 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -53,7 +53,7 @@ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwa self.__setattr__(keyword, kwargs.get(keyword)) -end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node'] +end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node'] class Flow: diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 0232313e824..86460f093bc 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -15,14 +15,11 @@ from functools import reduce from typing import Dict, List -from django.conf import settings from django.contrib.postgres.fields import ArrayField from django.core import cache, validators from django.core import signing -from django.core.paginator import Paginator from django.db import transaction, models from django.db.models import QuerySet, Q -from django.forms import CharField from django.http import HttpResponse from django.template import Template, Context from rest_framework import serializers @@ -46,10 +43,9 @@ from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list from embedding.models import SearchMode from function_lib.serializers.function_lib_serializer import FunctionLibSerializer -from setting.models import AuthOperate, TeamMember, TeamMemberPermission +from setting.models import AuthOperate from setting.models.model_management import Model from setting.models_provider import get_model_credential -from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.models_provider.tools import get_model_instance_by_model_user_id from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR @@ -979,6 +975,17 @@ def play_demo_text(self, form_data, with_valid=True): model = get_model_instance_by_model_user_id(tts_model_id, application.user_id, **form_data) return model.text_to_speech(text) + def application_list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + user_id = self.data.get('user_id') + application_id = self.data.get('application_id') + application = Application.objects.filter(user_id=user_id).exclude(id=application_id) + # 把应用的type为WORK_FLOW的应用放到最上面 然后再按名称排序 + serialized_data = ApplicationSerializerModel(application, many=True).data + application = sorted(serialized_data, key=lambda x: (x['type'] != 'WORK_FLOW', x['name'])) + return list(application) + class ApplicationKeySerializerModel(serializers.ModelSerializer): class Meta: model = ApplicationApiKey diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 488c244f973..61051f96d05 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -305,6 +305,8 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response): 'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()), 'stream': stream, 're_chat': re_chat, + 'client_id': client_id, + 'client_type': client_type, 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), base_to_response, form_data) r = work_flow_manage.run() diff --git a/apps/application/urls.py b/apps/application/urls.py index b3df23d73a2..5bd551b7b58 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -22,6 +22,7 @@ path('application/<str:application_id>/function_lib', views.Application.FunctionLib.as_view()), path('application/<str:application_id>/function_lib/<str:function_lib_id>', views.Application.FunctionLib.Operate.as_view()), + path('application/<str:application_id>/application', views.Application.Application.as_view()), path('application/<str:application_id>/model_params_form/<str:model_id>', views.Application.ModelParamsForm.as_view()), path('application/<str:application_id>/hit_test', views.Application.HitTest.as_view()), diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 64b6c367b0a..f0873d62c74 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -243,6 +243,24 @@ def get(self, request: Request, application_id: str, function_lib_id: str): data={'application_id': application_id, 'user_id': request.user.id}).get_function_lib(function_lib_id)) + class Application(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取当前人创建的应用列表", + operation_id="获取当前人创建的应用列表", + tags=["应用/会话"]) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, + 'user_id': request.user.id}).application_list()) + class Profile(APIView): authentication_classes = [TokenAuth] diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 922bbfc3c34..78eded77348 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -43,7 +43,7 @@ class ChatView(APIView): class Export(APIView): authentication_classes = [TokenAuth] - @action(methods=['GET'], detail=False) + @action(methods=['POST'], detail=False) @swagger_auto_schema(operation_summary="导出对话", operation_id="导出对话", manual_parameters=ChatApi.get_request_params_api(), diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index 92f9aaae49b..f41268e15ac 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -309,6 +309,18 @@ const listFunctionLib: (application_id: String, loading?: Ref<boolean>) => Promi ) => { return get(`${prefix}/${application_id}/function_lib`, undefined, loading) } +/** + * 获取当前人的所有应用列表 + * @param application_id 应用id + * @param loading + * @returns + */ +export const getApplicationList: ( + application_id: string, + loading?: Ref<boolean> +) => Promise<Result<any>> = (application_id, loading) => { + return get(`${prefix}/${application_id}/application`, undefined, loading) +} /** * 获取应用所属的函数库 * @param application_id @@ -500,5 +512,6 @@ export default { getWorkFlowVersionDetail, putWorkFlowVersion, playDemoText, - getUserList + getUserList, + getApplicationList } diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index ebee3cc628f..72dfbcb6cb2 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -17,7 +17,12 @@ <el-icon class="mr-8 arrow-icon" :class="current === index ? 'rotate-90' : ''" ><CaretRight /></el-icon> - <component :is="iconComponent(`${item.type}-icon`)" class="mr-8" :size="24" /> + <component + :is="iconComponent(`${item.type}-icon`)" + class="mr-8" + :size="24" + :item="item.info" + /> <h4>{{ item.name }}</h4> </div> <div class="flex align-center"> @@ -37,7 +42,11 @@ <div class="mt-12" v-if="current === index"> <template v-if="item.status === 200"> <!-- 开始 --> - <template v-if="item.type === WorkflowType.Start"> + <template + v-if=" + item.type === WorkflowType.Start || item.type === WorkflowType.Application + " + > <div class="card-never border-r-4"> <h5 class="p-8-12">参数输入</h5> <div class="p-8-12 border-t-dashed lighter"> @@ -84,15 +93,25 @@ </template> <!-- AI 对话 / 问题优化--> <template - v-if="item.type == WorkflowType.AiChat || item.type == WorkflowType.Question" + v-if=" + item.type == WorkflowType.AiChat || + item.type == WorkflowType.Question || + item.type == WorkflowType.Application + " > - <div class="card-never border-r-4"> + <div + class="card-never border-r-4" + v-if="item.type !== WorkflowType.Application" + > <h5 class="p-8-12">角色设定 (System)</h5> <div class="p-8-12 border-t-dashed lighter"> {{ item.system || '-' }} </div> </div> - <div class="card-never border-r-4 mt-8"> + <div + class="card-never border-r-4 mt-8" + v-if="item.type !== WorkflowType.Application" + > <h5 class="p-8-12">历史记录</h5> <div class="p-8-12 border-t-dashed lighter"> <template v-if="item.history_message?.length > 0"> @@ -115,7 +134,9 @@ </div> </div> <div class="card-never border-r-4 mt-8"> - <h5 class="p-8-12">AI 回答</h5> + <h5 class="p-8-12"> + {{ item.type == WorkflowType.Application ? '应用回答' : 'AI 回答' }} + </h5> <div class="p-8-12 border-t-dashed lighter"> <MdPreview v-if="item.answer" @@ -269,7 +290,7 @@ watch(dialogVisible, (bool) => { const open = (data: any) => { detail.value = cloneDeep(data) - + console.log(detail.value) dialogVisible.value = true } onBeforeUnmount(() => { diff --git a/ui/src/enums/workflow.ts b/ui/src/enums/workflow.ts index be2571a5c3e..aaf636cc22a 100644 --- a/ui/src/enums/workflow.ts +++ b/ui/src/enums/workflow.ts @@ -8,5 +8,6 @@ export enum WorkflowType { Reply = 'reply-node', FunctionLib = 'function-lib-node', FunctionLibCustom = 'function-node', - RrerankerNode = 'reranker-node' + RrerankerNode = 'reranker-node', + Application = 'application-node' } diff --git a/ui/src/views/application-workflow/component/DropdownMenu.vue b/ui/src/views/application-workflow/component/DropdownMenu.vue index e9e03e45d75..33381506435 100644 --- a/ui/src/views/application-workflow/component/DropdownMenu.vue +++ b/ui/src/views/application-workflow/component/DropdownMenu.vue @@ -43,7 +43,7 @@ <template v-for="(item, index) in filter_function_lib_list" :key="index"> <div class="workflow-dropdown-item cursor flex p-8-12" - @click.stop="clickNodes(functionLibNode, item)" + @click.stop="clickNodes(functionLibNode, item, 'function')" @mousedown.stop="onmousedown(functionLibNode, item)" > <component @@ -59,14 +59,45 @@ </template> </el-scrollbar> </el-tab-pane> + <el-tab-pane label="应用" name="application"> + <el-scrollbar height="400"> + <template v-for="(item, index) in filter_application_list" :key="index"> + <div + class="workflow-dropdown-item cursor flex p-8-12" + @click.stop="clickNodes(applicationNode, item, 'application')" + @mousedown.stop="onmousedown(applicationNode, item, 'application')" + > + <component + :is="iconComponent(`application-node-icon`)" + class="mr-8 mt-4" + :size="32" + :item="item" + /> + <div class="pre-wrap" style="width: 60%"> + <auto-tooltip :content="item.name" style="width: 80%" class="lighter"> + {{ item.name }} + </auto-tooltip> + <el-text type="info" size="small" style="width: 80%">{{ item.desc }}</el-text> + </div> + <div class="status-tag" style="margin-left: auto"> + <el-tag type="warning" v-if="isWorkFlow(item.type)" style="height: 22px" + >高级编排</el-tag + > + <el-tag class="blue-tag" v-else style="height: 22px">简单配置</el-tag> + </div> + </div> + </template> + </el-scrollbar> + </el-tab-pane> </el-tabs> </div> </template> <script setup lang="ts"> import { ref, onMounted, computed } from 'vue' -import { menuNodes, functionLibNode, functionNode } from '@/workflow/common/data' +import { menuNodes, functionLibNode, functionNode, applicationNode } from '@/workflow/common/data' import { iconComponent } from '@/workflow/icons/utils' import applicationApi from '@/api/application' +import { isWorkFlow } from '@/utils/application' const search_text = ref<string>('') const props = defineProps({ show: { @@ -91,21 +122,48 @@ const filter_function_lib_list = computed(() => { item.name.toLocaleLowerCase().includes(search_text.value.toLocaleLowerCase()) ) }) +const applicationList = ref<any[]>([]) +const filter_application_list = computed(() => { + return applicationList.value.filter((item: any) => + item.name.toLocaleLowerCase().includes(search_text.value.toLocaleLowerCase()) + ) +}) + const filter_menu_nodes = computed(() => { return menuNodes.filter((item) => item.label.toLocaleLowerCase().includes(search_text.value.toLocaleLowerCase()) ) }) -function clickNodes(item: any, data?: any) { +function clickNodes(item: any, data?: any, type?: string) { if (data) { item['properties']['stepName'] = data.name - item['properties']['node_data'] = { - ...data, - function_lib_id: data.id, - input_field_list: data.input_field_list.map((field: any) => ({ - ...field, - value: field.source == 'reference' ? [] : '' - })) + if (type == 'function') { + item['properties']['node_data'] = { + ...data, + function_lib_id: data.id, + input_field_list: data.input_field_list.map((field: any) => ({ + ...field, + value: field.source == 'reference' ? [] : '' + })) + } + } + if (type == 'application') { + if (isWorkFlow(data.type)) { + console.log(data.work_flow.nodes[0].properties.api_input_field_list) + item['properties']['node_data'] = { + name: data.name, + icon: data.icon, + application_id: data.id, + api_input_field_list: data.work_flow.nodes[0].properties.api_input_field_list, + user_input_field_list: data.work_flow.nodes[0].properties.user_input_field_list + } + } else { + item['properties']['node_data'] = { + name: data.name, + icon: data.icon, + application_id: data.id + } + } } } props.workflowRef?.addNode(item) @@ -113,16 +171,35 @@ function clickNodes(item: any, data?: any) { emit('clickNodes', item) } -function onmousedown(item: any, data?: any) { +function onmousedown(item: any, data?: any, type?: string) { if (data) { item['properties']['stepName'] = data.name - item['properties']['node_data'] = { - ...data, - function_lib_id: data.id, - input_field_list: data.input_field_list.map((field: any) => ({ - ...field, - value: field.source == 'reference' ? [] : '' - })) + if (type == 'function') { + item['properties']['node_data'] = { + ...data, + function_lib_id: data.id, + input_field_list: data.input_field_list.map((field: any) => ({ + ...field, + value: field.source == 'reference' ? [] : '' + })) + } + } + if (type == 'application') { + if (isWorkFlow(data.type)) { + item['properties']['node_data'] = { + name: data.name, + icon: data.icon, + application_id: data.id, + api_input_field_list: data.work_flow.nodes[0].properties.api_input_field_list, + user_input_field_list: data.work_flow.nodes[0].properties.user_input_field_list + } + } else { + item['properties']['node_data'] = { + name: data.name, + icon: data.icon, + application_id: data.id + } + } } } props.workflowRef?.onmousedown(item) @@ -133,6 +210,9 @@ function getList() { applicationApi.listFunctionLib(props.id, loading).then((res: any) => { functionLibList.value = res.data }) + applicationApi.getApplicationList(props.id, loading).then((res: any) => { + applicationList.value = res.data + }) } onMounted(() => { diff --git a/ui/src/workflow/common/NodeContainer.vue b/ui/src/workflow/common/NodeContainer.vue index 50619d2d1a3..59df25c6880 100644 --- a/ui/src/workflow/common/NodeContainer.vue +++ b/ui/src/workflow/common/NodeContainer.vue @@ -11,7 +11,12 @@ class="flex align-center" :style="{ maxWidth: node_status == 200 ? 'calc(100% - 55px)' : 'calc(100% - 85px)' }" > - <component :is="iconComponent(`${nodeModel.type}-icon`)" class="mr-8" :size="24" /> + <component + :is="iconComponent(`${nodeModel.type}-icon`)" + class="mr-8" + :size="24" + :item="nodeModel?.properties.node_data" + /> <h4 v-if="showOperate(nodeModel.type)" style="max-width: 90%"> <ReadWrite @mousemove.stop diff --git a/ui/src/workflow/common/data.ts b/ui/src/workflow/common/data.ts index 8cca9bcf9f9..cfb9fe2a463 100644 --- a/ui/src/workflow/common/data.ts +++ b/ui/src/workflow/common/data.ts @@ -215,6 +215,24 @@ export const functionLibNode = { } } +export const applicationNode = { + type: WorkflowType.Application, + text: '应用节点', + label: '应用节点', + height: 260, + properties: { + stepName: '应用节点', + config: { + fields: [ + { + label: '结果', + value: 'result' + } + ] + } + } +} + export const compareList = [ { value: 'is_null', label: '为空' }, { value: 'is_not_null', label: '不为空' }, @@ -242,7 +260,8 @@ export const nodeDict: any = { [WorkflowType.Reply]: replyNode, [WorkflowType.FunctionLib]: functionLibNode, [WorkflowType.FunctionLibCustom]: functionNode, - [WorkflowType.RrerankerNode]: rerankerNode + [WorkflowType.RrerankerNode]: rerankerNode, + [WorkflowType.Application]: applicationNode } export function isWorkFlow(type: string | undefined) { return type === 'WORK_FLOW' diff --git a/ui/src/workflow/common/validate.ts b/ui/src/workflow/common/validate.ts index 000d13c0c1d..86f7d8af373 100644 --- a/ui/src/workflow/common/validate.ts +++ b/ui/src/workflow/common/validate.ts @@ -4,7 +4,8 @@ const end_nodes: Array<string> = [ WorkflowType.AiChat, WorkflowType.Reply, WorkflowType.FunctionLib, - WorkflowType.FunctionLibCustom + WorkflowType.FunctionLibCustom, + WorkflowType.Application ] export class WorkFlowInstance { nodes diff --git a/ui/src/workflow/icons/application-node-icon.vue b/ui/src/workflow/icons/application-node-icon.vue new file mode 100644 index 00000000000..928980c18b7 --- /dev/null +++ b/ui/src/workflow/icons/application-node-icon.vue @@ -0,0 +1,29 @@ +<template> + <AppAvatar + v-if="isAppIcon(item.icon)" + shape="square" + :size="32" + style="background: none" + class="mr-8" + > + <img :src="item.icon" alt="" /> + </AppAvatar> + <AppAvatar + v-else-if="item?.name" + :name="item?.name" + pinyinColor + shape="square" + :size="32" + class="mr-8" + /> +</template> +<script setup lang="ts"> +import { isAppIcon } from '@/utils/application' +import { defineProps } from 'vue' +const props = defineProps<{ + item: { + name: string + icon: string + } +}>() +</script> diff --git a/ui/src/workflow/nodes/application-node/index.ts b/ui/src/workflow/nodes/application-node/index.ts new file mode 100644 index 00000000000..6292ab5df20 --- /dev/null +++ b/ui/src/workflow/nodes/application-node/index.ts @@ -0,0 +1,12 @@ +import ChatNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class ChatNode extends AppNode { + constructor(props: any) { + super(props, ChatNodeVue) + } +} +export default { + type: 'application-node', + model: AppNodeModel, + view: ChatNode +} diff --git a/ui/src/workflow/nodes/application-node/index.vue b/ui/src/workflow/nodes/application-node/index.vue new file mode 100644 index 00000000000..922da2ba217 --- /dev/null +++ b/ui/src/workflow/nodes/application-node/index.vue @@ -0,0 +1,132 @@ +<template> + <NodeContainer :nodeModel="nodeModel"> + <h5 class="title-decoration-1 mb-8">节点设置</h5> + <el-card shadow="never" class="card-never"> + <el-form + @submit.prevent + :model="form_data" + label-position="top" + require-asterisk-position="right" + label-width="auto" + ref="applicationNodeFormRef" + > + <el-form-item + label="用户问题" + prop="question_reference_address" + :rules="{ + message: '请选择检索问题', + trigger: 'blur', + required: true + }" + > + <NodeCascader + ref="nodeCascaderRef" + :nodeModel="nodeModel" + class="w-full" + placeholder="请选择检索问题" + v-model="form_data.question_reference_address" + /> + </el-form-item> + <div v-for="(field, index) in form_data.api_input_field_list" :key="'api-input-' + index"> + <el-form-item + :label="field.variable" + :prop="'api_input_field_list.' + index + '.value'" + :rules="[ + { required: field.is_required, message: `请输入${field.variable}`, trigger: 'blur' } + ]" + > + <NodeCascader + ref="nodeCascaderRef" + :nodeModel="nodeModel" + class="w-full" + placeholder="请选择检索问题" + v-model="form_data.api_input_field_list[index].value" + /> + </el-form-item> + </div> + + <!-- Loop through dynamic fields for user_input_field_list --> + <div v-for="(field, index) in form_data.user_input_field_list" :key="'user-input-' + index"> + <el-form-item + :label="field.label" + :prop="'user_input_field_list.' + index + '.value'" + :rules="[ + { required: field.required, message: `请输入${field.label}`, trigger: 'blur' } + ]" + > + <NodeCascader + ref="nodeCascaderRef" + :nodeModel="nodeModel" + class="w-full" + placeholder="请选择检索问题" + v-model="form_data.user_input_field_list[index].value" + /> + </el-form-item> + </div> + <el-form-item label="返回内容" @click.prevent> + <template #label> + <div class="flex align-center"> + <div class="mr-4"> + <span>返回内容<span class="danger">*</span></span> + </div> + <el-tooltip effect="dark" placement="right" popper-class="max-w-200"> + <template #content> + 关闭后该节点的内容则不输出给用户。 + 如果你想让用户看到该节点的输出内容,请打开开关。 + </template> + <AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon> + </el-tooltip> + </div> + </template> + <el-switch size="small" v-model="form_data.is_result" /> + </el-form-item> + </el-form> + </el-card> + </NodeContainer> +</template> + +<script setup lang="ts"> +import { set, groupBy } from 'lodash' +import { app } from '@/main' +import NodeContainer from '@/workflow/common/NodeContainer.vue' +import { ref, computed, onMounted } from 'vue' +import NodeCascader from '@/workflow/common/NodeCascader.vue' +import type { FormInstance } from 'element-plus' + +const form = { + question_reference_address: [], + api_input_field_list: [], + user_input_field_list: [] +} + +const applicationNodeFormRef = ref<FormInstance>() + +const form_data = computed({ + get: () => { + if (props.nodeModel.properties.node_data) { + return props.nodeModel.properties.node_data + } else { + set(props.nodeModel.properties, 'node_data', form) + } + return props.nodeModel.properties.node_data + }, + set: (value) => { + set(props.nodeModel.properties, 'node_data', value) + } +}) + +const props = defineProps<{ nodeModel: any }>() + +const validate = () => { + return applicationNodeFormRef.value?.validate().catch((err) => { + return Promise.reject({ node: props.nodeModel, errMessage: err }) + }) +} + +onMounted(() => { + console.log(applicationNodeFormRef.value) + set(props.nodeModel, 'validate', validate) +}) +</script> + +<style lang="scss" scoped></style> diff --git a/ui/src/workflow/nodes/search-dataset-node/index.vue b/ui/src/workflow/nodes/search-dataset-node/index.vue index 64013673c7b..61eb8aa5805 100644 --- a/ui/src/workflow/nodes/search-dataset-node/index.vue +++ b/ui/src/workflow/nodes/search-dataset-node/index.vue @@ -198,6 +198,7 @@ function refresh() { } const validate = () => { + console.log(DatasetNodeFormRef.value) return Promise.all([ nodeCascaderRef.value.validate(), DatasetNodeFormRef.value?.validate()