Skip to content

Commit

Permalink
Ai proxy support doubao (#1337)
Browse files Browse the repository at this point in the history
  • Loading branch information
rinfx authored Sep 24, 2024
1 parent dc61bfc commit bef9139
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 0 deletions.
13 changes: 13 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,19 @@ provider:
}
```

### 使用 OpenAI 协议代理豆包大模型服务

**配置信息**

```yaml
provider:
type: doubao
apiTokens:
- YOUR_DOUBAO_API_KEY
modelMapping:
'*': YOUR_DOUBAO_ENDPOINT
timeout: 1200000
```
### 使用月之暗面配合其原生的文件上下文
Expand Down
102 changes: 102 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/doubao.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package provider

import (
"errors"
"fmt"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)

const (
doubaoDomain = "ark.cn-beijing.volces.com"
doubaoChatCompletionPath = "/api/v3/chat/completions"
)

type doubaoProviderInitializer struct{}

func (m *doubaoProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}

func (m *doubaoProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &doubaoProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}

type doubaoProvider struct {
config ProviderConfig
contextCache *contextCache
}

func (m *doubaoProvider) GetProviderType() string {
return providerTypeDoubao
}

func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = util.OverwriteRequestHost(doubaoDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
if m.config.protocol == protocolOriginal {
ctx.DontReadRequestBody()
return types.ActionContinue, nil
}
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(doubaoChatCompletionPath)
return types.ActionContinue, nil
}

func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
if m.contextCache != nil {
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.doubao.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.doubao.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
} else {
return types.ActionContinue, err
}
} else {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.doubao.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
return types.ActionContinue, err
}
_ = proxywasm.ResumeHttpRequest()
return types.ActionPause, nil
}
}
2 changes: 2 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const (
providerTypeDeepl = "deepl"
providerTypeMistral = "mistral"
providerTypeCohere = "cohere"
providerTypeDoubao = "doubao"

protocolOpenAI = "openai"
protocolOriginal = "original"
Expand Down Expand Up @@ -96,6 +97,7 @@ var (
providerTypeDeepl: &deeplProviderInitializer{},
providerTypeMistral: &mistralProviderInitializer{},
providerTypeCohere: &cohereProviderInitializer{},
providerTypeDoubao: &doubaoProviderInitializer{},
}
)

Expand Down

0 comments on commit bef9139

Please # to comment.