Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

ENG-1392: SQL tool in Agents #400

Merged
merged 8 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions aixplain/factories/agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from aixplain.modules.agent.tool.pipeline_tool import PipelineTool
from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool
from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool
from aixplain.modules.agent.tool.sql_tool import SQLTool
from aixplain.modules.model import Model
from aixplain.modules.pipeline import Pipeline
from aixplain.utils import config
Expand Down Expand Up @@ -188,6 +189,11 @@ def create_custom_python_code_tool(cls, code: Union[Text, Callable], description
"""Create a new custom python code tool."""
return CustomPythonCodeTool(description=description, code=code)

@classmethod
def create_sql_tool(cls, description: Text, database: Text, schema: Text, table: Optional[Text] = None) -> SQLTool:
"""Create a new SQL tool."""
return SQLTool(description=description, database=database, schema=schema, table=table)

@classmethod
def list(cls) -> Dict:
"""List all agents available in the platform."""
Expand Down
7 changes: 7 additions & 0 deletions aixplain/factories/agent_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aixplain.modules.agent.tool.pipeline_tool import PipelineTool
from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool
from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool
from aixplain.modules.agent.tool.sql_tool import SQLTool
from typing import Dict, Text
from urllib.parse import urljoin

Expand Down Expand Up @@ -45,6 +46,12 @@ def build_agent(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> Agent:
tool = CustomPythonCodeTool(description=tool["description"], code=tool["utilityCode"])
else:
tool = PythonInterpreterTool()
elif tool["type"] == "sql":
parameters = {parameter["name"]: parameter["value"] for parameter in tool.get("parameters", [])}
database = parameters.get("database")
schema = parameters.get("schema")
table = parameters.get("table", None)
tool = SQLTool(description=tool["description"], database=database, schema=schema, table=table)
else:
raise Exception("Agent Creation Error: Tool type not supported.")
tools.append(tool)
Expand Down
86 changes: 86 additions & 0 deletions aixplain/modules/agent/tool/sql_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
__author__ = "aiXplain"

"""
Copyright 2024 The aiXplain SDK authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Author: Lucas Pavanelli and Thiago Castro Ferreira
Date: May 16th 2024
Description:
Agentification Class
"""
import os
import validators
from typing import Text, Optional, Dict

from aixplain.modules.agent.tool import Tool


class SQLTool(Tool):
"""Tool to execute SQL commands in an SQLite database.

Attributes:
description (Text): description of the tool
database (Text): database name
schema (Text): database schema description
table (Optional[Text]): table name (optional)
"""

def __init__(
self,
description: Text,
database: Text,
schema: Text,
table: Optional[Text] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tables?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also can we support csv uploading. Please check my comment in agentification pr.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_commit is missing

**additional_info,
) -> None:
"""Tool to execute SQL query commands in an SQLite database.

Args:
description (Text): description of the tool
database (Text): database name
schema (Text): database schema description
table (Optional[Text]): table name (optional)
"""
super().__init__("", description, **additional_info)
self.database = database
self.schema = schema
self.table = table

def to_dict(self) -> Dict[str, Text]:
return {
"description": self.description,
"parameters": [
{"name": "database", "value": self.database},
{"name": "schema", "value": self.schema},
{"name": "table", "value": self.table},
],
"type": "sql",
}

def validate(self):
from aixplain.factories.file_factory import FileFactory

assert self.description and self.description.strip() != "", "SQL Tool Error: Description is required"
assert self.database and self.database.strip() != "", "SQL Tool Error: Database is required"
if not (
str(self.database).startswith("s3://")
or str(self.database).startswith("http://")
or str(self.database).startswith("https://")
or validators.url(self.database)
):
if not os.path.exists(self.database):
raise Exception(f"SQL Tool Error: Database '{self.database}' does not exist")
self.database = FileFactory.upload(local_path=self.database, is_temp=True)
assert self.schema and self.schema.strip() != "", "SQL Tool Error: Schema is required"
30 changes: 14 additions & 16 deletions aixplain/v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TYPE_CHECKING,
Callable,
NotRequired,
Optional,
)

from .resource import (
Expand All @@ -27,7 +28,7 @@
PythonInterpreterTool,
)
from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool

from aixplain.modules.agent.tool.sql_tool import SQLTool
from .enums import Function


Expand Down Expand Up @@ -97,20 +98,14 @@ def create_model_tool(
):
from aixplain.factories import AgentFactory

return AgentFactory.create_model_tool(
model=model, function=function, supplier=supplier, description=description
)
return AgentFactory.create_model_tool(model=model, function=function, supplier=supplier, description=description)

@classmethod
def create_pipeline_tool(
cls, description: str, pipeline: Union["Pipeline", str]
) -> "PipelineTool":
def create_pipeline_tool(cls, description: str, pipeline: Union["Pipeline", str]) -> "PipelineTool":
"""Create a new pipeline tool."""
from aixplain.factories import AgentFactory

return AgentFactory.create_pipeline_tool(
description=description, pipeline=pipeline
)
return AgentFactory.create_pipeline_tool(description=description, pipeline=pipeline)

@classmethod
def create_python_interpreter_tool(cls) -> "PythonInterpreterTool":
Expand All @@ -120,12 +115,15 @@ def create_python_interpreter_tool(cls) -> "PythonInterpreterTool":
return AgentFactory.create_python_interpreter_tool()

@classmethod
def create_custom_python_code_tool(
cls, code: Union[str, Callable], description: str = ""
) -> "CustomPythonCodeTool":
def create_custom_python_code_tool(cls, code: Union[str, Callable], description: str = "") -> "CustomPythonCodeTool":
"""Create a new custom python code tool."""
from aixplain.factories import AgentFactory

return AgentFactory.create_custom_python_code_tool(
code=code, description=description
)
return AgentFactory.create_custom_python_code_tool(code=code, description=description)

@classmethod
def create_sql_tool(cls, description: str, database: str, schema: str, table: Optional[str] = None) -> "SQLTool":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have tables as well:
tables: Optional[List[Text]]

"""Create a new SQL tool."""
from aixplain.factories import AgentFactory

return AgentFactory.create_sql_tool(description=description, database=database, schema=schema, table=table)
23 changes: 23 additions & 0 deletions tests/functional/agent/agent_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,26 @@ def test_specific_model_parameters_e2e(tool_config):
tool_used = True
break
assert tool_used, "Tool was not used in execution"


@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent])
def test_sql_tool(delete_agents_and_team_agents, AgentFactory):
assert delete_agents_and_team_agents
tool = AgentFactory.create_sql_tool(
description="Execute an SQL query and return the result",
database="https://aixplain-platform-assets.s3.us-east-1.amazonaws.com/samples/tests/test.db",
schema="employees (id INT PRIMARY KEY, name VARCHAR(100), age INT, salary DECIMAL(10, 2), department VARCHAR(100))",
)
assert tool is not None
assert tool.description == "Execute an SQL query and return the result"
agent = AgentFactory.create(
name="Teste",
description="You are a test agent that search for employee information in a database",
tools=[tool],
)
assert agent is not None
response = agent.run("What is the name of the employee with the highest salary?")
assert response is not None
assert response["completed"] is True
assert response["status"].lower() == "success"
assert "eve" in str(response["data"]["output"]).lower()
24 changes: 24 additions & 0 deletions tests/unit/agent/sql_tool_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from aixplain.factories import AgentFactory
from aixplain.modules.agent.tool.sql_tool import SQLTool


