diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 581c7e88..5ff9ff69 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -92,6 +92,8 @@ def __init__( super().__init__(id, name, description, api_key, supplier, version, cost=cost) self.additional_info = additional_info self.tools = tools + for i, _ in enumerate(tools): + self.tools[i].api_key = api_key self.llm_id = llm_id if isinstance(status, str): try: @@ -110,7 +112,7 @@ def validate(self) -> None: ), "Agent Creation Error: Agent name must not contain special characters." try: - llm = ModelFactory.get(self.llm_id) + llm = ModelFactory.get(self.llm_id, api_key=self.api_key) assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." except Exception: raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.") @@ -307,19 +309,19 @@ def delete(self) -> None: message = f"Agent Deletion Error (HTTP {r.status_code}): There was an error in deleting the agent." logging.error(message) raise Exception(f"{message}") - + def update(self) -> None: """Update agent.""" import warnings import inspect + # Get the current call stack stack = inspect.stack() - if len(stack) > 2 and stack[1].function != 'save': + if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " - "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " "Please use save() instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) from aixplain.factories.agent_factory.utils import build_agent @@ -343,10 +345,9 @@ def update(self) -> None: error_msg = f"Agent Update Error (HTTP {r.status_code}): {resp}" raise Exception(error_msg) - def save(self) -> None: """Save the Agent.""" - self.update() + self.update() def deploy(self) -> None: assert self.status == AssetStatus.DRAFT, "Agent must be in draft status to be deployed." diff --git a/aixplain/modules/agent/tool/__init__.py b/aixplain/modules/agent/tool/__init__.py index 01b44dfa..aefa093a 100644 --- a/aixplain/modules/agent/tool/__init__.py +++ b/aixplain/modules/agent/tool/__init__.py @@ -22,6 +22,7 @@ """ from abc import ABC from typing import Optional, Text +from aixplain.utils import config class Tool(ABC): @@ -38,6 +39,7 @@ def __init__( name: Text, description: Text, version: Optional[Text] = None, + api_key: Optional[Text] = config.TEAM_API_KEY, **additional_info, ) -> None: """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. @@ -46,10 +48,12 @@ def __init__( name (Text): name of the tool description (Text): descriptiion of the tool version (Text): version of the tool + api_key (Text): api key of the tool. Defaults to config.TEAM_API_KEY. """ self.name = name self.description = description self.version = version + self.api_key = api_key self.additional_info = additional_info def to_dict(self): diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 0b1c3179..bdbe0f5f 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -108,7 +108,7 @@ def validate(self) -> Model: try: model = None if self.model is not None: - model = ModelFactory.get(self.model) + model = ModelFactory.get(self.model, api_key=self.api_key) return model except Exception: raise Exception(f"Model Tool Unavailable. Make sure Model '{self.model}' exists or you have access to it.") diff --git a/aixplain/modules/agent/tool/pipeline_tool.py b/aixplain/modules/agent/tool/pipeline_tool.py index 9ea7a5fb..ab3b4311 100644 --- a/aixplain/modules/agent/tool/pipeline_tool.py +++ b/aixplain/modules/agent/tool/pipeline_tool.py @@ -62,6 +62,6 @@ def validate(self): from aixplain.factories.pipeline_factory import PipelineFactory try: - PipelineFactory.get(self.pipeline) + PipelineFactory.get(self.pipeline, api_key=self.api_key) except Exception: raise Exception(f"Pipeline Tool Unavailable. Make sure Pipeline '{self.pipeline}' exists or you have access to it.") diff --git a/tests/unit/agent_test.py b/tests/unit/agent_test.py index 6c17a5b6..10997a75 100644 --- a/tests/unit/agent_test.py +++ b/tests/unit/agent_test.py @@ -9,7 +9,6 @@ from aixplain.modules.agent.utils import process_variables from urllib.parse import urljoin from unittest.mock import patch -import warnings from aixplain.enums.function import Function @@ -198,6 +197,8 @@ def test_to_dict(): description="Test Agent Description", llm_id="6646261c6eb563165658bbb1", tools=[AgentFactory.create_model_tool(function="text-generation")], + api_key="test_api_key", + status=AssetStatus.DRAFT, ) agent_json = agent.to_dict() @@ -207,6 +208,7 @@ def test_to_dict(): assert agent_json["llmId"] == "6646261c6eb563165658bbb1" assert agent_json["assets"][0]["function"] == "text-generation" assert agent_json["assets"][0]["type"] == "model" + assert agent_json["status"] == "draft" def test_update_success(): @@ -256,7 +258,10 @@ def test_update_success(): mock.get(url, headers=headers, json=model_ref_response) # Capture warnings - with pytest.warns(DeprecationWarning, match="update\(\) is deprecated and will be removed in a future version. Please use save\(\) instead."): + with pytest.warns( + DeprecationWarning, + match="update\(\) is deprecated and will be removed in a future version. Please use save\(\) instead.", + ): agent.update() assert agent.id == ref_response["id"] @@ -265,6 +270,7 @@ def test_update_success(): assert agent.llm_id == ref_response["llmId"] assert agent.tools[0].function.value == ref_response["assets"][0]["function"] + def test_save_success(): agent = Agent( id="123", @@ -310,8 +316,9 @@ def test_save_success(): "pricing": {"currency": "USD", "value": 0.0}, } mock.get(url, headers=headers, json=model_ref_response) - + import warnings + # Capture warnings with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") # Trigger all warnings @@ -328,6 +335,7 @@ def test_save_success(): assert agent.llm_id == ref_response["llmId"] assert agent.tools[0].function.value == ref_response["assets"][0]["function"] + def test_run_success(): agent = Agent("123", "Test Agent", "Sample Description") url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run") @@ -369,3 +377,64 @@ def test_fail_utilities_without_model(): with pytest.raises(Exception) as exc_info: AgentFactory.create(name="Test", tools=[ModelTool(function=Function.UTILITIES)], llm_id="6646261c6eb563165658bbb1") assert str(exc_info.value) == "Agent Creation Error: Utility function must be used with an associated model." + + +def test_agent_api_key_propagation(): + """Test that the api_key is properly propagated to tools when creating an agent""" + custom_api_key = "custom_test_key" + tool = AgentFactory.create_model_tool(function="text-generation") + agent = Agent(id="123", name="Test Agent", description="Test Description", tools=[tool], api_key=custom_api_key) + + # Check that the agent has the correct api_key + assert agent.api_key == custom_api_key + # Check that the tool received the agent's api_key + assert agent.tools[0].api_key == custom_api_key + + +def test_agent_default_api_key(): + """Test that the default api_key is used when none is provided""" + tool = AgentFactory.create_model_tool(function="text-generation") + agent = Agent(id="123", name="Test Agent", description="Test Description", tools=[tool]) + + # Check that the agent has the default api_key + assert agent.api_key == config.TEAM_API_KEY + # Check that the tool has the default api_key + assert agent.tools[0].api_key == config.TEAM_API_KEY + + +def test_agent_multiple_tools_api_key(): + """Test that api_key is properly propagated to multiple tools""" + custom_api_key = "custom_test_key" + tools = [ + AgentFactory.create_model_tool(function="text-generation"), + AgentFactory.create_python_interpreter_tool(), + AgentFactory.create_custom_python_code_tool( + code="def main(query: str) -> str:\n return 'Hello'", description="Test Tool" + ), + ] + + agent = Agent(id="123", name="Test Agent", description="Test Description", tools=tools, api_key=custom_api_key) + + # Check that all tools received the agent's api_key + for tool in agent.tools: + assert tool.api_key == custom_api_key + + +def test_agent_api_key_in_requests(): + """Test that the api_key is properly used in API requests""" + custom_api_key = "custom_test_key" + agent = Agent(id="123", name="Test Agent", description="Test Description", api_key=custom_api_key) + + with requests_mock.Mocker() as mock: + url = agent.url + # The custom api_key should be used in the headers + headers = {"x-api-key": custom_api_key, "Content-Type": "application/json"} + ref_response = {"data": "test_url", "status": "IN_PROGRESS"} + mock.post(url, headers=headers, json=ref_response) + + response = agent.run_async(data={"query": "Test query"}) + + # Verify that the request was made with the correct api_key + assert mock.last_request.headers["x-api-key"] == custom_api_key + assert response["status"] == "IN_PROGRESS" + assert response["url"] == "test_url"