Skip to content

Commit a827396

Browse files
committed
feat: Gemini Image understand model
1 parent 89d17ad commit a827396

File tree

3 files changed

+121
-10
lines changed

3 files changed

+121
-10
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# coding=utf-8
2+
import base64
3+
import os
4+
from typing import Dict
5+
6+
from langchain_core.messages import HumanMessage
7+
8+
from common import forms
9+
from common.exception.app_exception import AppApiException
10+
from common.forms import BaseForm, TooltipLabel
11+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
12+
13+
class GeminiImageModelParams(BaseForm):
14+
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
15+
required=True, default_value=0.7,
16+
_min=0.1,
17+
_max=1.0,
18+
_step=0.01,
19+
precision=2)
20+
21+
max_tokens = forms.SliderField(
22+
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
23+
required=True, default_value=800,
24+
_min=1,
25+
_max=100000,
26+
_step=1,
27+
precision=0)
28+
29+
30+
31+
class GeminiImageModelCredential(BaseForm, BaseModelCredential):
32+
api_key = forms.PasswordInputField('API Key', required=True)
33+
34+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
35+
raise_exception=False):
36+
model_type_list = provider.get_model_type_list()
37+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
38+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
39+
40+
for key in ['api_key']:
41+
if key not in model_credential:
42+
if raise_exception:
43+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
44+
else:
45+
return False
46+
try:
47+
model = provider.get_model(model_type, model_name, model_credential)
48+
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
49+
for chunk in res:
50+
print(chunk)
51+
except Exception as e:
52+
if isinstance(e, AppApiException):
53+
raise e
54+
if raise_exception:
55+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
56+
else:
57+
return False
58+
return True
59+
60+
def encryption_dict(self, model: Dict[str, object]):
61+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
62+
63+
def get_model_params_setting_form(self, model_name):
64+
return GeminiImageModelParams()

apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,47 @@
1111
from common.util.file_util import get_file_content
1212
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
1313
ModelInfoManage
14+
from setting.models_provider.impl.gemini_model_provider.credential.image import GeminiImageModelCredential
1415
from setting.models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential
16+
from setting.models_provider.impl.gemini_model_provider.model.image import GeminiImage
1517
from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
1618
from smartdoc.conf import PROJECT_DIR
1719

1820
gemini_llm_model_credential = GeminiLLMModelCredential()
21+
gemini_image_model_credential = GeminiImageModelCredential()
1922

20-
gemini_1_pro = ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
21-
ModelTypeConst.LLM,
22-
gemini_llm_model_credential,
23-
GeminiChatModel)
23+
model_info_list = [
24+
ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
25+
ModelTypeConst.LLM,
26+
gemini_llm_model_credential,
27+
GeminiChatModel),
28+
ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型,随Google更新而更新',
29+
ModelTypeConst.LLM,
30+
gemini_llm_model_credential,
31+
GeminiChatModel),
32+
]
2433

25-
gemini_1_pro_vision = ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型,随Google更新而更新',
26-
ModelTypeConst.LLM,
27-
gemini_llm_model_credential,
28-
GeminiChatModel)
34+
model_image_info_list = [
35+
ModelInfo('gemini-1.5-flash', '最新的Gemini 1.5 Flash模型,随Google更新而更新',
36+
ModelTypeConst.IMAGE,
37+
gemini_image_model_credential,
38+
GeminiImage),
39+
ModelInfo('gemini-1.5-pro', '最新的Gemini 1.5 Flash模型,随Google更新而更新',
40+
ModelTypeConst.IMAGE,
41+
gemini_image_model_credential,
42+
GeminiImage),
43+
]
2944

30-
model_info_manage = ModelInfoManage.builder().append_model_info(gemini_1_pro).append_model_info(
31-
gemini_1_pro_vision).append_default_model_info(gemini_1_pro).build()
45+
46+
47+
model_info_manage = (
48+
ModelInfoManage.builder()
49+
.append_model_info_list(model_info_list)
50+
.append_model_info_list(model_image_info_list)
51+
.append_default_model_info(model_info_list[0])
52+
.append_default_model_info(model_image_info_list[0])
53+
.build()
54+
)
3255

3356

3457
class GeminiModelProvider(IModelProvider):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Dict
2+
3+
from langchain_google_genai import ChatGoogleGenerativeAI
4+
5+
from common.config.tokenizer_manage_config import TokenizerManage
6+
from setting.models_provider.base_model_provider import MaxKBBaseModel
7+
8+
9+
def custom_get_token_ids(text: str):
10+
tokenizer = TokenizerManage.get_tokenizer()
11+
return tokenizer.encode(text)
12+
13+
14+
class GeminiImage(MaxKBBaseModel, ChatGoogleGenerativeAI):
15+
16+
@staticmethod
17+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
18+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
19+
return GeminiImage(
20+
model=model_name,
21+
google_api_key=model_credential.get('api_key'),
22+
streaming=True,
23+
**optional_params,
24+
)

0 commit comments

Comments
 (0)