Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

BUG-382: Fixing validation of team and agents #406

Merged
merged 8 commits into from
Feb 24, 2025
Merged
2 changes: 1 addition & 1 deletion aixplain/factories/agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def create(
"tasks": [task.to_dict() for task in tasks],
}
agent = build_agent(payload=payload, api_key=api_key)
agent.validate()
agent.validate(raise_exception=True)
response = "Unspecified error"
try:
logging.debug(f"Start service for POST Create Agent - {url} - {headers} - {json.dumps(agent.to_dict())}")
Expand Down
95 changes: 58 additions & 37 deletions aixplain/factories/agent_factory/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__author__ = "thiagocastroferreira"

import logging
import aixplain.utils.config as config
from aixplain.enums import Function, Supplier
from aixplain.enums.asset_status import AssetStatus
Expand All @@ -16,49 +17,69 @@
GPT_4o_ID = "6646261c6eb563165658bbb1"


def build_tool(tool: Dict):
"""Build a tool from a dictionary.

Args:
tool (Dict): Tool dictionary.

Returns:
Tool: Tool object.
"""
if tool["type"] == "model":
supplier = "aixplain"
for supplier_ in Supplier:
if isinstance(tool["supplier"], str):
if tool["supplier"] is not None and tool["supplier"].lower() in [
supplier_.value["code"].lower(),
supplier_.value["name"].lower(),
]:
supplier = supplier_
break
tool = ModelTool(
function=Function(tool.get("function", None)),
supplier=supplier,
version=tool["version"],
model=tool["assetId"],
description=tool.get("description", ""),
parameters=tool.get("parameters", None),
)
elif tool["type"] == "pipeline":
tool = PipelineTool(description=tool["description"], pipeline=tool["assetId"])
elif tool["type"] == "utility":
if tool.get("utilityCode", None) is not None:
tool = CustomPythonCodeTool(description=tool["description"], code=tool["utilityCode"])
else:
tool = PythonInterpreterTool()
elif tool["type"] == "sql":
parameters = {parameter["name"]: parameter["value"] for parameter in tool.get("parameters", [])}
database = parameters.get("database")
schema = parameters.get("schema")
tables = parameters.get("tables", None)
tables = tables.split(",") if tables is not None else None
enable_commit = parameters.get("enable_commit", False)
tool = SQLTool(
description=tool["description"], database=database, schema=schema, tables=tables, enable_commit=enable_commit
)
else:
raise Exception("Agent Creation Error: Tool type not supported.")

return tool


