|
| 1 | +# coding=utf-8 |
| 2 | +import base64 |
| 3 | +import os |
| 4 | +from typing import Dict |
| 5 | + |
| 6 | +from langchain_core.messages import HumanMessage |
| 7 | + |
| 8 | +from common import forms |
| 9 | +from common.exception.app_exception import AppApiException |
| 10 | +from common.forms import BaseForm, TooltipLabel |
| 11 | +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode |
| 12 | + |
| 13 | +class GeminiImageModelParams(BaseForm): |
| 14 | + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), |
| 15 | + required=True, default_value=0.7, |
| 16 | + _min=0.1, |
| 17 | + _max=1.0, |
| 18 | + _step=0.01, |
| 19 | + precision=2) |
| 20 | + |
| 21 | + max_tokens = forms.SliderField( |
| 22 | + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), |
| 23 | + required=True, default_value=800, |
| 24 | + _min=1, |
| 25 | + _max=100000, |
| 26 | + _step=1, |
| 27 | + precision=0) |
| 28 | + |
| 29 | + |
| 30 | + |
| 31 | +class GeminiImageModelCredential(BaseForm, BaseModelCredential): |
| 32 | + api_key = forms.PasswordInputField('API Key', required=True) |
| 33 | + |
| 34 | + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, |
| 35 | + raise_exception=False): |
| 36 | + model_type_list = provider.get_model_type_list() |
| 37 | + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): |
| 38 | + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') |
| 39 | + |
| 40 | + for key in ['api_key']: |
| 41 | + if key not in model_credential: |
| 42 | + if raise_exception: |
| 43 | + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') |
| 44 | + else: |
| 45 | + return False |
| 46 | + try: |
| 47 | + model = provider.get_model(model_type, model_name, model_credential) |
| 48 | + res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) |
| 49 | + for chunk in res: |
| 50 | + print(chunk) |
| 51 | + except Exception as e: |
| 52 | + if isinstance(e, AppApiException): |
| 53 | + raise e |
| 54 | + if raise_exception: |
| 55 | + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') |
| 56 | + else: |
| 57 | + return False |
| 58 | + return True |
| 59 | + |
| 60 | + def encryption_dict(self, model: Dict[str, object]): |
| 61 | + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} |
| 62 | + |
| 63 | + def get_model_params_setting_form(self, model_name): |
| 64 | + return GeminiImageModelParams() |
0 commit comments