Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: BaiLian Image Model #1844

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
ModelInfoManage
from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.embedding import \
AliyunBaiLianEmbeddingCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.image import QwenVLModelCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.llm import BaiLianLLMModelCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.reranker import \
AliyunBaiLianRerankerCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.stt import AliyunBaiLianSTTModelCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.tti import QwenTextToImageModelCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.tts import AliyunBaiLianTTSModelCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.image import QwenVLChatModel
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.llm import BaiLianChatModel
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.stt import AliyunBaiLianSpeechToText
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tti import QwenTextToImageModel
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tts import AliyunBaiLianTextToSpeech
from smartdoc.conf import PROJECT_DIR

Expand All @@ -30,6 +34,8 @@
aliyun_bai_lian_stt_model_credential = AliyunBaiLianSTTModelCredential()
aliyun_bai_lian_embedding_model_credential = AliyunBaiLianEmbeddingCredential()
aliyun_bai_lian_llm_model_credential = BaiLianLLMModelCredential()
qwenvl_model_credential = QwenVLModelCredential()
qwentti_model_credential = QwenTextToImageModelCredential()

model_info_list = [ModelInfo('gte-rerank',
'阿里巴巴通义实验室开发的GTE-Rerank文本排序系列模型,开发者可以通过LlamaIndex框架进行集成高质量文本检索、排序。',
Expand All @@ -52,9 +58,28 @@
BaiLianChatModel)
]

model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
model_info_list[1]).append_default_model_info(model_info_list[2]).append_default_model_info(
model_info_list[3]).append_default_model_info(model_info_list[4]).build()
module_info_vl_list = [
ModelInfo('qwen-vl-max', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
]
module_info_tti_list = [
ModelInfo('wanx-v1',
'通义万相-文本生成图像大模型,支持中英文双语输入,支持输入参考图片进行参考内容或者参考风格迁移,重点风格包括但不限于水彩、油画、中国画、素描、扁平插画、二次元、3D卡通。',
ModelTypeConst.TTI, qwentti_model_credential, QwenTextToImageModel),
]

model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(model_info_list)
.append_model_info_list(module_info_vl_list)
.append_model_info_list(module_info_tti_list)
.append_default_model_info(model_info_list[1])
.append_default_model_info(model_info_list[2])
.append_default_model_info(model_info_list[3])
.append_default_model_info(model_info_list[4])
.build()
)


class AliyunBaiLianModelProvider(IModelProvider):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: llm.py
@date:2024/7/11 18:41
@desc:
"""
import base64
import os
from typing import Dict

from langchain_core.messages import HumanMessage

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode


class QwenModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=1.0,
_min=0.1,
_max=1.9,
_step=0.01,
precision=2)

max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)


class QwenVLModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
for chunk in res:
print(chunk)
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

api_key = forms.PasswordInputField('API Key', required=True)

def get_model_params_setting_form(self, model_name):
return QwenModelParams()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码整体结构良好,但存在一些需要关注的问题和优化点:

  1. 导入顺序问题

    from langchain_core.messages import HumanMessage

    应该先导入 HumanMessage 类。

  2. API 错误处理

    except Exception as e:
        if isinstance(e, AppApiException):
            raise e
        if raise_exception:
            raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
        else:
            return False

    在捕获到其他异常时,建议只打印日志或记录错误信息,而不是重新抛出异常,除非有明确的需求。

  3. 加密逻辑简化

    def encryption_dict(self, model: Dict[str, object]):
        return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

    如果 BaseModelCredential 类已经实现了一个安全的加密方法,并且这个加密逻辑适用于 QwenVLModelCredential,可以直接调用而不需要重复编写加密。

  4. 模型测试功能

    try:
        model = provider.get_model(model_type, model_name, model_credential)
        res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
        for chunk in res:
            print(chunk)
    except Exception as e:
        if isinstance(e, AppApiException):
            raise e
        if raise_exception:
            raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
        else:
            return False

    需要添加调试信息以确保模型流获取成功,或者减少打印过多的日志以便于后续维护。

  5. 类成员顺序

    class QwenVLModelCredential(BaseForm, BaseModelCredential):
        api_key = forms.PasswordInputField('API Key', required=True)
        # 其他成员...

    推荐遵循驼峰命名法(PascalCase)来定义类成员名称,例如使用 apiKey 而不是 api_key

  6. 注释文档

    """
    @project: MaxKB
    @Author:虎
    @file: llm.py
    @date:2024/7/11 18:41
    @desc:
    """

通过上述改进,可以提高代码质量,避免潜在的安全风险和性能问题。

Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: llm.py
@date:2024/7/11 18:41
@desc:
"""
import base64
import os
from typing import Dict

from langchain_core.messages import HumanMessage

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode


