Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: 工作流增加嵌入应用 #1598

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion apps/application/flow/step_node/__init__.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
@desc:
"""
from .ai_chat_step_node import *
from .application_node import BaseApplicationNode
from .condition_node import *
from .question_node import *
from .search_dataset_node import *
@@ -17,7 +18,7 @@
from .reranker_node import *

node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode]
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode]


def get_node(node_type):
2 changes: 2 additions & 0 deletions apps/application/flow/step_node/application_node/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# coding=utf-8
from .impl import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# coding=utf-8
from typing import Type

from rest_framework import serializers

from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage


class ApplicationNodeSerializer(serializers.Serializer):
application_id = serializers.CharField(required=True, error_messages=ErrMessage.char("应用id"))
question_reference_address = serializers.ListField(required=True, error_messages=ErrMessage.list("用户问题"))
api_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.list("api输入字段"))
user_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.uuid("用户输入字段"))


class IApplicationNode(INode):
type = 'application-node'

def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ApplicationNodeSerializer

def _run(self):
question = self.workflow_manage.get_reference_field(
self.node_params_serializer.data.get('question_reference_address')[0],
self.node_params_serializer.data.get('question_reference_address')[1:])
kwargs = {}
for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []):
kwargs[api_input_field['variable']] = self.workflow_manage.get_reference_field(api_input_field['value'][0],
api_input_field['value'][1:])
for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []):
kwargs[user_input_field['field']] = self.workflow_manage.get_reference_field(user_input_field['value'][0],
user_input_field['value'][1:])

return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
message=str(question), **kwargs)

def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# coding=utf-8
from .base_application_node import BaseApplicationNode
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# coding=utf-8
import json
import time
import uuid
from typing import List, Dict
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.application_node.i_application_node import IApplicationNode
from application.models import Chat
from common.handle.impl.response.openai_to_response import OpenaiToResponse


def string_to_uuid(input_str):
return str(uuid.uuid5(uuid.NAMESPACE_DNS, input_str))


def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
result = node_variable.get('result')
node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0)
node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0)
node.context['answer'] = answer
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
workflow.answer += answer


def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据 (流式)
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
answer = ''
usage = {}
for chunk in response:
# 先把流转成字符串
response_content = chunk.decode('utf-8')[6:]
response_content = json.loads(response_content)
choices = response_content.get('choices')
if choices and isinstance(choices, list) and len(choices) > 0:
content = choices[0].get('delta', {}).get('content', '')
answer += content
yield content
usage = response_content.get('usage', {})
node_variable['result'] = {'usage': usage}
_write_context(node_variable, workflow_variable, node, workflow, answer)


def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点实例对象
@param workflow: 工作流管理器
"""
response = node_variable.get('result')['choices'][0]['message']
answer = response.get('content', '') or "抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。"
_write_context(node_variable, workflow_variable, node, workflow, answer)


class BaseApplicationNode(IApplicationNode):

def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
**kwargs) -> NodeResult:
from application.serializers.chat_message_serializers import ChatMessageSerializer
# 生成嵌入应用的chat_id
current_chat_id = string_to_uuid(chat_id + application_id)
Chat.objects.get_or_create(id=current_chat_id, defaults={
'application_id': application_id,
'abstract': message
})
response = ChatMessageSerializer(
data={'chat_id': current_chat_id, 'message': message,
're_chat': re_chat,
'stream': stream,
'application_id': application_id,
'client_id': client_id,
'client_type': client_type, 'form_data': kwargs}).chat(base_to_response=OpenaiToResponse())
if response.status_code == 200:
if stream:
content_generator = response.streaming_content
return NodeResult({'result': content_generator, 'question': message}, {},
_write_context=write_context_stream)
else:
data = json.loads(response.content)
return NodeResult({'result': data, 'question': message}, {},
_write_context=write_context)

