From 4b48e490fab07adae07b9fad63afcab819864e4f Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 5 Nov 2024 17:11:33 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0Mistral=E6=B8=A0?= =?UTF-8?q?=E9=81=93=20(close=20#546)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/constants.go | 2 + relay/channel/mistral/adaptor.go | 72 ++++++++++++++++++++++++++ relay/channel/mistral/constants.go | 12 +++++ relay/channel/mistral/text.go | 40 ++++++++++++++ relay/constant/api_type.go | 3 ++ relay/relay_adaptor.go | 3 ++ web/src/constants/channel.constants.js | 1 + 7 files changed, 133 insertions(+) create mode 100644 relay/channel/mistral/adaptor.go create mode 100644 relay/channel/mistral/constants.go create mode 100644 relay/channel/mistral/text.go diff --git a/common/constants.go b/common/constants.go index 1f4d3f806..d724691c3 100644 --- a/common/constants.go +++ b/common/constants.go @@ -222,6 +222,7 @@ const ( ChannelCloudflare = 39 ChannelTypeSiliconFlow = 40 ChannelTypeVertexAi = 41 + ChannelTypeMistral = 42 ChannelTypeDummy // this one is only for count, do not add any channel after this @@ -270,4 +271,5 @@ var ChannelBaseURLs = []string{ "https://api.cloudflare.com", //39 "https://api.siliconflow.cn", //40 "", //41 + "https://api.mistral.ai", //42 } diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go new file mode 100644 index 000000000..5ca095d81 --- /dev/null +++ b/relay/channel/mistral/adaptor.go @@ -0,0 +1,72 @@ +package mistral + +import ( + "errors" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/channel/openai" + relaycommon "one-api/relay/common" +) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Header.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + mistralReq := requestOpenAI2Mistral(*request) + //common.LogJson(c, "body", mistralReq) + return mistralReq, nil +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = openai.OaiStreamHandler(c, resp, info) + } else { + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/mistral/constants.go b/relay/channel/mistral/constants.go new file mode 100644 index 000000000..7f5f3acac --- /dev/null +++ b/relay/channel/mistral/constants.go @@ -0,0 +1,12 @@ +package mistral + +var ModelList = []string{ + "open-mistral-7b", + "open-mixtral-8x7b", + "mistral-small-latest", + "mistral-medium-latest", + "mistral-large-latest", + "mistral-embed", +} + +var ChannelName = "mistral" diff --git a/relay/channel/mistral/text.go b/relay/channel/mistral/text.go new file mode 100644 index 000000000..04add0675 --- /dev/null +++ b/relay/channel/mistral/text.go @@ -0,0 +1,40 @@ +package mistral + +import ( + "encoding/json" + "one-api/dto" +) + +func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { + messages := make([]dto.Message, 0, len(request.Messages)) + for _, message := range request.Messages { + if !message.IsStringContent() { + mediaMessages := message.ParseContent() + for j, mediaMessage := range mediaMessages { + if mediaMessage.Type == dto.ContentTypeImageURL { + imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) + mediaMessage.ImageUrl = imageUrl.Url + mediaMessages[j] = mediaMessage + } + } + messageRaw, _ := json.Marshal(mediaMessages) + message.Content = messageRaw + } + messages = append(messages, dto.Message{ + Role: message.Role, + Content: message.Content, + ToolCalls: message.ToolCalls, + ToolCallId: message.ToolCallId, + }) + } + return &dto.GeneralOpenAIRequest{ + Model: request.Model, + Stream: request.Stream, + Messages: messages, + Temperature: request.Temperature, + TopP: request.TopP, + MaxTokens: request.MaxTokens, + Tools: request.Tools, + ToolChoice: request.ToolChoice, + } +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 98c6cc0e8..4be095257 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -25,6 +25,7 @@ const ( APITypeCloudflare APITypeSiliconFlow APITypeVertexAi + APITypeMistral APITypeDummy // this one is only for count, do not add any channel after this ) @@ -72,6 +73,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeSiliconFlow case common.ChannelTypeVertexAi: apiType = APITypeVertexAi + case common.ChannelTypeMistral: + apiType = APITypeMistral } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 01647825d..5219c8f31 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -12,6 +12,7 @@ import ( "one-api/relay/channel/dify" "one-api/relay/channel/gemini" "one-api/relay/channel/jina" + "one-api/relay/channel/mistral" "one-api/relay/channel/ollama" "one-api/relay/channel/openai" "one-api/relay/channel/palm" @@ -68,6 +69,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &siliconflow.Adaptor{} case constant.APITypeVertexAi: return &vertex.Adaptor{} + case constant.APITypeMistral: + return &mistral.Adaptor{} } return nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 8a3ddbd39..711edb684 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -109,6 +109,7 @@ export const CHANNEL_OPTIONS = [ { key: 37, text: 'Dify', value: 37, color: 'teal', label: 'Dify' }, { key: 38, text: 'Jina', value: 38, color: 'blue', label: 'Jina' }, { key: 40, text: 'SiliconCloud', value: 40, color: 'purple', label: 'SiliconCloud' }, + { key: 42, text: 'Mistral AI', value: 42, color: 'blue', label: 'Mistral AI' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' }, { key: 22,