def build_agent(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> Agent:
"""Instantiate a new agent in the platform."""
tools_dict = payload["assets"]
tools = []
for tool in tools_dict:
if tool["type"] == "model":
supplier = "aixplain"
for supplier_ in Supplier:
if isinstance(tool["supplier"], str):
if tool["supplier"] is not None and tool["supplier"].lower() in [
supplier_.value["code"].lower(),
supplier_.value["name"].lower(),
]:
supplier = supplier_
break
tool = ModelTool(
function=Function(tool.get("function", None)),
supplier=supplier,
version=tool["version"],
model=tool["assetId"],
description=tool.get("description", ""),
parameters=tool.get("parameters", None),
try:
tools.append(build_tool(tool))
except Exception as e:
logging.warning(
f"Tool {tool['assetId']} is not available. Make sure it exists or you have access to it. "
"If you think this is an error, please contact the administrators."
)
elif tool["type"] == "pipeline":
tool = PipelineTool(description=tool["description"], pipeline=tool["assetId"])
elif tool["type"] == "utility":
if tool.get("utilityCode", None) is not None:
tool = CustomPythonCodeTool(description=tool["description"], code=tool["utilityCode"])
else:
tool = PythonInterpreterTool()
elif tool["type"] == "sql":
parameters = {parameter["name"]: parameter["value"] for parameter in tool.get("parameters", [])}
database = parameters.get("database")
schema = parameters.get("schema")
tables = parameters.get("tables", None)
tables = tables.split(",") if tables is not None else None
enable_commit = parameters.get("enable_commit", False)
tool = SQLTool(
description=tool["description"], database=database, schema=schema, tables=tables, enable_commit=enable_commit
)
else:
raise Exception("Agent Creation Error: Tool type not supported.")
tools.append(tool)
continue

agent = Agent(
id=payload["id"] if "id" in payload else "",
Expand Down
2 changes: 1 addition & 1 deletion aixplain/factories/team_agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def create(
}

team_agent = build_team_agent(payload=payload, api_key=api_key)
team_agent.validate()
team_agent.validate(raise_exception=True)
response = "Unspecified error"
try:
logging.debug(f"Start service for POST Create TeamAgent - {url} - {headers} - {json.dumps(payload)}")
Expand Down
11 changes: 9 additions & 2 deletions aixplain/factories/team_agent_factory/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__author__ = "lucaspavanelli"

import logging
import aixplain.utils.config as config
from aixplain.enums.asset_status import AssetStatus
from aixplain.modules.team_agent import TeamAgent
Expand All @@ -16,8 +17,14 @@ def build_team_agent(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> Team
agents_dict = payload["agents"]
agents = []
for i, agent in enumerate(agents_dict):
agent = AgentFactory.get(agent["assetId"])
agents.append(agent)
try:
agents.append(AgentFactory.get(agent["assetId"]))
except Exception:
logging.warning(
f"Agent {agent['assetId']} not found. Make sure it exists or you have access to it. "
"If you think this is an error, please contact the administrators."
)
continue

team_agent = TeamAgent(
id=payload.get("id", ""),
Expand Down
91 changes: 75 additions & 16 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class Agent(Model):
cost (Dict, optional): model price. Defaults to None.
"""

is_valid: bool

def __init__(
self,
id: Text,
Expand Down Expand Up @@ -107,8 +109,9 @@ def __init__(
status = AssetStatus.DRAFT
self.status = status
self.tasks = tasks
self.is_valid = True

def validate(self) -> None:
def _validate(self) -> None:
"""Validate the Agent."""
from aixplain.factories.model_factory import ModelFactory

Expand All @@ -119,15 +122,36 @@ def validate(self) -> None:

try:
llm = ModelFactory.get(self.llm_id, api_key=self.api_key)
assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model."
except Exception:
raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.")

assert (
llm.function == Function.TEXT_GENERATION
), "Large Language Model must be a text generation model."

for tool in self.tools:
if isinstance(tool, Tool):
tool.validate()
elif isinstance(tool, Model):
assert not isinstance(tool, Agent), "Agent cannot contain another Agent."
assert not isinstance(
tool, Agent
), "Agent cannot contain another Agent."

def validate(self, raise_exception: bool = False) -> bool:
"""Validate the Agent."""
try:
self._validate()
self.is_valid = True
except Exception as e:
self.is_valid = False
if raise_exception:
raise e
else:
logging.warning(f"Agent Validation Error: {e}")
logging.warning(
"You won't be able to run the Agent until the issues are handled manually."
)
return self.is_valid

def run(
self,
Expand Down Expand Up @@ -183,7 +207,9 @@ def run(
return response
poll_url = response["url"]
end = time.time()
result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time)
result = self.sync_poll(
poll_url, name=name, timeout=timeout, wait_time=wait_time
)
result_data = result.data
return AgentResponse(
status=ResponseStatus.SUCCESS,
Expand Down Expand Up @@ -245,10 +271,19 @@ def run_async(
"""
from aixplain.factories.file_factory import FileFactory

assert data is not None or query is not None, "Either 'data' or 'query' must be provided."
if not self.is_valid:
raise Exception(
"Agent is not valid. Please validate the agent before running."
)

assert (
data is not None or query is not None
), "Either 'data' or 'query' must be provided."
if data is not None:
if isinstance(data, dict):
assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided."
assert (
"query" in data and data["query"] is not None
), "When providing a dictionary, 'query' must be provided."
query = data.get("query")
if session_id is None:
session_id = data.get("session_id")
Expand All @@ -261,7 +296,9 @@ def run_async(

# process content inputs
if content is not None:
assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text."
assert (
FileFactory.check_storage_type(query) == StorageType.TEXT
), "When providing 'content', query must be text."

if isinstance(content, list):
assert len(content) <= 3, "The maximum number of content inputs is 3."
Expand All @@ -270,7 +307,9 @@ def run_async(
query += f"\n{input_link}"
elif isinstance(content, dict):
for key, value in content.items():
assert "{{" + key + "}}" in query, f"Key '{key}' not found in query."
assert (
"{{" + key + "}}" in query
), f"Key '{key}' not found in query."
value = FileFactory.to_link(value)
query = query.replace("{{" + key + "}}", f"'{value}'")

Expand All @@ -285,8 +324,16 @@ def run_async(
"sessionId": session_id,
"history": history,
"executionParams": {
"maxTokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"maxIterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations,
"maxTokens": (
parameters["max_tokens"]
if "max_tokens" in parameters
else max_tokens
),
"maxIterations": (
parameters["max_iterations"]
if "max_iterations" in parameters
else max_iterations
),
"outputFormat": output_format.value,
},
}
Expand Down Expand Up @@ -320,7 +367,11 @@ def to_dict(self) -> Dict:
"assets": [tool.to_dict() for tool in self.tools],
"description": self.description,
"role": self.instructions,
"supplier": self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier,
"supplier": (
self.supplier.value["code"]
if isinstance(self.supplier, Supplier)
else self.supplier
),
"version": self.version,
"llmId": self.llm_id,
"status": self.status.value,
Expand All @@ -331,7 +382,10 @@ def delete(self) -> None:
"""Delete Agent service"""
try:
url = urljoin(config.BACKEND_URL, f"sdk/agents/{self.id}")
headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"}
headers = {
"x-api-key": config.TEAM_API_KEY,
"Content-Type": "application/json",
}
logging.debug(f"Start service for DELETE Agent - {url} - {headers}")
r = _request_with_retry("delete", url, headers=headers)
logging.debug(f"Result of request for DELETE Agent - {r.status_code}")
Expand All @@ -355,19 +409,22 @@ def update(self) -> None:
stack = inspect.stack()
if len(stack) > 2 and stack[1].function != "save":
warnings.warn(
"update() is deprecated and will be removed in a future version. " "Please use save() instead.",
"update() is deprecated and will be removed in a future version. "
"Please use save() instead.",
DeprecationWarning,
stacklevel=2,
)
from aixplain.factories.agent_factory.utils import build_agent

self.validate()
self.validate(raise_exception=True)
url = urljoin(config.BACKEND_URL, f"sdk/agents/{self.id}")
headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"}

payload = self.to_dict()

logging.debug(f"Start service for PUT Update Agent - {url} - {headers} - {json.dumps(payload)}")
logging.debug(
f"Start service for PUT Update Agent - {url} - {headers} - {json.dumps(payload)}"
)
resp = "No specified error."
try:
r = _request_with_retry("put", url, headers=headers, json=payload)
Expand All @@ -386,7 +443,9 @@ def save(self) -> None:
self.update()

def deploy(self) -> None:
assert self.status == AssetStatus.DRAFT, "Agent must be in draft status to be deployed."
assert (
self.status == AssetStatus.DRAFT
), "Agent must be in draft status to be deployed."
assert self.status != AssetStatus.ONBOARDED, "Agent is already deployed."
self.status = AssetStatus.ONBOARDED
self.update()
Expand Down
Loading