Skip to content

Commit

Permalink
fix: incorporate review feedback
Browse files Browse the repository at this point in the history
Signed-off-by: MICHAEL DESMOND <mdesmond@us.ibm.com>
  • Loading branch information
michael-desmond committed Mar 6, 2025
1 parent 1f989bc commit bed0857
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 50 deletions.
9 changes: 5 additions & 4 deletions python/beeai_framework/adapters/langchain/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from beeai_framework.context import RunContext
from beeai_framework.emitter.emitter import Emitter
from beeai_framework.tools.tool import Tool, ToolRunOptions
from beeai_framework.tools.tool import StringToolOutput, Tool, ToolRunOptions
from beeai_framework.utils.strings import to_safe_word


Expand All @@ -35,7 +35,7 @@ class LangChainToolRunOptions(ToolRunOptions):
T = TypeVar("T", bound=BaseModel)


class LangChainTool(Tool[T, LangChainToolRunOptions]):
class LangChainTool(Tool[T, LangChainToolRunOptions, StringToolOutput]):
@property
def name(self) -> str:
return self._tool.name
Expand All @@ -58,7 +58,7 @@ def __init__(self, tool: StructuredTool | LangChainSimpleTool, options: dict[str
super().__init__(options)
self._tool = tool

async def _run(self, input: T, options: LangChainToolRunOptions | None, context: RunContext) -> Any:
async def _run(self, input: T, options: LangChainToolRunOptions | None, context: RunContext) -> StringToolOutput:
langchain_runnable_config = options.langchain_runnable_config or {} if options else {}
args = (
input if isinstance(input, dict) else input.model_dump(),
Expand All @@ -74,4 +74,5 @@ async def _run(self, input: T, options: LangChainToolRunOptions | None, context:
response = await self._tool.ainvoke(*args)
else:
response = self._tool.invoke(*args)
return response

return StringToolOutput(result=str(response))
2 changes: 1 addition & 1 deletion python/beeai_framework/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from beeai_framework.utils.asynchronous import ensure_async
from beeai_framework.utils.custom_logger import BeeLogger

R = TypeVar("R", bound=BaseModel)
R = TypeVar("R")

logger = BeeLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions python/beeai_framework/tools/mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
from beeai_framework.context import RunContext
from beeai_framework.emitter import Emitter
from beeai_framework.tools import Tool
from beeai_framework.tools.tool import JSONToolOutput, ToolRunOptions
from beeai_framework.tools.tool import JSONToolOutput, ToolOutput, ToolRunOptions
from beeai_framework.utils import BeeLogger
from beeai_framework.utils.models import json_to_model
from beeai_framework.utils.strings import to_safe_word

logger = BeeLogger(__name__)


class MCPTool(Tool[BaseModel, ToolRunOptions]):
class MCPTool(Tool[BaseModel, ToolRunOptions, ToolOutput]):
"""Tool implementation for Model Context Protocol."""

def __init__(self, server_params: StdioServerParameters, tool: MCPToolInfo, **options: int) -> None:
Expand Down
2 changes: 1 addition & 1 deletion python/beeai_framework/tools/search/duckduckgo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DuckDuckGoSearchToolOutput(SearchToolOutput):
pass


class DuckDuckGoSearchTool(Tool[DuckDuckGoSearchToolInput, ToolRunOptions]):
class DuckDuckGoSearchTool(Tool[DuckDuckGoSearchToolInput, ToolRunOptions, DuckDuckGoSearchToolOutput]):
name = "DuckDuckGo"
description = "Search for online trends, news, current events, real-time information, or research topics."
input_schema = DuckDuckGoSearchToolInput
Expand Down
2 changes: 1 addition & 1 deletion python/beeai_framework/tools/search/wikipedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class WikipediaToolOutput(SearchToolOutput):
pass


class WikipediaTool(Tool[WikipediaToolInput, ToolRunOptions]):
class WikipediaTool(Tool[WikipediaToolInput, ToolRunOptions, WikipediaToolOutput]):
name = "Wikipedia"
description = "Search factual and historical information, including biography, \
history, politics, geography, society, culture, science, technology, people, \
Expand Down
24 changes: 16 additions & 8 deletions python/beeai_framework/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def __str__(self) -> str:
return self.get_text_content()


OUT = TypeVar("OUT", bound=ToolOutput)


class StringToolOutput(ToolOutput):
def __init__(self, result: str = "") -> None:
super().__init__()
Expand All @@ -83,7 +86,7 @@ def is_empty(self) -> bool:
return not self.result


class Tool(Generic[IN, OPT], ABC):
class Tool(Generic[IN, OPT, OUT], ABC):
def __init__(self, options: dict[str, Any] | None = None) -> None:
self.options: dict[str, Any] | None = options or None

Expand Down Expand Up @@ -111,7 +114,7 @@ def _create_emitter(self) -> Emitter:
pass

@abstractmethod
async def _run(self, input: IN, options: OPT | None, context: RunContext) -> Any:
async def _run(self, input: IN, options: OPT | None, context: RunContext) -> OUT:
pass

def validate_input(self, input: IN | dict[str, Any]) -> IN:
Expand All @@ -120,8 +123,8 @@ def validate_input(self, input: IN | dict[str, Any]) -> IN:
except ValidationError as e:
raise ToolInputValidationError("Tool input validation error", cause=e)

def run(self, input: IN | dict[str, Any], options: OPT | None = None) -> Run[IN]:
async def run_tool(context: RunContext) -> IN:
def run(self, input: IN | dict[str, Any], options: OPT | None = None) -> Run[OUT]:
async def run_tool(context: RunContext) -> OUT:
error_propagated = False

try:
Expand Down Expand Up @@ -217,7 +220,7 @@ def tool(tool_function: Callable) -> Tool:
if tool_description is None:
raise ValueError("No tool description provided.")

class FunctionTool(Tool[Any, ToolRunOptions]):
class FunctionTool(Tool[Any, ToolRunOptions, ToolOutput]):
name = tool_name
description = tool_description or ""
input_schema = tool_input
Expand All @@ -231,12 +234,17 @@ def _create_emitter(self) -> Emitter:
creator=self,
)

async def _run(self, input: Any, options: ToolRunOptions | None, context: RunContext) -> StringToolOutput:
async def _run(self, input: Any, options: ToolRunOptions | None, context: RunContext) -> ToolOutput:
tool_input_dict = input.model_dump()
if inspect.iscoroutinefunction(tool_function):
return StringToolOutput(str(await tool_function(**tool_input_dict)))
result = await tool_function(**tool_input_dict)
else:
result = tool_function(**tool_input_dict)

if isinstance(result, ToolOutput):
return result
else:
return StringToolOutput(str(tool_function(**tool_input_dict)))
return StringToolOutput(result=str(result))

f_tool = FunctionTool()
return f_tool
2 changes: 1 addition & 1 deletion python/beeai_framework/tools/weather/openmeteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class OpenMeteoToolInput(BaseModel):
)


class OpenMeteoTool(Tool[OpenMeteoToolInput, ToolRunOptions]):
class OpenMeteoTool(Tool[OpenMeteoToolInput, ToolRunOptions, StringToolOutput]):
name = "OpenMeteoTool"
description = "Retrieve current, past, or future weather forecasts for a location."
input_schema = OpenMeteoToolInput
Expand Down
18 changes: 10 additions & 8 deletions python/docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ class RiddleToolInput(BaseModel):
riddle_number: int = Field(description="Index of riddle to retrieve.")


class RiddleTool(Tool[RiddleToolInput, ToolRunOptions]):
class RiddleTool(Tool[RiddleToolInput, ToolRunOptions, StringToolOutput]):
name = "Riddle"
description = "It selects a riddle to test your knowledge."
input_schema = RiddleToolInput
Expand Down Expand Up @@ -435,7 +435,7 @@ from beeai_framework.context import RunContext
from beeai_framework.emitter.emitter import Emitter
from beeai_framework.errors import FrameworkError
from beeai_framework.tools import ToolInputValidationError
from beeai_framework.tools.tool import Tool, ToolRunOptions
from beeai_framework.tools.tool import JSONToolOutput, Tool, ToolRunOptions


class OpenLibraryToolInput(BaseModel):
Expand All @@ -450,7 +450,7 @@ class OpenLibraryToolResult(BaseModel):
bib_key: str


class OpenLibraryTool(Tool[OpenLibraryToolInput, ToolRunOptions]):
class OpenLibraryTool(Tool[OpenLibraryToolInput, ToolRunOptions, JSONToolOutput]):
name = "OpenLibrary"
description = """Provides access to a library of books with information about book titles,
authors, contributors, publication dates, publisher and isbn."""
Expand All @@ -467,7 +467,7 @@ class OpenLibraryTool(Tool[OpenLibraryToolInput, ToolRunOptions]):

async def _run(
self, tool_input: OpenLibraryToolInput, options: ToolRunOptions | None, context: RunContext
) -> OpenLibraryToolResult:
) -> JSONToolOutput:
key = ""
value = ""
input_vars = vars(tool_input)
Expand All @@ -489,10 +489,12 @@ class OpenLibraryTool(Tool[OpenLibraryToolInput, ToolRunOptions]):

json_output = response.json()[f"{key}:{value}"]

return OpenLibraryToolResult(
preview_url=json_output.get("preview_url", ""),
info_url=json_output.get("info_url", ""),
bib_key=json_output.get("bib_key", ""),
return JSONToolOutput(
result={
"preview_url": json_output.get("preview_url", ""),
"info_url": json_output.get("info_url", ""),
"bib_key": json_output.get("bib_key", ""),
}
)


Expand Down
Loading

0 comments on commit bed0857

Please # to comment.