Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: 工作流版本管理 #1386

Merged
merged 6 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions apps/application/migrations/0018_workflowversion_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Generated by Django 4.2.15 on 2024-10-16 15:17

from django.db import migrations, models

sql = """
UPDATE "public".application_work_flow_version
SET "name" = TO_CHAR(create_time, 'YYYY-MM-DD HH24:MI:SS');
"""


class Migration(migrations.Migration):
dependencies = [
('application', '0017_application_tts_model_params_setting'),
]

operations = [
migrations.AddField(
model_name='workflowversion',
name='name',
field=models.CharField(default='', max_length=128, verbose_name='版本名称'),
),
migrations.RunSQL(sql),
migrations.AddField(
model_name='workflowversion',
name='publish_user_id',
field=models.UUIDField(default=None, null=True, verbose_name='发布者id'),
),
migrations.AddField(
model_name='workflowversion',
name='publish_user_name',
field=models.CharField(default='', max_length=128, verbose_name='发布者名称'),
),
]
3 changes: 3 additions & 0 deletions apps/application/models/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class Meta:
class WorkFlowVersion(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
application = models.ForeignKey(Application, on_delete=models.CASCADE)
name = models.CharField(verbose_name="版本名称", max_length=128, default="")
publish_user_id = models.UUIDField(verbose_name="发布者id", max_length=128, default=None, null=True)
publish_user_name = models.CharField(verbose_name="发布者名称", max_length=128, default="")
work_flow = models.JSONField(verbose_name="工作流数据", default=dict)

class Meta:
Expand Down
10 changes: 9 additions & 1 deletion apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
@date:2023/11/7 10:02
@desc:
"""
import datetime
import hashlib
import json
import os
Expand Down Expand Up @@ -50,6 +51,7 @@
from setting.models_provider.tools import get_model_instance_by_model_user_id
from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR
from users.models import User

chat_cache = cache.caches['chat_cache']

Expand Down Expand Up @@ -684,6 +686,8 @@ def delete(self, with_valid=True):
def publish(self, instance, with_valid=True):
if with_valid:
self.is_valid()
user_id = self.data.get('user_id')
user = QuerySet(User).filter(id=user_id).first()
application = QuerySet(Application).filter(id=self.data.get("application_id")).first()
work_flow = instance.get('work_flow')
if work_flow is None:
Expand All @@ -703,7 +707,10 @@ def publish(self, instance, with_valid=True):
application.save()
# 插入知识库关联关系
self.save_application_mapping(application_dataset_id_list, dataset_id_list, application.id)
work_flow_version = WorkFlowVersion(work_flow=work_flow, application=application)
work_flow_version = WorkFlowVersion(work_flow=work_flow, application=application,
name=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
publish_user_id=user_id,
publish_user_name=user.username)
chat_cache.clear_by_application_id(str(application.id))
work_flow_version.save()
return True
Expand Down Expand Up @@ -1002,6 +1009,7 @@ def text_to_speech(self, text, with_valid=True):
if application.tts_model_enable:
model = get_model_instance_by_model_user_id(application.tts_model_id, application.user_id,
**application.tts_model_params_setting)

return model.text_to_speech(text)

def play_demo_text(self, form_data, with_valid=True):
Expand Down
84 changes: 84 additions & 0 deletions apps/application/serializers/application_version_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: application_version_serializers.py
@date:2024/10/15 16:42
@desc:
"""
from typing import Dict

from django.db.models import QuerySet
from rest_framework import serializers

from application.models import WorkFlowVersion
from common.db.search import page_search
from common.exception.app_exception import AppApiException
from common.util.field_message import ErrMessage


class ApplicationVersionModelSerializer(serializers.ModelSerializer):
class Meta:
model = WorkFlowVersion
fields = ['id', 'name', 'application_id', 'work_flow', 'publish_user_id', 'publish_user_name', 'create_time',
'update_time']


class ApplicationVersionEditSerializer(serializers.Serializer):
name = serializers.CharField(required=False, max_length=128, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("版本名称"))


class ApplicationVersionSerializer(serializers.Serializer):
class Query(serializers.Serializer):
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("应用id"))
name = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("摘要"))

def get_query_set(self):
query_set = QuerySet(WorkFlowVersion).filter(application_id=self.data.get('application_id'))
if 'name' in self.data and self.data.get('name') is not None:
query_set = query_set.filter(name__contains=self.data.get('name'))
return query_set.order_by("-create_time")

def list(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
query_set = self.get_query_set()
return [ApplicationVersionModelSerializer(v).data for v in query_set]

def page(self, current_page, page_size, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return page_search(current_page, page_size,
self.get_query_set(),
post_records_handler=lambda v: ApplicationVersionModelSerializer(v).data)

class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("应用id"))
work_flow_version_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("工作流版本id"))

def one(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=self.data.get('application_id'),
id=self.data.get('work_flow_version_id')).first()
if work_flow_version is not None:
return ApplicationVersionModelSerializer(work_flow_version).data
else:
raise AppApiException(500, '不存在的工作流版本')

