diff --git a/computer-use-demo/computer_use_demo/loop.py b/computer-use-demo/computer_use_demo/loop.py index bb959e4b..a4284ddb 100644 --- a/computer-use-demo/computer_use_demo/loop.py +++ b/computer-use-demo/computer_use_demo/loop.py @@ -24,8 +24,6 @@ from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult -BETA_FLAG = "computer-use-2024-10-22" - class APIProvider(StrEnum): ANTHROPIC = "anthropic" @@ -74,6 +72,7 @@ async def sampling_loop( api_key: str, only_n_most_recent_images: int | None = None, max_tokens: int = 4096, + prompt_caching: bool = True, ): """ Agentic sampling loop for the assistant/tool interaction of computer use. @@ -98,6 +97,15 @@ async def sampling_loop( elif provider == APIProvider.BEDROCK: client = AnthropicBedrock() + betas = ["computer-use-2024-10-22"] + if prompt_caching: + betas.append("prompt-caching-2024-07-31") + for message in messages: + if isinstance(message.content, str): + continue + for content_block_param in message.content: + content_block_param["cache_control"] = {"type": "ephemeral"} + # Call the API # we use raw_response to provide debug information to streamlit. Your # implementation may be able call the SDK directly with: @@ -108,7 +116,7 @@ async def sampling_loop( model=model, system=system, tools=tool_collection.to_params(), - betas=["computer-use-2024-10-22"], + betas=betas, ) api_response_callback(cast(APIResponse[BetaMessage], raw_response)) diff --git a/computer-use-demo/tests/loop_test.py b/computer-use-demo/tests/loop_test.py index 4985dbee..751cbab9 100644 --- a/computer-use-demo/tests/loop_test.py +++ b/computer-use-demo/tests/loop_test.py @@ -49,7 +49,8 @@ async def test_loop(): ) assert len(result) == 4 - assert result[0] == {"role": "user", "content": "Test message"} + assert result[0]["role"] == "user" + assert result[0]["content"] == "Test message" assert result[1]["role"] == "assistant" assert result[2]["role"] == "user" assert result[3]["role"] == "assistant"