Skip to content

feat: Add support for image function tools #654

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/tools/image_function_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import asyncio
import base64
import os

from agents import Agent, Runner, image_function_tool

FILEPATH = os.path.join(os.path.dirname(__file__), "media/small.webp")


@image_function_tool
def image_to_base64(path: str) -> str:
"""
This function takes a path to an image and returns a base64 encoded string of the image.
It is used to convert the image to a base64 encoded string so that it can be sent to the LLM.
"""
with open(path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return f"data:image/jpeg;base64,{encoded_string}"


async def main():
agent = Agent(
name="Assistant",
instructions="You are a helpful assistant.",
tools=[image_to_base64],
)

result = await Runner.run(agent, f"Read the image in {FILEPATH} and tell me what you see.")
print(result.final_output)


if __name__ == "__main__":
asyncio.run(main())
Binary file added examples/tools/media/small.webp
Binary file not shown.
4 changes: 4 additions & 0 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@
FileSearchTool,
FunctionTool,
FunctionToolResult,
ImageFunctionTool,
Tool,
WebSearchTool,
default_tool_error_function,
function_tool,
image_function_tool,
)
from .tracing import (
AgentSpanData,
Expand Down Expand Up @@ -203,12 +205,14 @@ def enable_verbose_stdout_logging():
"AgentUpdatedStreamEvent",
"StreamEvent",
"FunctionTool",
"ImageFunctionTool",
"FunctionToolResult",
"ComputerTool",
"FileSearchTool",
"Tool",
"WebSearchTool",
"function_tool",
"image_function_tool",
"Usage",
"add_trace_processor",
"agent_span",
Expand Down
113 changes: 111 additions & 2 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,14 @@
from .models.interface import ModelTracing
from .run_context import RunContextWrapper, TContext
from .stream_events import RunItemStreamEvent, StreamEvent
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
from .tool import (
ComputerTool,
FunctionTool,
FunctionToolResult,
ImageFunctionTool,
ImageFunctionToolResult,
Tool,
)
from .tracing import (
SpanError,
Trace,
Expand Down Expand Up @@ -106,6 +113,12 @@ class ToolRunFunction:
function_tool: FunctionTool


@dataclass
class ToolRunImageFunction:
tool_call: ResponseFunctionToolCall
image_function_tool: ImageFunctionTool


@dataclass
class ToolRunComputerAction:
tool_call: ResponseComputerToolCall
Expand All @@ -117,6 +130,7 @@ class ProcessedResponse:
new_items: list[RunItem]
handoffs: list[ToolRunHandoff]
functions: list[ToolRunFunction]
image_functions: list[ToolRunImageFunction]
computer_actions: list[ToolRunComputerAction]
tools_used: list[str] # Names of all tools used, including hosted tools

Expand All @@ -127,6 +141,7 @@ def has_tools_to_run(self) -> bool:
[
self.handoffs,
self.functions,
self.image_functions,
self.computer_actions,
]
)
Expand Down Expand Up @@ -207,14 +222,21 @@ async def execute_tools_and_side_effects(
new_step_items.extend(processed_response.new_items)

# First, lets run the tool calls - function tools and computer actions
function_results, computer_results = await asyncio.gather(
function_results, image_function_results, computer_results = await asyncio.gather(
cls.execute_function_tool_calls(
agent=agent,
tool_runs=processed_response.functions,
hooks=hooks,
context_wrapper=context_wrapper,
config=run_config,
),
cls.execute_image_function_tool_calls(
agent=agent,
tool_runs=processed_response.image_functions,
hooks=hooks,
context_wrapper=context_wrapper,
config=run_config,
),
cls.execute_computer_actions(
agent=agent,
actions=processed_response.computer_actions,
Expand All @@ -224,6 +246,7 @@ async def execute_tools_and_side_effects(
),
)
new_step_items.extend([result.run_item for result in function_results])
new_step_items.extend([result.run_item for result in image_function_results])
new_step_items.extend(computer_results)

# Second, check if there are any handoffs
Expand Down Expand Up @@ -342,10 +365,14 @@ def process_model_response(

run_handoffs = []
functions = []
image_functions = []
computer_actions = []
tools_used: list[str] = []
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
image_function_map = {
tool.name: tool for tool in all_tools if isinstance(tool, ImageFunctionTool)
}
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)

for output in response.output:
Expand Down Expand Up @@ -393,6 +420,15 @@ def process_model_response(
handoff=handoff_map[output.name],
)
run_handoffs.append(handoff)

elif output.name in image_function_map:
items.append(ToolCallItem(raw_item=output, agent=agent))
image_functions.append(
ToolRunImageFunction(
tool_call=output,
image_function_tool=image_function_map[output.name],
)
)
# Regular function tool call
else:
if output.name not in function_map:
Expand All @@ -415,6 +451,7 @@ def process_model_response(
new_items=items,
handoffs=run_handoffs,
functions=functions,
image_functions=image_functions,
computer_actions=computer_actions,
tools_used=tools_used,
)
Expand Down Expand Up @@ -489,6 +526,78 @@ async def run_single_tool(
for tool_run, result in zip(tool_runs, results)
]

@classmethod
async def execute_image_function_tool_calls(
cls,
*,
agent: Agent[TContext],
tool_runs: list[ToolRunImageFunction],
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> list[ImageFunctionToolResult]:
async def run_single_tool(
func_tool: ImageFunctionTool, tool_call: ResponseFunctionToolCall
) -> Any:
with function_span(func_tool.name) as span_fn:
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
try:
_, _, result = await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, func_tool),
(
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
if agent.hooks
else _coro.noop_coroutine()
),
func_tool.on_invoke_tool(context_wrapper, tool_call.arguments),
)

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, func_tool, result),
(
agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
if agent.hooks
else _coro.noop_coroutine()
),
)
except Exception as e:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": func_tool.name, "error": str(e)},
)
)
if isinstance(e, AgentsException):
raise e
raise UserError(f"Error running tool {func_tool.name}: {e}") from e