def edit(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ApplicationVersionEditSerializer(data=instance).is_valid(raise_exception=True)
work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=self.data.get('application_id'),
id=self.data.get('work_flow_version_id')).first()
if work_flow_version is not None:
name = instance.get('name', None)
if name is not None and len(name) > 0:
work_flow_version.name = name
work_flow_version.save()
return ApplicationVersionModelSerializer(work_flow_version).data
else:
raise AppApiException(500, '不存在的工作流版本')
69 changes: 69 additions & 0 deletions apps/application/swagger_api/application_version_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: application_version_api.py
@date:2024/10/15 17:18
@desc:
"""
from drf_yasg import openapi

from common.mixins.api_mixin import ApiMixin


class ApplicationVersionApi(ApiMixin):
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'name', 'work_flow', 'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_NUMBER, title="主键id",
description="主键id"),
'name': openapi.Schema(type=openapi.TYPE_NUMBER, title="版本名称",
description="版本名称"),
'work_flow': openapi.Schema(type=openapi.TYPE_STRING, title="工作流数据", description='工作流数据'),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间')
}
)

class Query(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
openapi.Parameter(name='name',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='版本名称')]

class Operate(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
openapi.Parameter(name='work_flow_version_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用版本id'), ]

class Edit(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=[],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="版本名称",
description="版本名称")
}
)
7 changes: 6 additions & 1 deletion apps/application/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@
name='application/audio'),
path('application/<str:application_id>/text_to_speech', views.Application.TextToSpeech.as_view(),
name='application/audio'),
path('application/<str:application_id>/work_flow_version', views.ApplicationVersionView.as_view()),
path('application/<str:application_id>/work_flow_version/<int:current_page>/<int:page_size>',
views.ApplicationVersionView.Page.as_view()),
path('application/<str:application_id>/work_flow_version/<str:work_flow_version_id>',
views.ApplicationVersionView.Operate.as_view()),
path('application/<str:application_id>/play_demo_text', views.Application.PlayDemoText.as_view(),
name='application/audio'),
name='application/audio')

]
1 change: 1 addition & 0 deletions apps/application/views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
"""
from .application_views import *
from .chat_views import *
from .application_version_views import *
89 changes: 89 additions & 0 deletions apps/application/views/application_version_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: application_version_views.py
@date:2024/10/15 16:49
@desc:
"""
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.views import APIView

from application.serializers.application_version_serializers import ApplicationVersionSerializer
from application.swagger_api.application_version_api import ApplicationVersionApi
from common.auth import has_permissions, TokenAuth
from common.constants.permission_constants import PermissionConstants, CompareConstants, ViewPermission, RoleConstants, \
Permission, Group, Operate
from common.response import result


class ApplicationVersionView(APIView):
authentication_classes = [TokenAuth]

@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取应用列表",
operation_id="获取应用列表",
manual_parameters=ApplicationVersionApi.Query.get_request_params_api(),
responses=result.get_api_array_response(ApplicationVersionApi.get_response_body_api()),
tags=['应用/版本'])
@has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND)
def get(self, request: Request, application_id: str):
return result.success(
ApplicationVersionSerializer.Query(
data={'name': request.query_params.get('name'), 'user_id': request.user.id,
'application_id': application_id}).list())

class Page(APIView):
authentication_classes = [TokenAuth]

@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="分页获取应用版本列表",
operation_id="分页获取应用版本列表",
manual_parameters=result.get_page_request_params(
ApplicationVersionApi.Query.get_request_params_api()),
responses=result.get_page_api_response(ApplicationVersionApi.get_response_body_api()),
tags=['应用/版本'])
@has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND)
def get(self, request: Request, application_id: str, current_page: int, page_size: int):
return result.success(
ApplicationVersionSerializer.Query(
data={'name': request.query_params.get('name'), 'user_id': request.user,
'application_id': application_id}).page(
current_page, page_size))

class Operate(APIView):
authentication_classes = [TokenAuth]

@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取应用版本详情",
operation_id="获取应用版本详情",
manual_parameters=ApplicationVersionApi.Operate.get_request_params_api(),
responses=result.get_api_response(ApplicationVersionApi.get_response_body_api()),
tags=['应用/版本'])
@has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND)
def get(self, request: Request, application_id: str, work_flow_version_id: str):
return result.success(
ApplicationVersionSerializer.Operate(
data={'user_id': request.user,
'application_id': application_id, 'work_flow_version_id': work_flow_version_id}).one())

@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="修改应用版本信息",
operation_id="修改应用版本信息",
manual_parameters=ApplicationVersionApi.Operate.get_request_params_api(),
request_body=ApplicationVersionApi.Edit.get_request_body_api(),
responses=result.get_api_response(ApplicationVersionApi.get_response_body_api()),
tags=['应用/版本'])
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND))
def put(self, request: Request, application_id: str, work_flow_version_id: str):
return result.success(
ApplicationVersionSerializer.Operate(
data={'application_id': application_id, 'work_flow_version_id': work_flow_version_id,
'user_id': request.user.id}).edit(
request.data))
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from typing import Dict, List

from langchain_core.embeddings import Embeddings
Expand Down Expand Up @@ -34,3 +35,7 @@ def new_instance(model_type: str, model_name: str, model_credential: Dict[str, s
secret_key=model_credential.get('SecretKey'),
model_name=model_name,
)

def _generate_auth_token(self):
# Example method to generate an authentication token for the model API
return f"{self.secret_id}:{self.secret_key}"
Loading
Loading