def get_details(self, index: int, **kwargs):
global_fields = []
for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []):
global_fields.append({
'label': api_input_field['variable'],
'key': api_input_field['variable'],
'value': self.workflow_manage.get_reference_field(
api_input_field['value'][0],
api_input_field['value'][1:])
})
for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []):
global_fields.append({
'label': user_input_field['label'],
'key': user_input_field['field'],
'value': self.workflow_manage.get_reference_field(
user_input_field['value'][0],
user_input_field['value'][1:])
})
return {
'name': self.node.properties.get('stepName'),
"index": index,
"info": self.node.properties.get('node_data'),
'run_time': self.context.get('run_time'),
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message,
'global_fields': global_fields
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于提供的代码已经经过格式化和结构整理,因此在当前的情况下无法检测出不规范内容或潜在问题。然而,在处理实际项目时,您还需要关注以下几点:

  • 代码可能需要根据具体的业务需求进行调整。
  • 在写入上下文数据的过程中是否考虑了性能问题?
  • 是否存在与外部依赖关系有关的问题?

此外,请注意:

  • 如果涉及到敏感信息(如API密钥、聊天记录等),确保正确地存储这些值并保护它们免受未授权访问风险。

如果有详细的知识更新或者新的编程实践要求,可以参考最新的编码标准和最佳实务来优化现有代码库。

总的来说,尽管代码看起来简洁,但如果涉及跨部门合作或其他具体领域的话,还是可能存在特定领域的细微差别需要注意。

2 changes: 1 addition & 1 deletion apps/application/flow/workflow_manage.py
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwa
self.__setattr__(keyword, kwargs.get(keyword))


end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node']
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node']


class Flow:
17 changes: 12 additions & 5 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
@@ -15,14 +15,11 @@
from functools import reduce
from typing import Dict, List

from django.conf import settings
from django.contrib.postgres.fields import ArrayField
from django.core import cache, validators
from django.core import signing
from django.core.paginator import Paginator
from django.db import transaction, models
from django.db.models import QuerySet, Q
from django.forms import CharField
from django.http import HttpResponse
from django.template import Template, Context
from rest_framework import serializers
@@ -46,10 +43,9 @@
from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
from embedding.models import SearchMode
from function_lib.serializers.function_lib_serializer import FunctionLibSerializer
from setting.models import AuthOperate, TeamMember, TeamMemberPermission
from setting.models import AuthOperate
from setting.models.model_management import Model
from setting.models_provider import get_model_credential
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.models_provider.tools import get_model_instance_by_model_user_id
from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR
@@ -979,6 +975,17 @@ def play_demo_text(self, form_data, with_valid=True):
model = get_model_instance_by_model_user_id(tts_model_id, application.user_id, **form_data)
return model.text_to_speech(text)

def application_list(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
user_id = self.data.get('user_id')
application_id = self.data.get('application_id')
application = Application.objects.filter(user_id=user_id).exclude(id=application_id)
# 把应用的type为WORK_FLOW的应用放到最上面 然后再按名称排序
serialized_data = ApplicationSerializerModel(application, many=True).data
application = sorted(serialized_data, key=lambda x: (x['type'] != 'WORK_FLOW', x['name']))
return list(application)

class ApplicationKeySerializerModel(serializers.ModelSerializer):
class Meta:
model = ApplicationApiKey
2 changes: 2 additions & 0 deletions apps/application/serializers/chat_message_serializers.py
Original file line number Diff line number Diff line change
@@ -305,6 +305,8 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
'stream': stream,
're_chat': re_chat,
'client_id': client_id,
'client_type': client_type,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
base_to_response, form_data)
r = work_flow_manage.run()
1 change: 1 addition & 0 deletions apps/application/urls.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
path('application/<str:application_id>/function_lib', views.Application.FunctionLib.as_view()),
path('application/<str:application_id>/function_lib/<str:function_lib_id>',
views.Application.FunctionLib.Operate.as_view()),
path('application/<str:application_id>/application', views.Application.Application.as_view()),
path('application/<str:application_id>/model_params_form/<str:model_id>',
views.Application.ModelParamsForm.as_view()),
path('application/<str:application_id>/hit_test', views.Application.HitTest.as_view()),
18 changes: 18 additions & 0 deletions apps/application/views/application_views.py
Original file line number Diff line number Diff line change
@@ -243,6 +243,24 @@ def get(self, request: Request, application_id: str, function_lib_id: str):
data={'application_id': application_id,
'user_id': request.user.id}).get_function_lib(function_lib_id))