def test_create_sql_tool(mocker):
tool = AgentFactory.create_sql_tool(description="Test", database="test.db", schema="test", table="test")
assert isinstance(tool, SQLTool)
assert tool.description == "Test"
assert tool.database == "test.db"
assert tool.schema == "test"
assert tool.table == "test"

tool_dict = tool.to_dict()
assert tool_dict["description"] == "Test"
assert tool_dict["parameters"] == [
{"name": "database", "value": "test.db"},
{"name": "schema", "value": "test"},
{"name": "table", "value": "test"},
]

mocker.patch("aixplain.factories.file_factory.FileFactory.upload", return_value="s3://test.db")
mocker.patch("os.path.exists", return_value=True)
tool.validate()
assert tool.database == "s3://test.db"
4 changes: 2 additions & 2 deletions tests/unit/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,11 +433,11 @@ def test_model_to_dict():
def test_model_repr():
# Test with supplier as dict
model1 = Model(id="test-id", name="Test Model", supplier={"name": "aiXplain"})
assert repr(model1) == "<Model: Test Model by aiXplain>"
assert repr(model1) == "<Model: Test Model by aixplain>"

# Test with supplier as string
model2 = Model(id="test-id", name="Test Model", supplier="aiXplain")
assert str(model2) == "<Model: Test Model by aiXplain>"
assert str(model2) == "<Model: Test Model by aixplain>"


def test_poll_with_error():
Expand Down