From ac70f6004bb0f706e3592370d49ef0d70f91ecd5 Mon Sep 17 00:00:00 2001 From: volkanaydingul_macbook_pro_local Date: Tue, 6 May 2025 18:29:55 +0200 Subject: [PATCH] Add ImageFunctionTool and related functionality This commit introduces the ImageFunctionTool and ImageFunctionToolResult classes, enabling the creation and execution of image-generating tools. The necessary modifications include updates to the tool execution logic, new data classes for handling image function calls, and adjustments to the response processing to accommodate image outputs. Additionally, the input handling in the Runner class has been refined to support the new image function items. Changes include: - New classes: ImageFunctionTool, ImageFunctionToolResult, ToolRunImageFunction - Updated tool execution methods to handle image functions - Modifications to the ProcessedResponse class to include image function results - Enhancements to ItemHelpers for image function output formatting - Adjustments in the Runner class for input item processing These changes enhance the SDK's capabilities for handling image generation tasks alongside existing function tools. --- examples/tools/image_function_tool.py | 33 ++++ examples/tools/media/small.webp | Bin 0 -> 3186 bytes src/agents/__init__.py | 4 + src/agents/_run_impl.py | 113 ++++++++++- src/agents/items.py | 24 ++- src/agents/models/openai_responses.py | 18 +- src/agents/run.py | 9 +- src/agents/tool.py | 182 +++++++++++++++++ src/agents/voice/model.py | 2 + tests/test_extra_headers.py | 9 +- tests/tools/test_image_function_tool.py | 248 ++++++++++++++++++++++++ 11 files changed, 632 insertions(+), 10 deletions(-) create mode 100644 examples/tools/image_function_tool.py create mode 100644 examples/tools/media/small.webp create mode 100644 tests/tools/test_image_function_tool.py diff --git a/examples/tools/image_function_tool.py b/examples/tools/image_function_tool.py new file mode 100644 index 000000000..e80949e85 --- /dev/null +++ b/examples/tools/image_function_tool.py @@ -0,0 +1,33 @@ +import asyncio +import base64 +import os + +from agents import Agent, Runner, image_function_tool + +FILEPATH = os.path.join(os.path.dirname(__file__), "media/small.webp") + + +@image_function_tool +def image_to_base64(path: str) -> str: + """ + This function takes a path to an image and returns a base64 encoded string of the image. + It is used to convert the image to a base64 encoded string so that it can be sent to the LLM. + """ + with open(path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()).decode("utf-8") + return f"data:image/jpeg;base64,{encoded_string}" + + +async def main(): + agent = Agent( + name="Assistant", + instructions="You are a helpful assistant.", + tools=[image_to_base64], + ) + + result = await Runner.run(agent, f"Read the image in {FILEPATH} and tell me what you see.") + print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/tools/media/small.webp b/examples/tools/media/small.webp new file mode 100644 index 0000000000000000000000000000000000000000..49633857efab07a9ef58eeaecf7ee4473efc9874 GIT binary patch literal 3186 zcmV-&42|d8+<%fJ{)xXF8%g06k@cgK?j7gDbmNXnas%LPJCY*cTPXEp! zYFjG%&uP6LQePXgQXzFbRl8&epH#DU0;A}AP^qHjH4bFC;J<_&Vkup$k|wzUiabdR zYPBsBho*4@$c2Drqz)@78%^;I-k}YlV7dy7j%;y~c^n2@LBzCMzdxI=_X7vCrRJrA z6hGIJaqWd|3UjTLwT3J0RWlTZT+5>UGW>=>EqgQG3qWV4)MyPCqsT*{a(0E|E3Q_e z_UZ0Jh_w^>84D)n_zevv5rvEFSm%at|3ljoe(!){za7h19yV~pBi2fYbV4&;8wr`7 z1sSNXIKMsTry#Lx4z0e{U8?!j9iWTveEKyE5O8<(<#x|lHkB~`x@-B38XM-6dtA({ z07bin_yLe_e1XWLni2oRhcm`SB$M~uG~Wv|X0Qg88I2~t-|j=qv{6q^JYa1{W@d7@ z+NzujtNaMrmdjRLbaull)xGpRn>U(KCDJ`&{|FdxcD3e_gX|SO)zG=0L=!J~t}lkO zn@;S}ZkRMODCByozHfeuXFV)*##)##Kd>XmWRkI_Gu1SRI1{3{8MDcu-*{{S%dvuL z;c4x-Gl zuG&N30RHesQ}p+qd_{^ z$yY(fDuHv5a`tS9wQmTF91c9P@VbUjqVF+lMN$HEeWo33HM90%NG7Mol-6{V^yip3 z_7d4MVIwf4e3>dGB7`+(wG&>baurF0M)EV^-y7P{2z7K#*`O$L>4hdG6cF!USOaf| z#`(|4lAUCFu%XgZo>Wsvi_02(jo##cc@X?5qsWejD@}MBbVBgW@;;SFsT!PEDUDDp zu5qbqzHoto^On90Ml2Q?o^WktDl>&BR#^Wzwj~2b_UQPwN%QJR|Hg6MS$^FQu>-7* z-9qH?2dhO3Sj%5&FV(-;5gt0#kQ*_3bDG^u@v53DMYrCVl((x{L-lv~ND6zh79EA2 zNF6mLtxLagZ9v}#6QB=S4|mg6plT?p8aNaeS^0snb+#rcX{*I)_W&}_ zD6R)c>?Ga&rxC1e1bFvo69=`VDT(2UtXY*?3ibLMVx!U41chR3taHL&9cMUE1Im=X+SS5ql5&VR1cRcKI&k=8I$WYF&$$#KP zmb8{qV3b9v+AJ`Z->2}NZ4{&Vjmh$Jg8rL*g$PfHdJ7t^2O(?{Bo0e|y?eqQEi%~V z@BoVIl^*EsZmE1%zT=PFr+5>bI+dG_VH5quklLsFs^SO81O9*Y-3X7rBT29Bp*wdO z#BG&?(FKN&C4)Ws?V(qpWG1-Y4LNH}EXK_mUK&OOy^ZFuyaCX;Zz%hy!0(1{%#Pja z70phmi)}z3HG`vpFDU^orwi45#T$v_L5wo$>b`jD$f$9bXzmgUjdhAwNe0^cp+Rw_ z9Z~2hq_IV8VpiXucLDjrmRBjnJ4^F-|J$OPWHDa5VG?g)Y}CJ7G<^E#^wn(qptHfr zuSptDgn=MJ?R6APTX`N_g;>~1;plIdrMNS1KdmsAHM1cR9jV%CSIBk88bWBw3AV6c zQFlkU-1+z?R^2jym_NRkL$R7?KMl*oLrMtOfU|KeuPfg+yjMsqJ3K74;36yzaE zJ$#8@C-$T>aa+qNyq4OEtNjycVZ_ND_z}p11fV7*2McM(l2WR{HL8?&niwk(k~j!h zZo~hJLR%I9DVPX6=oBI{w-qYw#wfnso;OLLKmD}+?2(mr>}Fjil+x>Y(itfUWu^{o zs#>s06a^k?vRO*fMotq;E+sPwEi4LSNIiJ5_C5uyqj$H1mXa5TV*}2{5LUX~j*aQt zW6WJ*>Bv^hB+y72J*8Re;bAF*q201e-cU8eC zUH(%>OEgUTz;J$H`wWxFp4Lu=(1xaR1f9CAGw59vI9SAM`kx4EU-AH+CVSJw08LI5 znmNUzaMaYQ7Bb_LhIlGu6N>|4Z5CdG<=GFju;te$Gm^%^U2!_7KG1H>IU|_q5oq)? zvil)+`GvKz_1NgsPwCOoZ1hP$O<=*U;XP|>t2nT4EQ(5t78#EDIcnOX`qL35qv^~a zy42QeDN|0yPDZI)Y5X3ig~*vR@3DOtBvAibW1a+arU0V@UgjfFl4psqou9N@N(EKh znrA{^DFZFGz+KxThlB_HteZnBbKA{Y*2oYBL<0d4VrKs*%oDpL9>3sf%bx=7!>hzt zxk%m9e8Fb$(AC!-fSPG}r_~0CeC}6l#p!Q*;ll!KOo6HY!n$Qav2QjX3~mxkaT1bk z?7HWbqjuaB7;;UTm}ng}<-7u(w~X77NMT{-8gDrbS7Fi?!X@3i2-Lvg4)7!08I_f3 zLs(Fdn*vC_qBLg){|-X^NEZAI@+&w8lLk%<~7nkw)hGc#p`mfF<(blDup@pxH%G+)nf*aL5TAaLz5Tr zFy*jgXti`8vZvr~CfM;ufOD99QTsB%m!f5OYB91TLEMqvPAg^fi|r29-FRb~`VMR7 zG!ZK8XMc6j?^o-@=FnbQ`>>7mo+XCMT;?KADBnBGJrIRjuEDG-w-MiPAT3wHk>t^Z zghp%+FQOH;*7_^3$CKd^cj9Kh&zXzH!+3#+4OoKm%4=# bool: [ self.handoffs, self.functions, + self.image_functions, self.computer_actions, ] ) @@ -207,7 +222,7 @@ async def execute_tools_and_side_effects( new_step_items.extend(processed_response.new_items) # First, lets run the tool calls - function tools and computer actions - function_results, computer_results = await asyncio.gather( + function_results, image_function_results, computer_results = await asyncio.gather( cls.execute_function_tool_calls( agent=agent, tool_runs=processed_response.functions, @@ -215,6 +230,13 @@ async def execute_tools_and_side_effects( context_wrapper=context_wrapper, config=run_config, ), + cls.execute_image_function_tool_calls( + agent=agent, + tool_runs=processed_response.image_functions, + hooks=hooks, + context_wrapper=context_wrapper, + config=run_config, + ), cls.execute_computer_actions( agent=agent, actions=processed_response.computer_actions, @@ -224,6 +246,7 @@ async def execute_tools_and_side_effects( ), ) new_step_items.extend([result.run_item for result in function_results]) + new_step_items.extend([result.run_item for result in image_function_results]) new_step_items.extend(computer_results) # Second, check if there are any handoffs @@ -342,10 +365,14 @@ def process_model_response( run_handoffs = [] functions = [] + image_functions = [] computer_actions = [] tools_used: list[str] = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} + image_function_map = { + tool.name: tool for tool in all_tools if isinstance(tool, ImageFunctionTool) + } computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) for output in response.output: @@ -393,6 +420,15 @@ def process_model_response( handoff=handoff_map[output.name], ) run_handoffs.append(handoff) + + elif output.name in image_function_map: + items.append(ToolCallItem(raw_item=output, agent=agent)) + image_functions.append( + ToolRunImageFunction( + tool_call=output, + image_function_tool=image_function_map[output.name], + ) + ) # Regular function tool call else: if output.name not in function_map: @@ -415,6 +451,7 @@ def process_model_response( new_items=items, handoffs=run_handoffs, functions=functions, + image_functions=image_functions, computer_actions=computer_actions, tools_used=tools_used, ) @@ -489,6 +526,78 @@ async def run_single_tool( for tool_run, result in zip(tool_runs, results) ] + @classmethod + async def execute_image_function_tool_calls( + cls, + *, + agent: Agent[TContext], + tool_runs: list[ToolRunImageFunction], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + config: RunConfig, + ) -> list[ImageFunctionToolResult]: + async def run_single_tool( + func_tool: ImageFunctionTool, tool_call: ResponseFunctionToolCall + ) -> Any: + with function_span(func_tool.name) as span_fn: + if config.trace_include_sensitive_data: + span_fn.span_data.input = tool_call.arguments + try: + _, _, result = await asyncio.gather( + hooks.on_tool_start(context_wrapper, agent, func_tool), + ( + agent.hooks.on_tool_start(context_wrapper, agent, func_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + func_tool.on_invoke_tool(context_wrapper, tool_call.arguments), + ) + + await asyncio.gather( + hooks.on_tool_end(context_wrapper, agent, func_tool, result), + ( + agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + except Exception as e: + _error_tracing.attach_error_to_current_span( + SpanError( + message="Error running tool", + data={"tool_name": func_tool.name, "error": str(e)}, + ) + ) + if isinstance(e, AgentsException): + raise e + raise UserError(f"Error running tool {func_tool.name}: {e}") from e + + if config.trace_include_sensitive_data: + span_fn.span_data.output = result + return result + + tasks = [] + for tool_run in tool_runs: + image_function_tool = tool_run.image_function_tool + tasks.append(run_single_tool(image_function_tool, tool_run.tool_call)) + + results = await asyncio.gather(*tasks) + + return [ + ImageFunctionToolResult( + tool=tool_run.image_function_tool, + output=result, + run_item=ToolCallOutputItem( + output=result, + raw_item=ItemHelpers.image_function_tool_call_output_item( + tool_run.tool_call, result + ), + agent=agent, + ), + ) + for tool_run, result in zip(tool_runs, results) + ] + @classmethod async def execute_computer_actions( cls, diff --git a/src/agents/items.py b/src/agents/items.py index 8fb2b52a3..db7dcd96d 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -57,7 +57,7 @@ class RunItemBase(Generic[T], abc.ABC): def to_input_item(self) -> TResponseInputItem: """Converts this item into an input item suitable for passing to the model.""" - if isinstance(self.raw_item, dict): + if isinstance(self.raw_item, dict) or isinstance(self.raw_item, list): # We know that input items are dicts, so we can ignore the type error return self.raw_item # type: ignore elif isinstance(self.raw_item, BaseModel): @@ -248,3 +248,25 @@ def tool_call_output_item( "output": output, "type": "function_call_output", } + + @classmethod + def image_function_tool_call_output_item( + cls, tool_call: ResponseFunctionToolCall, output: str + ) -> FunctionCallOutput: + """Creates a tool call output item from a tool call and its output.""" + return [ + { + "call_id": tool_call.call_id, + "output": "Image generating tool is called.", + "type": "function_call_output", + }, + { + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": output, + } + ], + }, + ] diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index c1ff85b98..f6591e388 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -23,7 +23,14 @@ from ..handoffs import Handoff from ..items import ItemHelpers, ModelResponse, TResponseInputItem from ..logger import logger -from ..tool import ComputerTool, FileSearchTool, FunctionTool, Tool, WebSearchTool +from ..tool import ( + ComputerTool, + FileSearchTool, + FunctionTool, + ImageFunctionTool, + Tool, + WebSearchTool, +) from ..tracing import SpanError, response_span from ..usage import Usage from ..version import __version__ @@ -358,6 +365,15 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: "description": tool.description, } includes: IncludeLiteral | None = None + elif isinstance(tool, ImageFunctionTool): + converted_tool: ToolParam = { + "name": tool.name, + "parameters": tool.params_json_schema, + "strict": tool.strict_json_schema, + "type": "function", + "description": tool.description, + } + includes: IncludeLiteral | None = None elif isinstance(tool, WebSearchTool): ws: WebSearchToolParam = { "type": "web_search_preview", diff --git a/src/agents/run.py b/src/agents/run.py index 849da7bfc..43da3dde6 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -755,7 +755,14 @@ async def _run_single_turn( output_schema = cls._get_output_schema(agent) handoffs = cls._get_handoffs(agent) input = ItemHelpers.input_to_new_input_list(original_input) - input.extend([generated_item.to_input_item() for generated_item in generated_items]) + + # input.extend([generated_item.to_input_item() for generated_item in generated_items]) + for generated_item in generated_items: + input_item_from_generated_item = generated_item.to_input_item() + if isinstance(input_item_from_generated_item, list): + input.extend(input_item_from_generated_item) + else: + input.append(input_item_from_generated_item) new_response = await cls._get_new_response( agent, diff --git a/src/agents/tool.py b/src/agents/tool.py index c1c162423..6f3e11ddc 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -42,6 +42,18 @@ class FunctionToolResult: """The run item that was produced as a result of the tool call.""" +@dataclass +class ImageFunctionToolResult: + tool: ImageFunctionTool + """The tool that was run.""" + + output: Any + """The output of the tool.""" + + run_item: RunItem + """The run item that was produced as a result of the tool call.""" + + @dataclass class FunctionTool: """A tool that wraps a function. In most cases, you should use the `function_tool` helpers to @@ -73,6 +85,37 @@ class FunctionTool: as it increases the likelihood of correct JSON input.""" +@dataclass +class ImageFunctionTool: + """A tool that wraps a function that generates an image. In most cases, you should use the `image_function_tool` helpers to + create a ImageFunctionTool, as they let you easily wrap a Python function. + """ + + name: str + """The name of the tool, as shown to the LLM. Generally the name of the function.""" + + description: str + """A description of the tool, as shown to the LLM.""" + + params_json_schema: dict[str, Any] + """The JSON schema for the tool's parameters.""" + + on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]] + """A function that invokes the tool with the given context and parameters. The params passed + are: + 1. The tool run context. + 2. The arguments from the LLM, as a JSON string. + + You must return a string representation of the tool output, or something we can call `str()` on. + In case of errors, you can either raise an Exception (which will cause the run to fail) or + return a string error message (which will be sent back to the LLM). + """ + + strict_json_schema: bool = True + """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, + as it increases the likelihood of correct JSON input.""" + + @dataclass class FileSearchTool: """A hosted tool that lets the LLM search through a vector store. Currently only supported with @@ -308,3 +351,142 @@ def decorator(real_func: ToolFunction[...]) -> FunctionTool: return _create_function_tool(real_func) return decorator + + +def image_function_tool( + func: ToolFunction[...] | None = None, + *, + name_override: str | None = None, + description_override: str | None = None, + docstring_style: DocstringStyle | None = None, + use_docstring_info: bool = True, + failure_error_function: ToolErrorFunction | None = default_tool_error_function, + strict_mode: bool = True, +) -> ImageFunctionTool | Callable[[ToolFunction[...]], ImageFunctionTool]: + """ + Decorator to create a FunctionTool from a function. By default, we will: + 1. Parse the function signature to create a JSON schema for the tool's parameters. + 2. Use the function's docstring to populate the tool's description. + 3. Use the function's docstring to populate argument descriptions. + The docstring style is detected automatically, but you can override it. + + If the function takes a `RunContextWrapper` as the first argument, it *must* match the + context type of the agent that uses the tool. + + Args: + func: The function to wrap. + name_override: If provided, use this name for the tool instead of the function's name. + description_override: If provided, use this description for the tool instead of the + function's docstring. + docstring_style: If provided, use this style for the tool's docstring. If not provided, + we will attempt to auto-detect the style. + use_docstring_info: If True, use the function's docstring to populate the tool's + description and argument descriptions. + failure_error_function: If provided, use this function to generate an error message when + the tool call fails. The error message is sent to the LLM. If you pass None, then no + error message will be sent and instead an Exception will be raised. + strict_mode: Whether to enable strict mode for the tool's JSON schema. We *strongly* + recommend setting this to True, as it increases the likelihood of correct JSON input. + If False, it allows non-strict JSON schemas. For example, if a parameter has a default + value, it will be optional, additional properties are allowed, etc. See here for more: + https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas + """ + + def _create_image_function_tool(the_func: ToolFunction[...]) -> ImageFunctionTool: + schema = function_schema( + func=the_func, + name_override=name_override, + description_override=description_override, + docstring_style=docstring_style, + use_docstring_info=use_docstring_info, + strict_json_schema=strict_mode, + ) + + async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any: + try: + json_data: dict[str, Any] = json.loads(input) if input else {} + except Exception as e: + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Invalid JSON input for tool {schema.name}") + else: + logger.debug(f"Invalid JSON input for tool {schema.name}: {input}") + raise ModelBehaviorError( + f"Invalid JSON input for tool {schema.name}: {input}" + ) from e + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Invoking tool {schema.name}") + else: + logger.debug(f"Invoking tool {schema.name} with input {input}") + + try: + parsed = ( + schema.params_pydantic_model(**json_data) + if json_data + else schema.params_pydantic_model() + ) + except ValidationError as e: + raise ModelBehaviorError(f"Invalid JSON input for tool {schema.name}: {e}") from e + + args, kwargs_dict = schema.to_call_args(parsed) + + if not _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Tool call args: {args}, kwargs: {kwargs_dict}") + + if inspect.iscoroutinefunction(the_func): + if schema.takes_context: + result = await the_func(ctx, *args, **kwargs_dict) + else: + result = await the_func(*args, **kwargs_dict) + else: + if schema.takes_context: + result = the_func(ctx, *args, **kwargs_dict) + else: + result = the_func(*args, **kwargs_dict) + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Tool {schema.name} completed.") + else: + logger.debug(f"Tool {schema.name} returned {result}") + + return result + + async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any: + try: + return await _on_invoke_tool_impl(ctx, input) + except Exception as e: + if failure_error_function is None: + raise + + result = failure_error_function(ctx, e) + if inspect.isawaitable(result): + return await result + + _error_tracing.attach_error_to_current_span( + SpanError( + message="Error running tool (non-fatal)", + data={ + "tool_name": schema.name, + "error": str(e), + }, + ) + ) + return result + + return ImageFunctionTool( + name=schema.name, + description=schema.description or "", + params_json_schema=schema.params_json_schema, + on_invoke_tool=_on_invoke_tool, + strict_json_schema=strict_mode, + ) + + # If func is actually a callable, we were used as @function_tool with no parentheses + if callable(func): + return _create_image_function_tool(func) + + # Otherwise, we were used as @function_tool(...), so return a decorator + def decorator(real_func: ToolFunction[...]) -> ImageFunctionTool: + return _create_image_function_tool(real_func) + + return decorator diff --git a/src/agents/voice/model.py b/src/agents/voice/model.py index c36a4de76..b048a452d 100644 --- a/src/agents/voice/model.py +++ b/src/agents/voice/model.py @@ -17,9 +17,11 @@ TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] """Exportable type for the TTSModelSettings voice enum""" + @dataclass class TTSModelSettings: """Settings for a TTS model.""" + voice: TTSVoice | None = None """ The voice to use for the TTS model. If not provided, the default voice for the respective model diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py index f29c25408..8efa95a76 100644 --- a/tests/test_extra_headers.py +++ b/tests/test_extra_headers.py @@ -17,21 +17,21 @@ class DummyResponses: async def create(self, **kwargs): nonlocal called_kwargs called_kwargs = kwargs + class DummyResponse: id = "dummy" output = [] usage = type( "Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} )() + return DummyResponse() class DummyClient: def __init__(self): self.responses = DummyResponses() - - - model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, @@ -47,7 +47,6 @@ def __init__(self): assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" - @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_extra_headers_passed_to_openai_client(): @@ -76,7 +75,7 @@ def __init__(self): self.chat = type("_Chat", (), {"completions": DummyCompletions()})() self.base_url = "https://api.openai.com" - model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, diff --git a/tests/tools/test_image_function_tool.py b/tests/tools/test_image_function_tool.py new file mode 100644 index 000000000..7d22a1c2f --- /dev/null +++ b/tests/tools/test_image_function_tool.py @@ -0,0 +1,248 @@ +import asyncio +import json +from typing import Any + +import pytest +from pydantic import BaseModel +from typing_extensions import TypedDict + +from src.agents import ( + ImageFunctionTool, + ModelBehaviorError, + RunContextWrapper, + image_function_tool, +) +from src.agents.tool import default_tool_error_function + +# A dummy base64 encoded image string (e.g., a tiny 1x1 pixel red PNG) +DUMMY_IMAGE_BASE64 = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" + +# =================== Basic Tests for image_function_tool =================== + + +def argless_image_function() -> str: + return DUMMY_IMAGE_BASE64 + + +@pytest.mark.asyncio +async def test_argless_image_function(): + tool = image_function_tool(argless_image_function) + assert tool.name == "argless_image_function" + + result = await tool.on_invoke_tool(RunContextWrapper(None), "") + assert result == DUMMY_IMAGE_BASE64 + + +def argless_with_context_image(ctx: RunContextWrapper[str]) -> str: + return DUMMY_IMAGE_BASE64 + + +@pytest.mark.asyncio +async def test_argless_with_context_image(): + tool = image_function_tool(argless_with_context_image) + assert tool.name == "argless_with_context_image" + + result = await tool.on_invoke_tool(RunContextWrapper(None), "") + assert result == DUMMY_IMAGE_BASE64 + + # Extra JSON should not raise an error + result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}') + assert result == DUMMY_IMAGE_BASE64 + + +def simple_image_function(prompt: str, style: str = "realistic") -> str: + # In a real scenario, these parameters would affect the generated image + return DUMMY_IMAGE_BASE64 + + +@pytest.mark.asyncio +async def test_simple_image_function(): + tool = image_function_tool(simple_image_function, failure_error_function=None) + assert tool.name == "simple_image_function" + + result = await tool.on_invoke_tool(RunContextWrapper(None), '{"prompt": "cat"}') + assert result == DUMMY_IMAGE_BASE64 + + result = await tool.on_invoke_tool( + RunContextWrapper(None), '{"prompt": "dog", "style": "cartoon"}' + ) + assert result == DUMMY_IMAGE_BASE64 + + # Missing required argument should raise an error + with pytest.raises(ModelBehaviorError): + await tool.on_invoke_tool(RunContextWrapper(None), "") + + +class ImageParams(BaseModel): + prompt: str + width: int = 512 + height: int = 512 + + +class StyleOptions(TypedDict): + style: str + seed: int + + +def complex_args_image_function(params: ImageParams, style_options: StyleOptions) -> str: + # In a real scenario, these parameters would affect the generated image + return DUMMY_IMAGE_BASE64 + + +@pytest.mark.asyncio +async def test_complex_args_image_function(): + tool = image_function_tool(complex_args_image_function, failure_error_function=None) + assert tool.name == "complex_args_image_function" + + valid_json = json.dumps( + { + "params": ImageParams(prompt="sunset").model_dump(), + "style_options": StyleOptions(style="realistic", seed=42), + } + ) + result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json) + assert result == DUMMY_IMAGE_BASE64 + + valid_json = json.dumps( + { + "params": ImageParams(prompt="mountains", width=1024, height=768).model_dump(), + "style_options": StyleOptions(style="abstract", seed=123), + } + ) + result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json) + assert result == DUMMY_IMAGE_BASE64 + + # Missing required argument should raise an error + with pytest.raises(ModelBehaviorError): + await tool.on_invoke_tool(RunContextWrapper(None), '{"params": {"prompt": "forest"}}') + + +def test_image_function_config_overrides(): + tool = image_function_tool(simple_image_function, name_override="custom_image_name") + assert tool.name == "custom_image_name" + + tool = image_function_tool(simple_image_function, description_override="Generate custom images") + assert tool.description == "Generate custom images" + + tool = image_function_tool( + simple_image_function, + name_override="art_generator", + description_override="Creates beautiful art images", + ) + assert tool.name == "art_generator" + assert tool.description == "Creates beautiful art images" + + +def test_image_function_schema_is_strict(): + tool = image_function_tool(simple_image_function) + assert tool.strict_json_schema, "Should be strict by default" + assert ( + "additionalProperties" in tool.params_json_schema + and not tool.params_json_schema["additionalProperties"] + ) + + tool = image_function_tool(complex_args_image_function) + assert tool.strict_json_schema, "Should be strict by default" + assert ( + "additionalProperties" in tool.params_json_schema + and not tool.params_json_schema["additionalProperties"] + ) + + +@pytest.mark.asyncio +async def test_manual_image_function_tool_creation_works(): + def generate_image(prompt: str) -> str: + return DUMMY_IMAGE_BASE64 + + class ImageArgs(BaseModel): + prompt: str + + async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: + parsed = ImageArgs.model_validate_json(args) + return generate_image(prompt=parsed.prompt) + + tool = ImageFunctionTool( + name="image_creator", + description="Creates images from text prompts", + params_json_schema=ImageArgs.model_json_schema(), + on_invoke_tool=run_function, + ) + + assert tool.name == "image_creator" + assert tool.description == "Creates images from text prompts" + for key, value in ImageArgs.model_json_schema().items(): + assert tool.params_json_schema[key] == value + assert tool.strict_json_schema + + result = await tool.on_invoke_tool(RunContextWrapper(None), '{"prompt": "sunset"}') + assert result == DUMMY_IMAGE_BASE64 + + tool_not_strict = ImageFunctionTool( + name="image_creator", + description="Creates images from text prompts", + params_json_schema=ImageArgs.model_json_schema(), + on_invoke_tool=run_function, + strict_json_schema=False, + ) + + assert not tool_not_strict.strict_json_schema + assert "additionalProperties" not in tool_not_strict.params_json_schema + + result = await tool_not_strict.on_invoke_tool( + RunContextWrapper(None), '{"prompt": "sunset", "style": "realistic"}' + ) + assert result == DUMMY_IMAGE_BASE64 + + +@pytest.mark.asyncio +async def test_image_function_tool_default_error_works(): + def failing_image_generator(prompt: str) -> str: + raise ValueError("Image generation failed") + + tool = image_function_tool(failing_image_generator) + ctx = RunContextWrapper(None) + + result = await tool.on_invoke_tool(ctx, "") + assert "Invalid JSON" in str(result) + + result = await tool.on_invoke_tool(ctx, "{}") + assert "Invalid JSON" in str(result) + + result = await tool.on_invoke_tool(ctx, '{"prompt": "sunset"}') + assert result == default_tool_error_function(ctx, ValueError("Image generation failed")) + + +@pytest.mark.asyncio +async def test_sync_custom_error_function_works_for_image_tool(): + def failing_image_generator(prompt: str) -> str: + raise ValueError("Image generation failed") + + def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> str: + return f"error_{error.__class__.__name__}_image" + + tool = image_function_tool( + failing_image_generator, failure_error_function=custom_sync_error_function + ) + ctx = RunContextWrapper(None) + + result = await tool.on_invoke_tool(ctx, "") + assert result == "error_ModelBehaviorError_image" + + result = await tool.on_invoke_tool(ctx, "{}") + assert result == "error_ModelBehaviorError_image" + + result = await tool.on_invoke_tool(ctx, '{"prompt": "sunset"}') + assert result == "error_ValueError_image" + + +@pytest.mark.asyncio +async def test_async_image_generator(): + async def async_image_generator(prompt: str) -> str: + # Simulate some async operation + await asyncio.sleep(0.01) + return DUMMY_IMAGE_BASE64 + + tool = image_function_tool(async_image_generator) + + result = await tool.on_invoke_tool(RunContextWrapper(None), '{"prompt": "sunset"}') + assert result == DUMMY_IMAGE_BASE64