class Application(APIView):
authentication_classes = [TokenAuth]

@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取当前人创建的应用列表",
operation_id="获取当前人创建的应用列表",
tags=["应用/会话"])
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND))
def get(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id,
'user_id': request.user.id}).application_list())

class Profile(APIView):
authentication_classes = [TokenAuth]

2 changes: 1 addition & 1 deletion apps/application/views/chat_views.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ class ChatView(APIView):
class Export(APIView):
authentication_classes = [TokenAuth]

@action(methods=['GET'], detail=False)
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="导出对话",
operation_id="导出对话",
manual_parameters=ChatApi.get_request_params_api(),
15 changes: 14 additions & 1 deletion ui/src/api/application.ts
Original file line number Diff line number Diff line change
@@ -309,6 +309,18 @@ const listFunctionLib: (application_id: String, loading?: Ref<boolean>) => Promi
) => {
return get(`${prefix}/${application_id}/function_lib`, undefined, loading)
}
/**
* 获取当前人的所有应用列表
* @param application_id 应用id
* @param loading
* @returns
*/
export const getApplicationList: (
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/application`, undefined, loading)
}
/**
* 获取应用所属的函数库
* @param application_id
@@ -500,5 +512,6 @@ export default {
getWorkFlowVersionDetail,
putWorkFlowVersion,
playDemoText,
getUserList
getUserList,
getApplicationList
}
35 changes: 28 additions & 7 deletions ui/src/components/ai-chat/ExecutionDetailDialog.vue
Original file line number Diff line number Diff line change
@@ -17,7 +17,12 @@
<el-icon class="mr-8 arrow-icon" :class="current === index ? 'rotate-90' : ''"
><CaretRight
/></el-icon>
<component :is="iconComponent(`${item.type}-icon`)" class="mr-8" :size="24" />
<component
:is="iconComponent(`${item.type}-icon`)"
class="mr-8"
:size="24"
:item="item.info"
/>
<h4>{{ item.name }}</h4>
</div>
<div class="flex align-center">
@@ -37,7 +42,11 @@
<div class="mt-12" v-if="current === index">
<template v-if="item.status === 200">
<!-- 开始 -->
<template v-if="item.type === WorkflowType.Start">
<template
v-if="
item.type === WorkflowType.Start || item.type === WorkflowType.Application
"
>
<div class="card-never border-r-4">
<h5 class="p-8-12">参数输入</h5>
<div class="p-8-12 border-t-dashed lighter">
@@ -84,15 +93,25 @@
</template>
<!-- AI 对话 / 问题优化-->
<template
v-if="item.type == WorkflowType.AiChat || item.type == WorkflowType.Question"
v-if="
item.type == WorkflowType.AiChat ||
item.type == WorkflowType.Question ||
item.type == WorkflowType.Application
"
>
<div class="card-never border-r-4">
<div
class="card-never border-r-4"
v-if="item.type !== WorkflowType.Application"
>
<h5 class="p-8-12">角色设定 (System)</h5>
<div class="p-8-12 border-t-dashed lighter">
{{ item.system || '-' }}
</div>
</div>
<div class="card-never border-r-4 mt-8">
<div
class="card-never border-r-4 mt-8"
v-if="item.type !== WorkflowType.Application"
>
<h5 class="p-8-12">历史记录</h5>
<div class="p-8-12 border-t-dashed lighter">
<template v-if="item.history_message?.length > 0">
@@ -115,7 +134,9 @@
</div>
</div>
<div class="card-never border-r-4 mt-8">
<h5 class="p-8-12">AI 回答</h5>
<h5 class="p-8-12">
{{ item.type == WorkflowType.Application ? '应用回答' : 'AI 回答' }}
</h5>
<div class="p-8-12 border-t-dashed lighter">
<MdPreview
v-if="item.answer"
@@ -269,7 +290,7 @@ watch(dialogVisible, (bool) => {

const open = (data: any) => {
detail.value = cloneDeep(data)

console.log(detail.value)
dialogVisible.value = true
}
onBeforeUnmount(() => {
Loading
Loading