Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: 大语言模型支持自定义参数入参 #1458

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions apps/setting/models_provider/base_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
def is_cache_model():
return True

@staticmethod
def filter_optional_params(model_kwargs):
optional_params = {}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params[key] = value
return optional_params


class BaseModelCredential(ABC):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ def __init__(self, model_id: str, region_name: str, credentials_profile_name: st
@classmethod
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
**model_kwargs) -> 'BedrockModel':
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
keyword = get_max_tokens_keyword(model_name)
optional_params[keyword] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)

return cls(
model_id=model_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@ def is_cache_model():

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)

return AzureChatModel(
azure_endpoint=model_credential.get('api_base'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ def is_cache_model():

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)

deepseek_chat_open_ai = DeepSeekChatModel(
model=model_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ def is_cache_model():

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'temperature' in model_kwargs:
optional_params['temperature'] = model_kwargs['temperature']
if 'max_tokens' in model_kwargs:
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)

gemini_chat = GeminiChatModel(
model=model_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ def is_cache_model():

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)

kimi_chat_open_ai = KimiChatModel(
openai_api_base=model_credential['api_base'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)

return OllamaChatModel(model=model_name, openai_api_base=base_url,
openai_api_key=model_credential.get('api_key'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ def is_cache_model():

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
azure_chat_open_ai = OpenAIChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ def is_cache_model():

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
chat_tong_yi = QwenChatModel(
model_name=model_name,
dashscope_api_key=model_credential.get('api_key'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool
hunyuan_secret_id = credentials.get('hunyuan_secret_id')
hunyuan_secret_key = credentials.get('hunyuan_secret_key')

optional_params = {}
if 'temperature' in kwargs:
optional_params['temperature'] = kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(kwargs)

if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]):
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ def is_cache_model():

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
vllm_chat_open_ai = VllmChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@ def is_cache_model():

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return VolcanicEngineChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@ def is_cache_model():

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return QianfanChatModel(model=model_name,
qianfan_ak=model_credential.get('api_key'),
qianfan_sk=model_credential.get('secret_key'),
Expand Down
11 changes: 3 additions & 8 deletions apps/setting/models_provider/impl/xf_model_provider/model/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
@date:2024/04/19 15:55
@desc:
"""
import json
from typing import List, Optional, Any, Iterator, Dict

from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk, \
ChatSparkLLM
from langchain_community.chat_models.sparkllm import \
ChatSparkLLM, _convert_message_to_dict, _convert_delta_to_message_chunk
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
Expand All @@ -25,11 +24,7 @@ def is_cache_model():

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return XFChatSparkLLM(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# coding=utf-8

from typing import Dict
from typing import Dict, Optional, List, Any, Iterator
from urllib.parse import urlparse, ParseResult

from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessageChunk
from langchain_core.runnables import RunnableConfig

from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI

Expand All @@ -26,11 +30,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return XinferenceChatModel(
model=model_name,
openai_api_base=base_url,
Expand Down
24 changes: 11 additions & 13 deletions apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,41 @@
@desc:
"""

from langchain_community.chat_models import ChatZhipuAI
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
_convert_delta_to_message_chunk
from setting.models_provider.base_model_provider import MaxKBBaseModel
import json
from collections.abc import Iterator
from typing import Any, Dict, List, Optional

from langchain_community.chat_models import ChatZhipuAI
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
_convert_delta_to_message_chunk
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)

from langchain_core.messages import (
AIMessageChunk,
BaseMessage
)
from langchain_core.outputs import ChatGenerationChunk

from setting.models_provider.base_model_provider import MaxKBBaseModel


class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
optional_params: dict

@staticmethod
def is_cache_model():
return False

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']

optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
zhipuai_chat = ZhipuChatModel(
api_key=model_credential.get('api_key'),
model=model_name,
streaming=model_kwargs.get('streaming', False),
**optional_params
optional_params=optional_params,
**optional_params,
)
return zhipuai_chat

Expand Down Expand Up @@ -71,7 +69,7 @@ def _stream(
if self.zhipuai_api_base is None:
raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True}
_truncate_params(payload)
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
Expand Down
Loading