class QwenModelParams(BaseForm):
size = forms.SingleSelect(
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
required=True,
default_value='1024*1024',
option_list=[
{'value': '1024*1024', 'label': '1024*1024'},
{'value': '720*1280', 'label': '720*1280'},
{'value': '768*1152', 'label': '768*1152'},
{'value': '1280*720', 'label': '1280*720'},
],
text_field='label',
value_field='value')
n = forms.SliderField(
TooltipLabel('图片数量', '指定生成图片的数量'),
required=True, default_value=1,
_min=1,
_max=4,
_step=1,
precision=0)
style = forms.SingleSelect(
TooltipLabel('风格', '指定生成图片的风格'),
required=True,
default_value='<auto>',
option_list=[
{'value': '<auto>', 'label': '默认值,由模型随机输出图像风格'},
{'value': '<photography>', 'label': '摄影'},
{'value': '<portrait>', 'label': '人像写真'},
{'value': '<3d cartoon>', 'label': '3D卡通'},
{'value': '<anime>', 'label': '动画'},
{'value': '<oil painting>', 'label': '油画'},
{'value': '<watercolor>', 'label': '水彩'},
{'value': '<sketch>', 'label': '素描'},
{'value': '<chinese painting>', 'label': '中国画'},
{'value': '<flat illustration>', 'label': '扁平插画'},
],
text_field='label',
value_field='value'
)


class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
res = model.check_auth()
print(res)
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

api_key = forms.PasswordInputField('API Key', required=True)

def get_model_params_setting_form(self, model_name):
return QwenModelParams()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码中存在以下不规范和潜在问题:

  1. from typing import Dict 应该是 from typing import * 来涵盖所有内置类型的导入。

  2. from common.forms import BaseForm, TooltipLabel 中的 BaseFormTooltipLabel 可能会与其他模块中的同名类冲突或重命名需求。

  3. QwenModelParams 类中的 _min, _max", 和 _step` 属性在 Python 3 中不需要使用下划线前缀。这些属性用于定义滑块组件的最小、最大和步长范围。如果需要支持旧版本,请保持不变;如需更新,可以修改为无前导下划线的形式。

  4. QwenTextToImageModelCredential 类中,方法参数顺序应与实际使用的顺序一致(即:raise_exception 参数应该放在最前面)。

  5. encryption_dict 方法中,对于 API 密钥(api_key)进行加密,并返回完整的解密后的模型字典(包含其他键值对)。这会导致任何依赖原始未加密 API 密钥的功能失效。建议只在其必要情况下进行加密操作。

优化建议:

  • 尽量减少不必要的字段,默认值设置得更为合理。
  • 异常处理逻辑中,当验证失败时返回具体的错误信息以帮助用户调试。
  • 确保 get_model_params_setting_form 方法返回的是正确的表单对象实例。

Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# coding=utf-8

from typing import Dict

from langchain_community.chat_models import ChatOpenAI

from setting.models_provider.base_model_provider import MaxKBBaseModel


class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
chat_tong_yi = QwenVLChatModel(
model_name=model_name,
openai_api_key=model_credential.get('api_key'),
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
# stream_options={"include_usage": True},
streaming=True,
model_kwargs=optional_params,
)
return chat_tong_yi
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# coding=utf-8
from http import HTTPStatus
from typing import Dict

from dashscope import ImageSynthesis
from langchain_community.chat_models import ChatTongyi
from langchain_core.messages import HumanMessage

from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tti import BaseTextToImage


class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
api_key: str
model_name: str
params: dict

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model_name = kwargs.get('model_name')
self.params = kwargs.get('params')

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {'size': '1024*1024', 'style': '<auto>', 'n': 1}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
chat_tong_yi = QwenTextToImageModel(
model_name=model_name,
api_key=model_credential.get('api_key'),
**optional_params,
)
return chat_tong_yi

def is_cache_model(self):
return False

def check_auth(self):
chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max')
chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])])

def generate_image(self, prompt: str, negative_prompt: str = None):
# api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
rsp = ImageSynthesis.call(api_key=self.api_key,
model=self.model_name,
prompt=prompt,
negative_prompt=negative_prompt,
**self.params)
file_urls = []
if rsp.status_code == HTTPStatus.OK:
for result in rsp.output.results:
file_urls.append(result.url)
else:
print('sync_call Failed, status_code: %s, code: %s, message: %s' %
(rsp.status_code, rsp.code, rsp.message))
return file_urls
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看起来这段代码是一个从阿里云迁移到DashScope的文本到图像模型实现。以下是存在的一些问题和改进建议:

  1. 缺少必要的导入: 还需要dashscope.ImageSynthesis模块和HTTPStatus类的导入。
  2. 不必要的方法调用: check_auth方法中没有实际的操作来验证身份认证,且不需要每次都创建一个全新的聊天对象。
  3. 参数默认值: 参数在构造函数初始化时已经覆盖了静态方法中的默认参数,这可能不是有意为之的。
@@ -0,0 +1,58 @@
+# coding=utf-8
+from http import HTTPStatus
+from typing import Dict
+
+from dashscope.dashscope_client import DashScopeClient  # 修正缺失导入
+from dashscope import ImageSynthesis
+from langchain_community.chat_models import ChatTongyi
+from langchain.core.messages import HumanMessage
+
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from setting.models_provider.impl.base_tti import BaseTextToImage
+
+
+class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
+    api_key: str
+    model_name: str
+    params: dict
    

以上是主要的问题和改进建议,可以根据实际情况进一步调整其他部分。

Loading