Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

BUG-329: Fix agent validation when using api key #361

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __init__(
super().__init__(id, name, description, api_key, supplier, version, cost=cost)
self.additional_info = additional_info
self.tools = tools
for i, _ in enumerate(tools):
self.tools[i].api_key = api_key
self.llm_id = llm_id
if isinstance(status, str):
try:
Expand All @@ -110,7 +112,7 @@ def validate(self) -> None:
), "Agent Creation Error: Agent name must not contain special characters."

try:
llm = ModelFactory.get(self.llm_id)
llm = ModelFactory.get(self.llm_id, api_key=self.api_key)
assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model."
except Exception:
raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.")
Expand Down Expand Up @@ -307,19 +309,19 @@ def delete(self) -> None:
message = f"Agent Deletion Error (HTTP {r.status_code}): There was an error in deleting the agent."
logging.error(message)
raise Exception(f"{message}")

def update(self) -> None:
"""Update agent."""
import warnings
import inspect

# Get the current call stack
stack = inspect.stack()
if len(stack) > 2 and stack[1].function != 'save':
if len(stack) > 2 and stack[1].function != "save":
warnings.warn(
"update() is deprecated and will be removed in a future version. "
"Please use save() instead.",
"update() is deprecated and will be removed in a future version. " "Please use save() instead.",
DeprecationWarning,
stacklevel=2
stacklevel=2,
)
from aixplain.factories.agent_factory.utils import build_agent

Expand All @@ -343,10 +345,9 @@ def update(self) -> None:
error_msg = f"Agent Update Error (HTTP {r.status_code}): {resp}"
raise Exception(error_msg)


def save(self) -> None:
"""Save the Agent."""
self.update()
self.update()

def deploy(self) -> None:
assert self.status == AssetStatus.DRAFT, "Agent must be in draft status to be deployed."
Expand Down
4 changes: 4 additions & 0 deletions aixplain/modules/agent/tool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
from abc import ABC
from typing import Optional, Text
from aixplain.utils import config


class Tool(ABC):
Expand All @@ -38,6 +39,7 @@ def __init__(
name: Text,
description: Text,
version: Optional[Text] = None,
api_key: Optional[Text] = config.TEAM_API_KEY,
**additional_info,
) -> None:
"""Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands.
Expand All @@ -46,10 +48,12 @@ def __init__(
name (Text): name of the tool
description (Text): descriptiion of the tool
version (Text): version of the tool
api_key (Text): api key of the tool. Defaults to config.TEAM_API_KEY.
"""
self.name = name
self.description = description
self.version = version
self.api_key = api_key
self.additional_info = additional_info

def to_dict(self):
Expand Down
2 changes: 1 addition & 1 deletion aixplain/modules/agent/tool/model_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def validate(self) -> Model:
try:
model = None
if self.model is not None:
model = ModelFactory.get(self.model)
model = ModelFactory.get(self.model, api_key=self.api_key)
return model
except Exception:
raise Exception(f"Model Tool Unavailable. Make sure Model '{self.model}' exists or you have access to it.")
2 changes: 1 addition & 1 deletion aixplain/modules/agent/tool/pipeline_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ def validate(self):
from aixplain.factories.pipeline_factory import PipelineFactory

try:
PipelineFactory.get(self.pipeline)
PipelineFactory.get(self.pipeline, api_key=self.api_key)
except Exception:
raise Exception(f"Pipeline Tool Unavailable. Make sure Pipeline '{self.pipeline}' exists or you have access to it.")
75 changes: 72 additions & 3 deletions tests/unit/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from aixplain.modules.agent.utils import process_variables
from urllib.parse import urljoin
from unittest.mock import patch
import warnings
from aixplain.enums.function import Function


Expand Down Expand Up @@ -198,6 +197,8 @@ def test_to_dict():
description="Test Agent Description",
llm_id="6646261c6eb563165658bbb1",
tools=[AgentFactory.create_model_tool(function="text-generation")],
api_key="test_api_key",
status=AssetStatus.DRAFT,
)

