Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

ENG-1392: SQL tool in Agents #400

Merged
merged 8 commits into from
Feb 17, 2025
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions aixplain/factories/agent_factory/__init__.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@
from aixplain.modules.agent.tool.pipeline_tool import PipelineTool
from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool
from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool
from aixplain.modules.agent.tool.sql_tool import SQLTool
from aixplain.modules.model import Model
from aixplain.modules.pipeline import Pipeline
from aixplain.utils import config
@@ -188,6 +189,23 @@ def create_custom_python_code_tool(cls, code: Union[Text, Callable], description
"""Create a new custom python code tool."""
return CustomPythonCodeTool(description=description, code=code)

@classmethod
def create_sql_tool(
cls, description: Text, database: Text, schema: Optional[Text] = None, tables: Optional[List[Text]] = None
) -> SQLTool:
"""Create a new SQL tool

Args:
description (Text): description of the database tool
database (Text): URL/local path of the SQLite database file
schema (Optional[Text], optional): database schema description (optional)
tables (Optional[List[Text]], optional): table names to work with (optional)

Returns:
SQLTool: created SQLTool
"""
return SQLTool(description=description, database=database, schema=schema, tables=tables)

@classmethod
def list(cls) -> Dict:
"""List all agents available in the platform."""
8 changes: 8 additions & 0 deletions aixplain/factories/agent_factory/utils.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
from aixplain.modules.agent.tool.pipeline_tool import PipelineTool
from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool
from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool
from aixplain.modules.agent.tool.sql_tool import SQLTool
from typing import Dict, Text
from urllib.parse import urljoin

@@ -45,6 +46,13 @@ def build_agent(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> Agent:
tool = CustomPythonCodeTool(description=tool["description"], code=tool["utilityCode"])
else:
tool = PythonInterpreterTool()
elif tool["type"] == "sql":
parameters = {parameter["name"]: parameter["value"] for parameter in tool.get("parameters", [])}
database = parameters.get("database")
schema = parameters.get("schema")
tables = parameters.get("tables", None)
tables = tables.split(",") if tables is not None else None
tool = SQLTool(description=tool["description"], database=database, schema=schema, tables=tables)
else:
raise Exception("Agent Creation Error: Tool type not supported.")
tools.append(tool)
14 changes: 12 additions & 2 deletions aixplain/factories/file_factory.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@
MB_25 = 26214400
MB_50 = 52428800
MB_300 = 314572800
MB_500 = 524288000


class FileFactory:
@@ -65,8 +66,17 @@ def upload(
else:
content_type = mime_type

type_to_max_size = {"audio": MB_50, "application": MB_25, "video": MB_300, "image": MB_25, "other": MB_50}
if mime_type is None or mime_type.split("/")[0] not in type_to_max_size:
type_to_max_size = {
"audio": MB_50,
"application": MB_25,
"video": MB_300,
"image": MB_25,
"other": MB_50,
"database": MB_500,
}
if local_path.endswith(".db"):
ftype = "database"
elif mime_type is None or mime_type.split("/")[0] not in type_to_max_size:
ftype = "other"
else:
ftype = mime_type.split("/")[0]
87 changes: 87 additions & 0 deletions aixplain/modules/agent/tool/sql_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
__author__ = "aiXplain"

"""
Copyright 2024 The aiXplain SDK authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Author: Lucas Pavanelli and Thiago Castro Ferreira
Date: May 16th 2024
Description:
Agentification Class
"""
import os
import validators
from typing import Text, Optional, Dict, List, Union

from aixplain.modules.agent.tool import Tool


class SQLTool(Tool):
"""Tool to execute SQL commands in an SQLite database.

Attributes:
description (Text): description of the tool
database (Text): database name
schema (Text): database schema description
tables (Optional[Union[List[Text], Text]]): table names to work with (optional)
"""

def __init__(
self,
description: Text,
database: Text,
schema: Text,
tables: Optional[Union[List[Text], Text]] = None,
**additional_info,
) -> None:
"""Tool to execute SQL query commands in an SQLite database.

Args:
description (Text): description of the tool
database (Text): database name
schema (Text): database schema description
tables (Optional[Union[List[Text], Text]]): table names to work with (optional)
"""
super().__init__("", description, **additional_info)
self.database = database
self.schema = schema
self.tables = tables if isinstance(tables, list) else [tables] if tables else None

def to_dict(self) -> Dict[str, Text]:
return {
"description": self.description,
"parameters": [
{"name": "database", "value": self.database},
{"name": "schema", "value": self.schema},
{"name": "tables", "value": ",".join(self.tables) if self.tables is not None else None},
],
"type": "sql",
}

def validate(self):
from aixplain.factories.file_factory import FileFactory

assert self.description and self.description.strip() != "", "SQL Tool Error: Description is required"
assert self.database and self.database.strip() != "", "SQL Tool Error: Database is required"
if not (
str(self.database).startswith("s3://")
or str(self.database).startswith("http://")
or str(self.database).startswith("https://")
or validators.url(self.database)
):
if not os.path.exists(self.database):
raise Exception(f"SQL Tool Error: Database '{self.database}' does not exist")
if not self.database.endswith(".db"):
raise Exception(f"SQL Tool Error: Database '{self.database}' must have .db extension")
self.database = FileFactory.upload(local_path=self.database, is_temp=True)
30 changes: 14 additions & 16 deletions aixplain/v2/agent.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
TYPE_CHECKING,
Callable,
NotRequired,
Optional,
)

from .resource import (
@@ -27,7 +28,7 @@
PythonInterpreterTool,
)
from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool

from aixplain.modules.agent.tool.sql_tool import SQLTool
from .enums import Function


@@ -97,20 +98,14 @@ def create_model_tool(
):
from aixplain.factories import AgentFactory

return AgentFactory.create_model_tool(
model=model, function=function, supplier=supplier, description=description
)
return AgentFactory.create_model_tool(model=model, function=function, supplier=supplier, description=description)

@classmethod
def create_pipeline_tool(
cls, description: str, pipeline: Union["Pipeline", str]
) -> "PipelineTool":
def create_pipeline_tool(cls, description: str, pipeline: Union["Pipeline", str]) -> "PipelineTool":
"""Create a new pipeline tool."""
from aixplain.factories import AgentFactory

return AgentFactory.create_pipeline_tool(
description=description, pipeline=pipeline
)
return AgentFactory.create_pipeline_tool(description=description, pipeline=pipeline)

@classmethod
def create_python_interpreter_tool(cls) -> "PythonInterpreterTool":
@@ -120,12 +115,15 @@ def create_python_interpreter_tool(cls) -> "PythonInterpreterTool":
return AgentFactory.create_python_interpreter_tool()

@classmethod
def create_custom_python_code_tool(
cls, code: Union[str, Callable], description: str = ""
) -> "CustomPythonCodeTool":
def create_custom_python_code_tool(cls, code: Union[str, Callable], description: str = "") -> "CustomPythonCodeTool":
"""Create a new custom python code tool."""
from aixplain.factories import AgentFactory

return AgentFactory.create_custom_python_code_tool(
code=code, description=description
)
return AgentFactory.create_custom_python_code_tool(code=code, description=description)

@classmethod
def create_sql_tool(cls, description: str, database: str, schema: str, table: Optional[str] = None) -> "SQLTool":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have tables as well:
tables: Optional[List[Text]]

"""Create a new SQL tool."""
from aixplain.factories import AgentFactory

return AgentFactory.create_sql_tool(description=description, database=database, schema=schema, tables=table)
23 changes: 23 additions & 0 deletions tests/functional/agent/agent_functional_test.py
Original file line number Diff line number Diff line change
@@ -345,3 +345,26 @@ def test_specific_model_parameters_e2e(tool_config):
tool_used = True
break
assert tool_used, "Tool was not used in execution"


@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent])
def test_sql_tool(delete_agents_and_team_agents, AgentFactory):
assert delete_agents_and_team_agents
tool = AgentFactory.create_sql_tool(
description="Execute an SQL query and return the result",
database="https://aixplain-platform-assets.s3.us-east-1.amazonaws.com/samples/tests/test.db",
schema="employees (id INT PRIMARY KEY, name VARCHAR(100), age INT, salary DECIMAL(10, 2), department VARCHAR(100))",
)
assert tool is not None
assert tool.description == "Execute an SQL query and return the result"
agent = AgentFactory.create(
name="Teste",
description="You are a test agent that search for employee information in a database",
tools=[tool],
)
assert agent is not None
response = agent.run("What is the name of the employee with the highest salary?")
assert response is not None
assert response["completed"] is True
assert response["status"].lower() == "success"
assert "eve" in str(response["data"]["output"]).lower()
24 changes: 24 additions & 0 deletions tests/unit/agent/sql_tool_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from aixplain.factories import AgentFactory
from aixplain.modules.agent.tool.sql_tool import SQLTool


def test_create_sql_tool(mocker):
tool = AgentFactory.create_sql_tool(description="Test", database="test.db", schema="test", tables=["test", "test2"])
assert isinstance(tool, SQLTool)
assert tool.description == "Test"
assert tool.database == "test.db"
assert tool.schema == "test"
assert tool.tables == ["test", "test2"]

tool_dict = tool.to_dict()
assert tool_dict["description"] == "Test"
assert tool_dict["parameters"] == [
{"name": "database", "value": "test.db"},
{"name": "schema", "value": "test"},
{"name": "tables", "value": "test,test2"},
]

mocker.patch("aixplain.factories.file_factory.FileFactory.upload", return_value="s3://test.db")
mocker.patch("os.path.exists", return_value=True)
tool.validate()
assert tool.database == "s3://test.db"