Skip to content

Commit

Permalink
feat: 添加Mistral渠道 (close #546)
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Nov 5, 2024
1 parent 3e2ae29 commit 4b48e49
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 0 deletions.
2 changes: 2 additions & 0 deletions common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ const (
ChannelCloudflare = 39
ChannelTypeSiliconFlow = 40
ChannelTypeVertexAi = 41
ChannelTypeMistral = 42

ChannelTypeDummy // this one is only for count, do not add any channel after this

Expand Down Expand Up @@ -270,4 +271,5 @@ var ChannelBaseURLs = []string{
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
"", //41
"https://api.mistral.ai", //42
}
72 changes: 72 additions & 0 deletions relay/channel/mistral/adaptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package mistral

import (
"errors"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
)

type Adaptor struct {
}

func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}

func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}

func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}

func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}

func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}

func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
mistralReq := requestOpenAI2Mistral(*request)
//common.LogJson(c, "body", mistralReq)
return mistralReq, nil
}

func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}

func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}

func (a *Adaptor) GetModelList() []string {
return ModelList
}

func (a *Adaptor) GetChannelName() string {
return ChannelName
}
12 changes: 12 additions & 0 deletions relay/channel/mistral/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package mistral

var ModelList = []string{
"open-mistral-7b",
"open-mixtral-8x7b",
"mistral-small-latest",
"mistral-medium-latest",
"mistral-large-latest",
"mistral-embed",
}

var ChannelName = "mistral"
40 changes: 40 additions & 0 deletions relay/channel/mistral/text.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package mistral

import (
"encoding/json"
"one-api/dto"
)

func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
mediaMessage.ImageUrl = imageUrl.Url
mediaMessages[j] = mediaMessage
}
}
messageRaw, _ := json.Marshal(mediaMessages)
message.Content = messageRaw
}
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
})
}
return &dto.GeneralOpenAIRequest{
Model: request.Model,
Stream: request.Stream,
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
MaxTokens: request.MaxTokens,
Tools: request.Tools,
ToolChoice: request.ToolChoice,
}
}
3 changes: 3 additions & 0 deletions relay/constant/api_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
APITypeCloudflare
APITypeSiliconFlow
APITypeVertexAi
APITypeMistral

APITypeDummy // this one is only for count, do not add any channel after this
)
Expand Down Expand Up @@ -72,6 +73,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeSiliconFlow
case common.ChannelTypeVertexAi:
apiType = APITypeVertexAi
case common.ChannelTypeMistral:
apiType = APITypeMistral
}
if apiType == -1 {
return APITypeOpenAI, false
Expand Down
3 changes: 3 additions & 0 deletions relay/relay_adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"one-api/relay/channel/dify"
"one-api/relay/channel/gemini"
"one-api/relay/channel/jina"
"one-api/relay/channel/mistral"
"one-api/relay/channel/ollama"
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
Expand Down Expand Up @@ -68,6 +69,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &siliconflow.Adaptor{}
case constant.APITypeVertexAi:
return &vertex.Adaptor{}
case constant.APITypeMistral:
return &mistral.Adaptor{}
}
return nil
}
Expand Down
1 change: 1 addition & 0 deletions web/src/constants/channel.constants.js
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ export const CHANNEL_OPTIONS = [
{ key: 37, text: 'Dify', value: 37, color: 'teal', label: 'Dify' },
{ key: 38, text: 'Jina', value: 38, color: 'blue', label: 'Jina' },
{ key: 40, text: 'SiliconCloud', value: 40, color: 'purple', label: 'SiliconCloud' },
{ key: 42, text: 'Mistral AI', value: 42, color: 'blue', label: 'Mistral AI' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
{
key: 22,
Expand Down

0 comments on commit 4b48e49

Please # to comment.