-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
feat: BaiLian Image Model #1844
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: MaxKB | ||
@Author:虎 | ||
@file: llm.py | ||
@date:2024/7/11 18:41 | ||
@desc: | ||
""" | ||
import base64 | ||
import os | ||
from typing import Dict | ||
|
||
from langchain_core.messages import HumanMessage | ||
|
||
from common import forms | ||
from common.exception.app_exception import AppApiException | ||
from common.forms import BaseForm, TooltipLabel | ||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
||
|
||
class QwenModelParams(BaseForm): | ||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), | ||
required=True, default_value=1.0, | ||
_min=0.1, | ||
_max=1.9, | ||
_step=0.01, | ||
precision=2) | ||
|
||
max_tokens = forms.SliderField( | ||
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), | ||
required=True, default_value=800, | ||
_min=1, | ||
_max=100000, | ||
_step=1, | ||
precision=0) | ||
|
||
|
||
class QwenVLModelCredential(BaseForm, BaseModelCredential): | ||
|
||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, | ||
raise_exception=False): | ||
model_type_list = provider.get_model_type_list() | ||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') | ||
for key in ['api_key']: | ||
if key not in model_credential: | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') | ||
else: | ||
return False | ||
try: | ||
model = provider.get_model(model_type, model_name, model_credential) | ||
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) | ||
for chunk in res: | ||
print(chunk) | ||
except Exception as e: | ||
if isinstance(e, AppApiException): | ||
raise e | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') | ||
else: | ||
return False | ||
return True | ||
|
||
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
||
api_key = forms.PasswordInputField('API Key', required=True) | ||
|
||
def get_model_params_setting_form(self, model_name): | ||
return QwenModelParams() | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: MaxKB | ||
@Author:虎 | ||
@file: llm.py | ||
@date:2024/7/11 18:41 | ||
@desc: | ||
""" | ||
import base64 | ||
import os | ||
from typing import Dict | ||
|
||
from langchain_core.messages import HumanMessage | ||
|
||
from common import forms | ||
from common.exception.app_exception import AppApiException | ||
from common.forms import BaseForm, TooltipLabel | ||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
||
|
||
class QwenModelParams(BaseForm): | ||
size = forms.SingleSelect( | ||
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'), | ||
required=True, | ||
default_value='1024*1024', | ||
option_list=[ | ||
{'value': '1024*1024', 'label': '1024*1024'}, | ||
{'value': '720*1280', 'label': '720*1280'}, | ||
{'value': '768*1152', 'label': '768*1152'}, | ||
{'value': '1280*720', 'label': '1280*720'}, | ||
], | ||
text_field='label', | ||
value_field='value') | ||
n = forms.SliderField( | ||
TooltipLabel('图片数量', '指定生成图片的数量'), | ||
required=True, default_value=1, | ||
_min=1, | ||
_max=4, | ||
_step=1, | ||
precision=0) | ||
style = forms.SingleSelect( | ||
TooltipLabel('风格', '指定生成图片的风格'), | ||
required=True, | ||
default_value='<auto>', | ||
option_list=[ | ||
{'value': '<auto>', 'label': '默认值,由模型随机输出图像风格'}, | ||
{'value': '<photography>', 'label': '摄影'}, | ||
{'value': '<portrait>', 'label': '人像写真'}, | ||
{'value': '<3d cartoon>', 'label': '3D卡通'}, | ||
{'value': '<anime>', 'label': '动画'}, | ||
{'value': '<oil painting>', 'label': '油画'}, | ||
{'value': '<watercolor>', 'label': '水彩'}, | ||
{'value': '<sketch>', 'label': '素描'}, | ||
{'value': '<chinese painting>', 'label': '中国画'}, | ||
{'value': '<flat illustration>', 'label': '扁平插画'}, | ||
], | ||
text_field='label', | ||
value_field='value' | ||
) | ||
|
||
|
||
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): | ||
|
||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, | ||
raise_exception=False): | ||
model_type_list = provider.get_model_type_list() | ||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') | ||
for key in ['api_key']: | ||
if key not in model_credential: | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') | ||
else: | ||
return False | ||
try: | ||
model = provider.get_model(model_type, model_name, model_credential) | ||
res = model.check_auth() | ||
print(res) | ||
except Exception as e: | ||
if isinstance(e, AppApiException): | ||
raise e | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') | ||
else: | ||
return False | ||
return True | ||
|
||
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
||
api_key = forms.PasswordInputField('API Key', required=True) | ||
|
||
def get_model_params_setting_form(self, model_name): | ||
return QwenModelParams() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 代码中存在以下不规范和潜在问题:
优化建议:
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# coding=utf-8 | ||
|
||
from typing import Dict | ||
|
||
from langchain_community.chat_models import ChatOpenAI | ||
|
||
from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
|
||
|
||
class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI): | ||
|
||
@staticmethod | ||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) | ||
chat_tong_yi = QwenVLChatModel( | ||
model_name=model_name, | ||
openai_api_key=model_credential.get('api_key'), | ||
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', | ||
# stream_options={"include_usage": True}, | ||
streaming=True, | ||
model_kwargs=optional_params, | ||
) | ||
return chat_tong_yi |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# coding=utf-8 | ||
from http import HTTPStatus | ||
from typing import Dict | ||
|
||
from dashscope import ImageSynthesis | ||
from langchain_community.chat_models import ChatTongyi | ||
from langchain_core.messages import HumanMessage | ||
|
||
from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
from setting.models_provider.impl.base_tti import BaseTextToImage | ||
|
||
|
||
class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage): | ||
api_key: str | ||
model_name: str | ||
params: dict | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.api_key = kwargs.get('api_key') | ||
self.model_name = kwargs.get('model_name') | ||
self.params = kwargs.get('params') | ||
|
||
@staticmethod | ||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
optional_params = {'params': {'size': '1024*1024', 'style': '<auto>', 'n': 1}} | ||
for key, value in model_kwargs.items(): | ||
if key not in ['model_id', 'use_local', 'streaming']: | ||
optional_params['params'][key] = value | ||
chat_tong_yi = QwenTextToImageModel( | ||
model_name=model_name, | ||
api_key=model_credential.get('api_key'), | ||
**optional_params, | ||
) | ||
return chat_tong_yi | ||
|
||
def is_cache_model(self): | ||
return False | ||
|
||
def check_auth(self): | ||
chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max') | ||
chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])]) | ||
|
||
def generate_image(self, prompt: str, negative_prompt: str = None): | ||
# api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', | ||
rsp = ImageSynthesis.call(api_key=self.api_key, | ||
model=self.model_name, | ||
prompt=prompt, | ||
negative_prompt=negative_prompt, | ||
**self.params) | ||
file_urls = [] | ||
if rsp.status_code == HTTPStatus.OK: | ||
for result in rsp.output.results: | ||
file_urls.append(result.url) | ||
else: | ||
print('sync_call Failed, status_code: %s, code: %s, message: %s' % | ||
(rsp.status_code, rsp.code, rsp.message)) | ||
return file_urls | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看起来这段代码是一个从阿里云迁移到DashScope的文本到图像模型实现。以下是存在的一些问题和改进建议:
@@ -0,0 +1,58 @@
+# coding=utf-8
+from http import HTTPStatus
+from typing import Dict
+
+from dashscope.dashscope_client import DashScopeClient # 修正缺失导入
+from dashscope import ImageSynthesis
+from langchain_community.chat_models import ChatTongyi
+from langchain.core.messages import HumanMessage
+
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from setting.models_provider.impl.base_tti import BaseTextToImage
+
+
+class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
+ api_key: str
+ model_name: str
+ params: dict
以上是主要的问题和改进建议,可以根据实际情况进一步调整其他部分。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码整体结构良好,但存在一些需要关注的问题和优化点:
导入顺序问题:
应该先导入
HumanMessage
类。API 错误处理:
在捕获到其他异常时,建议只打印日志或记录错误信息,而不是重新抛出异常,除非有明确的需求。
加密逻辑简化:
如果
BaseModelCredential
类已经实现了一个安全的加密方法,并且这个加密逻辑适用于QwenVLModelCredential
,可以直接调用而不需要重复编写加密。模型测试功能:
需要添加调试信息以确保模型流获取成功,或者减少打印过多的日志以便于后续维护。
类成员顺序:
推荐遵循驼峰命名法(PascalCase)来定义类成员名称,例如使用
apiKey
而不是api_key
。注释文档:
通过上述改进,可以提高代码质量,避免潜在的安全风险和性能问题。