Skip to content

Commit f85ce4a

Browse files
committed
feat: 支持讯飞向量模型
1 parent 97cfd60 commit f85ce4a

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: embedding.py
6+
@date:2024/10/17 15:40
7+
@desc:
8+
"""
9+
from typing import Dict
10+
11+
from common import forms
12+
from common.exception.app_exception import AppApiException
13+
from common.forms import BaseForm
14+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
15+
16+
17+
class XFEmbeddingCredential(BaseForm, BaseModelCredential):
18+
19+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
20+
raise_exception=False):
21+
model_type_list = provider.get_model_type_list()
22+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
23+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
24+
self.valid_form(model_credential)
25+
try:
26+
model = provider.get_model(model_type, model_name, model_credential)
27+
model.embed_query('你好')
28+
except Exception as e:
29+
if isinstance(e, AppApiException):
30+
raise e
31+
if raise_exception:
32+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
33+
else:
34+
return False
35+
return True
36+
37+
def encryption_dict(self, model: Dict[str, object]):
38+
return model
39+
40+
base_url = forms.TextInputField('API 域名', required=True, default_value="https://emb-cn-huabei-1.xf-yun.com/")
41+
spark_app_id = forms.TextInputField('APP ID', required=True)
42+
spark_api_key = forms.PasswordInputField("API Key", required=True)
43+
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: embedding.py
6+
@date:2024/10/17 15:29
7+
@desc:
8+
"""
9+
10+
import base64
11+
import json
12+
from typing import Dict, Optional
13+
14+
import numpy as np
15+
from langchain_community.embeddings import SparkLLMTextEmbeddings
16+
from numpy import ndarray
17+
18+
from setting.models_provider.base_model_provider import MaxKBBaseModel
19+
20+
21+
class XFEmbedding(MaxKBBaseModel, SparkLLMTextEmbeddings):
22+
@staticmethod
23+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
24+
return XFEmbedding(
25+
spark_app_id=model_credential.get('spark_app_id'),
26+
spark_api_key=model_credential.get('spark_api_key'),
27+
spark_api_secret=model_credential.get('spark_api_secret')
28+
)
29+
30+
@staticmethod
31+
def _parser_message(
32+
message: str,
33+
) -> Optional[ndarray]:
34+
data = json.loads(message)
35+
code = data["header"]["code"]
36+
if code != 0:
37+
# 这里是讯飞的QPS限制会报错,所以不建议用讯飞的向量模型
38+
raise Exception(f"Request error: {code}, {data}")
39+
else:
40+
text_base = data["payload"]["feature"]["text"]
41+
text_data = base64.b64decode(text_base)
42+
dt = np.dtype(np.float32)
43+
dt = dt.newbyteorder("<")
44+
text = np.frombuffer(text_data, dtype=dt)
45+
if len(text) > 2560:
46+
array = text[:2560]
47+
else:
48+
array = text
49+
return array

apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
from common.util.file_util import get_file_content
1313
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
1414
ModelInfoManage
15+
from setting.models_provider.impl.xf_model_provider.credential.embedding import XFEmbeddingCredential
1516
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
1617
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
1718
from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
19+
from setting.models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
1820
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
1921
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
2022
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
@@ -25,12 +27,14 @@
2527
qwen_model_credential = XunFeiLLMModelCredential()
2628
stt_model_credential = XunFeiSTTModelCredential()
2729
tts_model_credential = XunFeiTTSModelCredential()
30+
embedding_model_credential = XFEmbeddingCredential()
2831
model_info_list = [
2932
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
3033
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
3134
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
3235
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
3336
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
37+
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
3438
]
3539

3640
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(

0 commit comments

Comments
 (0)