if config.trace_include_sensitive_data:
span_fn.span_data.output = result
return result

tasks = []
for tool_run in tool_runs:
image_function_tool = tool_run.image_function_tool
tasks.append(run_single_tool(image_function_tool, tool_run.tool_call))

results = await asyncio.gather(*tasks)

return [
ImageFunctionToolResult(
tool=tool_run.image_function_tool,
output=result,
run_item=ToolCallOutputItem(
output=result,
raw_item=ItemHelpers.image_function_tool_call_output_item(
tool_run.tool_call, result
),
agent=agent,
),
)
for tool_run, result in zip(tool_runs, results)
]

@classmethod
async def execute_computer_actions(
cls,
Expand Down
24 changes: 23 additions & 1 deletion src/agents/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class RunItemBase(Generic[T], abc.ABC):

def to_input_item(self) -> TResponseInputItem:
"""Converts this item into an input item suitable for passing to the model."""
if isinstance(self.raw_item, dict):
if isinstance(self.raw_item, dict) or isinstance(self.raw_item, list):
# We know that input items are dicts, so we can ignore the type error
return self.raw_item # type: ignore
elif isinstance(self.raw_item, BaseModel):
Expand Down Expand Up @@ -248,3 +248,25 @@ def tool_call_output_item(
"output": output,
"type": "function_call_output",
}

@classmethod
def image_function_tool_call_output_item(
cls, tool_call: ResponseFunctionToolCall, output: str
) -> FunctionCallOutput:
"""Creates a tool call output item from a tool call and its output."""
return [
{
"call_id": tool_call.call_id,
"output": "Image generating tool is called.",
"type": "function_call_output",
},
{
"role": "user",
"content": [
{
"type": "input_image",
"image_url": output,
}
],
},
]
18 changes: 17 additions & 1 deletion src/agents/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@
from ..handoffs import Handoff
from ..items import ItemHelpers, ModelResponse, TResponseInputItem
from ..logger import logger
from ..tool import ComputerTool, FileSearchTool, FunctionTool, Tool, WebSearchTool
from ..tool import (
ComputerTool,
FileSearchTool,
FunctionTool,
ImageFunctionTool,
Tool,
WebSearchTool,
)
from ..tracing import SpanError, response_span
from ..usage import Usage
from ..version import __version__
Expand Down Expand Up @@ -358,6 +365,15 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]:
"description": tool.description,
}
includes: IncludeLiteral | None = None
elif isinstance(tool, ImageFunctionTool):
converted_tool: ToolParam = {
"name": tool.name,
"parameters": tool.params_json_schema,
"strict": tool.strict_json_schema,
"type": "function",
"description": tool.description,
}
includes: IncludeLiteral | None = None
elif isinstance(tool, WebSearchTool):
ws: WebSearchToolParam = {
"type": "web_search_preview",
Expand Down
9 changes: 8 additions & 1 deletion src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,14 @@ async def _run_single_turn(
output_schema = cls._get_output_schema(agent)
handoffs = cls._get_handoffs(agent)
input = ItemHelpers.input_to_new_input_list(original_input)
input.extend([generated_item.to_input_item() for generated_item in generated_items])

# input.extend([generated_item.to_input_item() for generated_item in generated_items])
for generated_item in generated_items:
input_item_from_generated_item = generated_item.to_input_item()
if isinstance(input_item_from_generated_item, list):
input.extend(input_item_from_generated_item)
else:
input.append(input_item_from_generated_item)

new_response = await cls._get_new_response(
agent,
Expand Down
Loading