Skip to content

Commit

Permalink
feat: add Anthropic API support with custom version header (#934)
Browse files Browse the repository at this point in the history
* feat: add Anthropic API support with custom version header

* refactor: use switch statement for API type header handling

* refactor: add OpenAI & AzureAD types to be exhaustive

* Update client.go

need explicit fallthrough in empty case statements

* constant for APIVersion; addtl tests
  • Loading branch information
danackerson authored Feb 25, 2025
1 parent 85f578b commit be2e238
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 6 deletions.
18 changes: 13 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream

func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
switch c.config.APIType {
case APITypeAzure, APITypeCloudflareAzure:
// Azure API Key authentication
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else if c.config.authToken != "" {
// OpenAI or Azure AD authentication
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
case APITypeAnthropic:
// https://docs.anthropic.com/en/api/versioning
req.Header.Set("anthropic-version", c.config.APIVersion)
case APITypeOpenAI, APITypeAzureAD:
fallthrough
default:
if c.config.authToken != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}
}

if c.config.OrgID != "" {
req.Header.Set("OpenAI-Organization", c.config.OrgID)
}
Expand Down
15 changes: 15 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ func TestClient(t *testing.T) {
}
}

func TestSetCommonHeadersAnthropic(t *testing.T) {
config := DefaultAnthropicConfig("mock-token", "")
client := NewClientWithConfig(config)
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}

client.setCommonHeaders(req)

if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
}
}

func TestDecodeResponse(t *testing.T) {
stringInput := ""

Expand Down
22 changes: 21 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ const (

azureAPIPrefix = "openai"
azureDeploymentsPrefix = "deployments"

AnthropicAPIVersion = "2023-06-01"
)

type APIType string
Expand All @@ -20,6 +22,7 @@ const (
APITypeAzure APIType = "AZURE"
APITypeAzureAD APIType = "AZURE_AD"
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
APITypeAnthropic APIType = "ANTHROPIC"
)

const AzureAPIKeyHeader = "api-key"
Expand All @@ -37,7 +40,7 @@ type ClientConfig struct {
BaseURL string
OrgID string
APIType APIType
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic
AssistantVersion string
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
HTTPClient HTTPDoer
Expand Down Expand Up @@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
}
}

func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
if baseURL == "" {
baseURL = "https://api.anthropic.com/v1"
}
return ClientConfig{
authToken: apiKey,
BaseURL: baseURL,
OrgID: "",
APIType: APITypeAnthropic,
APIVersion: AnthropicAPIVersion,

HTTPClient: &http.Client{},

EmptyMessagesLimit: defaultEmptyMessagesLimit,
}
}

func (ClientConfig) String() string {
return "<OpenAI API ClientConfig>"
}
Expand Down
40 changes: 40 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,43 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
})
}
}

func TestDefaultAnthropicConfig(t *testing.T) {
apiKey := "test-key"
baseURL := "https://api.anthropic.com/v1"

config := openai.DefaultAnthropicConfig(apiKey, baseURL)

if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}

if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
}

if config.BaseURL != baseURL {
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
}

if config.EmptyMessagesLimit != 300 {
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
}
}

func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
config := openai.DefaultAnthropicConfig("", "")

if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}

if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
}

expectedBaseURL := "https://api.anthropic.com/v1"
if config.BaseURL != expectedBaseURL {
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
}
}

0 comments on commit be2e238

Please # to comment.