agent_json = agent.to_dict()
Expand All @@ -207,6 +208,7 @@ def test_to_dict():
assert agent_json["llmId"] == "6646261c6eb563165658bbb1"
assert agent_json["assets"][0]["function"] == "text-generation"
assert agent_json["assets"][0]["type"] == "model"
assert agent_json["status"] == "draft"


def test_update_success():
Expand Down Expand Up @@ -256,7 +258,10 @@ def test_update_success():
mock.get(url, headers=headers, json=model_ref_response)

# Capture warnings
with pytest.warns(DeprecationWarning, match="update\(\) is deprecated and will be removed in a future version. Please use save\(\) instead."):
with pytest.warns(
DeprecationWarning,
match="update\(\) is deprecated and will be removed in a future version. Please use save\(\) instead.",
):
agent.update()

assert agent.id == ref_response["id"]
Expand All @@ -265,6 +270,7 @@ def test_update_success():
assert agent.llm_id == ref_response["llmId"]
assert agent.tools[0].function.value == ref_response["assets"][0]["function"]


def test_save_success():
agent = Agent(
id="123",
Expand Down Expand Up @@ -310,8 +316,9 @@ def test_save_success():
"#": {"currency": "USD", "value": 0.0},
}
mock.get(url, headers=headers, json=model_ref_response)

import warnings

# Capture warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Trigger all warnings
Expand All @@ -328,6 +335,7 @@ def test_save_success():
assert agent.llm_id == ref_response["llmId"]
assert agent.tools[0].function.value == ref_response["assets"][0]["function"]


def test_run_success():
agent = Agent("123", "Test Agent", "Sample Description")
url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run")
Expand Down Expand Up @@ -369,3 +377,64 @@ def test_fail_utilities_without_model():
with pytest.raises(Exception) as exc_info:
AgentFactory.create(name="Test", tools=[ModelTool(function=Function.UTILITIES)], llm_id="6646261c6eb563165658bbb1")
assert str(exc_info.value) == "Agent Creation Error: Utility function must be used with an associated model."


def test_agent_api_key_propagation():
"""Test that the api_key is properly propagated to tools when creating an agent"""
custom_api_key = "custom_test_key"
tool = AgentFactory.create_model_tool(function="text-generation")
agent = Agent(id="123", name="Test Agent", description="Test Description", tools=[tool], api_key=custom_api_key)

# Check that the agent has the correct api_key
assert agent.api_key == custom_api_key
# Check that the tool received the agent's api_key
assert agent.tools[0].api_key == custom_api_key


def test_agent_default_api_key():
"""Test that the default api_key is used when none is provided"""
tool = AgentFactory.create_model_tool(function="text-generation")
agent = Agent(id="123", name="Test Agent", description="Test Description", tools=[tool])

# Check that the agent has the default api_key
assert agent.api_key == config.TEAM_API_KEY
# Check that the tool has the default api_key
assert agent.tools[0].api_key == config.TEAM_API_KEY


def test_agent_multiple_tools_api_key():
"""Test that api_key is properly propagated to multiple tools"""
custom_api_key = "custom_test_key"
tools = [
AgentFactory.create_model_tool(function="text-generation"),
AgentFactory.create_python_interpreter_tool(),
AgentFactory.create_custom_python_code_tool(
code="def main(query: str) -> str:\n return 'Hello'", description="Test Tool"
),
]

agent = Agent(id="123", name="Test Agent", description="Test Description", tools=tools, api_key=custom_api_key)

# Check that all tools received the agent's api_key
for tool in agent.tools:
assert tool.api_key == custom_api_key


def test_agent_api_key_in_requests():
"""Test that the api_key is properly used in API requests"""
custom_api_key = "custom_test_key"
agent = Agent(id="123", name="Test Agent", description="Test Description", api_key=custom_api_key)

with requests_mock.Mocker() as mock:
url = agent.url
# The custom api_key should be used in the headers
headers = {"x-api-key": custom_api_key, "Content-Type": "application/json"}
ref_response = {"data": "test_url", "status": "IN_PROGRESS"}
mock.post(url, headers=headers, json=ref_response)

response = agent.run_async(data={"query": "Test query"})

# Verify that the request was made with the correct api_key
assert mock.last_request.headers["x-api-key"] == custom_api_key
assert response["status"] == "IN_PROGRESS"
assert response["url"] == "test_url"