Skip to content

Commit

Permalink
[computer-use-demo] Add prompt caching
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinji committed Oct 23, 2024
1 parent 4acbe40 commit d93012e
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 14 deletions.
51 changes: 47 additions & 4 deletions computer-use-demo/computer_use_demo/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@
from enum import StrEnum
from typing import Any, cast

from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
from anthropic import (
Anthropic,
AnthropicBedrock,
AnthropicVertex,
APIResponse,
BaseModel,
)
from anthropic.types import (
ToolResultBlockParam,
)
from anthropic.types.beta import (
BetaCacheControlEphemeralParam,
BetaContentBlock,
BetaContentBlockParam,
BetaImageBlockParam,
Expand All @@ -24,15 +31,15 @@

from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult

BETA_FLAG = "computer-use-2024-10-22"


class APIProvider(StrEnum):
ANTHROPIC = "anthropic"
BEDROCK = "bedrock"
VERTEX = "vertex"


MAX_PROMPT_CACHING_BREAKPOINTS = 4

PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = {
APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022",
APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
Expand Down Expand Up @@ -74,6 +81,7 @@ async def sampling_loop(
api_key: str,
only_n_most_recent_images: int | None = None,
max_tokens: int = 4096,
prompt_caching: bool = True,
):
"""
Agentic sampling loop for the assistant/tool interaction of computer use.
Expand All @@ -98,6 +106,11 @@ async def sampling_loop(
elif provider == APIProvider.BEDROCK:
client = AnthropicBedrock()

betas = ["computer-use-2024-10-22"]
if prompt_caching:
betas.append("prompt-caching-2024-07-31")
_add_prompt_caching_headers(messages)

# Call the API
# we use raw_response to provide debug information to streamlit. Your
# implementation may be able call the SDK directly with:
Expand All @@ -108,7 +121,7 @@ async def sampling_loop(
model=model,
system=system,
tools=tool_collection.to_params(),
betas=["computer-use-2024-10-22"],
betas=betas,
)

api_response_callback(cast(APIResponse[BetaMessage], raw_response))
Expand Down Expand Up @@ -230,3 +243,33 @@ def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str):
if result.system:
result_text = f"<system>{result.system}</system>\n{result_text}"
return result_text


def _add_prompt_caching_headers(
messages: list[BetaMessageParam],
):
prompt_caching_breakpoints = 0
for message in messages:
if isinstance(message["content"], str):
continue

params: list[BetaContentBlockParam] = []
for content_block in message["content"]:
if isinstance(content_block, BaseModel):
content_block_param = cast(
BetaContentBlockParam, content_block.to_dict()
)
else:
content_block_param = content_block
params.append(content_block_param)

if (
isinstance(content_block_param, dict)
and content_block_param.get("type") == "image"
and prompt_caching_breakpoints < MAX_PROMPT_CACHING_BREAKPOINTS
):
content_block_param["cache_control"] = BetaCacheControlEphemeralParam(
type="ephemeral"
)
prompt_caching_breakpoints += 1
message["content"] = params
2 changes: 1 addition & 1 deletion computer-use-demo/computer_use_demo/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _reset_api_provider():
st.session_state.messages.append(
{
"role": Sender.USER,
"content": [TextBlock(type="text", text=new_message)],
"content": [BetaTextBlock(type="text", text=new_message)],
}
)
_render_message(Sender.USER, new_message)
Expand Down
19 changes: 12 additions & 7 deletions computer-use-demo/tests/loop_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from unittest import mock

from anthropic.types import TextBlock, ToolUseBlock
from anthropic.types.beta import BetaMessage, BetaMessageParam
from anthropic.types.beta import (
BetaMessage,
BetaMessageParam,
BetaTextBlock,
BetaToolUseBlock,
)

from computer_use_demo.loop import APIProvider, sampling_loop

Expand All @@ -13,13 +17,13 @@ async def test_loop():
mock.Mock(
spec=BetaMessage,
content=[
TextBlock(type="text", text="Hello"),
ToolUseBlock(
BetaTextBlock(type="text", text="Hello"),
BetaToolUseBlock(
type="tool_use", id="1", name="computer", input={"action": "test"}
),
],
),
mock.Mock(spec=BetaMessage, content=[TextBlock(type="text", text="Done!")]),
mock.Mock(spec=BetaMessage, content=[BetaTextBlock(type="text", text="Done!")]),
]

tool_collection = mock.AsyncMock()
Expand Down Expand Up @@ -49,7 +53,8 @@ async def test_loop():
)

assert len(result) == 4
assert result[0] == {"role": "user", "content": "Test message"}
assert result[0]["role"] == "user"
assert result[0]["content"] == "Test message"
assert result[1]["role"] == "assistant"
assert result[2]["role"] == "user"
assert result[3]["role"] == "assistant"
Expand All @@ -58,7 +63,7 @@ async def test_loop():
tool_collection.run.assert_called_once_with(
name="computer", tool_input={"action": "test"}
)
output_callback.assert_called_with(TextBlock(text="Done!", type="text"))
output_callback.assert_called_with(BetaTextBlock(text="Done!", type="text"))
assert output_callback.call_count == 3
assert tool_output_callback.call_count == 1
assert api_response_callback.call_count == 2
5 changes: 3 additions & 2 deletions computer-use-demo/tests/streamlit_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from unittest import mock

import pytest
from anthropic.types.beta import BetaTextBlock
from streamlit.testing.v1 import AppTest

from computer_use_demo.streamlit import Sender, TextBlock
from computer_use_demo.streamlit import Sender


@pytest.fixture
Expand All @@ -18,6 +19,6 @@ def test_streamlit(streamlit_app: AppTest):
streamlit_app.chat_input[0].set_value("Hello").run()
assert patch.called
assert patch.call_args.kwargs["messages"] == [
{"role": Sender.USER, "content": [TextBlock(text="Hello", type="text")]}
{"role": Sender.USER, "content": [BetaTextBlock(text="Hello", type="text")]}
]
assert not streamlit_app.exception

0 comments on commit d93012e

Please # to comment.