Skip to content

feat: Implement get_tool_call_output method in RunResultBase and update doc #637

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/results.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ The [`new_items`][agents.result.RunResultBase.new_items] property contains the n
- [`ToolCallOutputItem`][agents.items.ToolCallOutputItem] indicates that a tool was called. The raw item is the tool response. You can also access the tool output from the item.
- [`ReasoningItem`][agents.items.ReasoningItem] indicates a reasoning item from the LLM. The raw item is the reasoning generated.

### Getting Tool Call Outputs

You can use the [`get_tool_call_output(tool_name)`][agents.result.RunResultBase.get_tool_call_output] method to retrieve the output of a specific tool call by its name. This is useful when you need to access the result of a particular tool execution:

```python
result = await Runner.run(agent, "input")
tool_output = result.get_tool_call_output("my_tool_name")
if tool_output:
# Process tool output
pass
```

The method returns `None` if the tool wasn't called or if there was an error retrieving the output.

## Other information

### Guardrail results
Expand Down
45 changes: 42 additions & 3 deletions src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import asyncio
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Optional, cast

from typing_extensions import TypeVar

Expand All @@ -13,12 +13,14 @@
from .agent_output import AgentOutputSchemaBase
from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .items import (ItemHelpers, ModelResponse, RunItem, ToolCallOutputItem,
TResponseInputItem)
from .logger import logger
from .run_context import RunContextWrapper
from .stream_events import StreamEvent
from .tracing import Trace
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
from .util._pretty_print import (pretty_print_result,
pretty_print_run_result_streaming)

if TYPE_CHECKING:
from ._run_impl import QueueCompleteSentinel
Expand Down Expand Up @@ -83,6 +85,43 @@ def to_input_list(self) -> list[TResponseInputItem]:
new_items = [item.to_input_item() for item in self.new_items]

return original_items + new_items

def get_tool_call_output(self, tool_name: str) -> Optional[ToolCallOutputItem]:
"""
Get the tool call output for a specific tool from the agent result.

Args:
tool_name: The name of the tool to find

Returns:
The tool call output item if found, None otherwise

Examples:
>>> result = RunResult(...)
>>> output = result.get_tool_call_output("calculator")
>>> if output:
... print(output.raw_item)
"""
try:
# Find the tool call item
tool_call_item = next(
(item for item in self.new_items
if item.type == "tool_call_item" and item.raw_item.name == tool_name),
None
)

# Find matching tool call output item with same call_id
if tool_call_item:
return next(
(item for item in self.new_items
if item.type == "tool_call_output_item"
and item.raw_item.get('call_id', None) == tool_call_item.raw_item.call_id),
None
)
return None
except Exception as e:
logger.error(f"Error getting tool call output for {tool_name}: {str(e)}")
return None

@property
def last_response_id(self) -> str | None:
Expand Down
111 changes: 102 additions & 9 deletions tests/test_output_tool.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import json
from typing import Any
from dataclasses import dataclass
from typing import Any, List

import pytest
from pydantic import BaseModel
from typing_extensions import TypedDict

from agents import (
Agent,
AgentOutputSchema,
AgentOutputSchemaBase,
ModelBehaviorError,
Runner,
UserError,
)
from agents import (Agent, AgentOutputSchema, AgentOutputSchemaBase,
ModelBehaviorError, Runner, UserError)
from agents.agent_output import _WRAPPER_DICT_KEY
from agents.items import RunItem, ToolCallItem, ToolCallOutputItem
from agents.result import RunResultBase
from agents.util import _json


Expand Down Expand Up @@ -166,3 +163,99 @@ def test_custom_output_schema():
json_str = json.dumps({"foo": "bar"})
validated = output_schema.validate_json(json_str)
assert validated == ["some", "output"]


@dataclass
class MockRunResult(RunResultBase):
"""Mock implementation of RunResultBase for testing"""
input: str
new_items: List[RunItem]
raw_responses: List[Any]
final_output: Any
input_guardrail_results: List[Any]
output_guardrail_results: List[Any]
context_wrapper: Any

@property
def last_agent(self):
return None

def test_get_tool_call_output():
# Create a mock agent
agent = Agent(name="test")

# Create mock tool call and output items
tool_call = ToolCallItem(
type="tool_call_item",
raw_item=type("ToolCall", (), {
"name": "test_tool",
"call_id": "123",
}),
agent=agent
)

tool_output = ToolCallOutputItem(
type="tool_call_output_item",
raw_item={
"call_id": "123",
"content": "tool output"
},
agent=agent,
output="tool output"
)

# Test successful tool call output retrieval
result = MockRunResult(
input="test input",
new_items=[tool_call, tool_output],
raw_responses=[],
final_output=None,
input_guardrail_results=[],
output_guardrail_results=[],
context_wrapper=None
)

output = result.get_tool_call_output("test_tool")
assert output is not None
assert output.type == "tool_call_output_item"
assert output.raw_item["content"] == "tool output"

# Test non-existent tool
output = result.get_tool_call_output("non_existent_tool")
assert output is None

# Test with empty items list
empty_result = MockRunResult(
input="test input",
new_items=[],
raw_responses=[],
final_output=None,
input_guardrail_results=[],
output_guardrail_results=[],
context_wrapper=None
)
output = empty_result.get_tool_call_output("test_tool")
assert output is None

# Test with mismatched call_id
mismatched_output = ToolCallOutputItem(
type="tool_call_output_item",
raw_item={
"call_id": "456", # Different call_id
"content": "tool output"
},
agent=agent,
output="tool output"
)

mismatched_result = MockRunResult(
input="test input",
new_items=[tool_call, mismatched_output],
raw_responses=[],
final_output=None,
input_guardrail_results=[],
output_guardrail_results=[],
context_wrapper=None
)
output = mismatched_result.get_tool_call_output("test_tool")
assert output is None