Skip to content

Commit

Permalink
Merge pull request #559 from Rt39/feat-claude-api
Browse files Browse the repository at this point in the history
添加对Anthropic Claude API的支持
  • Loading branch information
Fridemn authored Feb 21, 2025
2 parents 9cbb6aa + 838cd2d commit 04d204e
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ package.json
venv/*
packages/python_interpreter/workplace
.venv/*

.conda/
12 changes: 12 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,18 @@
"model": "grok-2-latest",
},
},
"anthropic(claude)": {
"id": "claude",
"type": "anthropic_chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.anthropic.com/v1",
"timeout": 120,
"model_config": {
"model": "claude-3-5-sonnet-latest",
"max_tokens": 4096,
},
},
"ollama": {
"id": "ollama_default",
"type": "openai_chat_completion",
Expand Down
23 changes: 23 additions & 0 deletions astrbot/core/provider/func_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,29 @@ def get_func_desc_openai_style(self) -> list:
)
return _l

def get_func_desc_anthropic_style(self) -> list:
"""
获得 Anthropic API 风格的**已经激活**的工具描述
"""
tools = []
for f in self.func_list:
if not f.active:
continue

# Convert internal format to Anthropic style
tool = {
"name": f.name,
"description": f.description,
"input_schema": {
"type": "object",
"properties": f.parameters.get("properties", {}),
# Keep the required field from the original parameters if it exists
"required": f.parameters.get("required", [])
}
}
tools.append(tool)
return tools

def get_func_desc_google_genai_style(self) -> Dict:
declarations = {}
tools = []
Expand Down
2 changes: 2 additions & 0 deletions astrbot/core/provider/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "anthropic_chat_completion":
from .sources.anthropic_source import ProviderAnthropic as ProviderAnthropic
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader
Expand Down
184 changes: 184 additions & 0 deletions astrbot/core/provider/sources/anthropic_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from typing import List
from mimetypes import guess_type

from anthropic import AsyncAnthropic
from anthropic.types import Message

from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from .openai_source import ProviderOpenAIOfficial

@register_provider_adapter("anthropic_chat_completion", "Anthropic Claude API 提供商适配器")
class ProviderAnthropic(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True,
default_persona: Personality = None
) -> None:
# Skip OpenAI's __init__ and call Provider's __init__ directly
Provider.__init__(self, provider_config, provider_settings, persistant_history, db_helper, default_persona)

self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)

self.client = AsyncAnthropic(
api_key=self.chosen_api_key,
timeout=self.timeout,
base_url=self.base_url
)

self.set_model(provider_config['model_config']['model'])

async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
tool_list = tools.get_func_desc_anthropic_style()
if tool_list:
payloads['tools'] = tool_list

completion = await self.client.messages.create(
**payloads,
stream=False
)

assert isinstance(completion, Message)
logger.debug(f"completion: {completion}")

if len(completion.content) == 0:
raise Exception("API 返回的 completion 为空。")
# TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容
# 选最后一条消息,如果要进行函数调用,anthropic会先返回文本消息的思维链,然后再返回函数调用请求
content = completion.content[-1]

llm_response = LLMResponse("assistant")

if content.type == "text":
# text completion
completion_text = str(content.text).strip()
llm_response.completion_text = completion_text

# Anthropic每次只返回一个函数调用
if completion.stop_reason == "tool_use":
# tools call (function calling)
args_ls = []
func_name_ls = []
func_name_ls.append(content.name)
args_ls.append(content.input)
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls

if not llm_response.completion_text and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}。")
raise Exception(f"API 返回的 completion 无法解析:{completion}。")

llm_response.raw_completion = completion

return llm_response

async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]

for part in context_query:
if '_no_save' in part:
del part['_no_save']

model_config = self.provider_config.get("model_config", {})

payloads = {
"messages": context_query,
**model_config
}
# Anthropic has a different way of handling system prompts
if system_prompt:
payloads['system'] = system_prompt
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)

except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 20
while retry_cnt > 0:
logger.warning(f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
try:
await self.pop_record(context_query)
response = await self.client.messages.create(
messages=context_query,
**model_config
)
llm_response = LLMResponse("assistant")
llm_response.completion_text = response.content[0].text
llm_response.raw_completion = response
return llm_response
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
else:
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
raise e

return llm_response

async def assemble_context(self, text: str, image_urls: List[str] = None):
'''组装上下文,支持文本和图片'''
if not image_urls:
return {"role": "user", "content": text}

content = []
content.append({"type": "text", "text": text})

for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)

if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue

# Get mime type for the image
mime_type, _ = guess_type(image_url)
if not mime_type:
mime_type = "image/jpeg" # Default to JPEG if can't determine

content.append({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": image_data.split("base64,")[1] if "base64," in image_data else image_data
}
})

return {"role": "user", "content": content}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pydantic~=2.10.3
aiohttp
openai
anthropic
qq-botpy
chardet~=5.1.0
Pillow
Expand Down

0 comments on commit 04d204e

Please # to comment.