From 292f441ea063dd9fd370e82020bec0dede93e5b4 Mon Sep 17 00:00:00 2001 From: xainaz Date: Thu, 6 Feb 2025 00:38:05 +0300 Subject: [PATCH 1/6] caching for agents, pipelines and models --- aixplain/enums/function.py | 11 ++- aixplain/enums/language.py | 5 +- aixplain/enums/license.py | 6 +- aixplain/modules/agent/__init__.py | 17 ++++ aixplain/modules/agent/cache_agents.py | 101 ++++++++++++++++++++ aixplain/modules/model/__init__.py | 17 +++- aixplain/modules/model/cache_models.py | 86 +++++++++++++++++ aixplain/modules/pipeline/asset.py | 10 ++ aixplain/modules/pipeline/pipeline_cache.py | 81 ++++++++++++++++ aixplain/utils/cache_utils.py | 35 ++++--- 10 files changed, 347 insertions(+), 22 deletions(-) create mode 100644 aixplain/modules/agent/cache_agents.py create mode 100644 aixplain/modules/model/cache_models.py create mode 100644 aixplain/modules/pipeline/pipeline_cache.py diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index a51f5301..f8190c9d 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -28,15 +28,18 @@ from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER from typing import Tuple, Dict from aixplain.base.parameters import BaseParameters, Parameter +import os CACHE_FILE = f"{CACHE_FOLDER}/functions.json" - +LOCK_FILE = f"{CACHE_FILE}.lock" def load_functions(): api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL - resp = load_from_cache(CACHE_FILE) + os.makedirs(CACHE_FOLDER, exist_ok=True) + + resp = load_from_cache(CACHE_FILE, LOCK_FILE) if resp is None: url = urljoin(backend_url, "sdk/functions") @@ -47,7 +50,7 @@ def load_functions(): f'Functions could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp) + save_to_cache(CACHE_FILE, resp, LOCK_FILE) class Function(str, Enum): def __new__(cls, value): @@ -63,6 +66,8 @@ def get_input_output_params(self) -> Tuple[Dict, Dict]: Tuple[Dict, Dict]: A tuple containing (input_params, output_params) """ function_io = FunctionInputOutput.get(self.value, None) + if function_io is None: + return {}, {} input_params = {param["code"]: param for param in function_io["spec"]["params"]} output_params = {param["code"]: param for param in function_io["spec"]["output"]} return input_params, output_params diff --git a/aixplain/enums/language.py b/aixplain/enums/language.py index db66b2a1..52938793 100644 --- a/aixplain/enums/language.py +++ b/aixplain/enums/language.py @@ -28,10 +28,11 @@ from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/languages.json" +LOCK_FILE = f"{CACHE_FILE}.lock" def load_languages(): - resp = load_from_cache(CACHE_FILE) + resp = load_from_cache(CACHE_FILE,LOCK_FILE) if resp is None: api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL @@ -45,7 +46,7 @@ def load_languages(): f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp) + save_to_cache(CACHE_FILE, resp, LOCK_FILE) languages = {} for w in resp: diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index a860a539..22ad9b0d 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -29,10 +29,11 @@ from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" +LOCK_FILE = f"{CACHE_FILE}.lock" def load_licenses(): - resp = load_from_cache(CACHE_FILE) + resp = load_from_cache(CACHE_FILE, LOCK_FILE) try: if resp is None: @@ -48,7 +49,8 @@ def load_licenses(): f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp) + save_to_cache(CACHE_FILE, resp, LOCK_FILE) + licenses = {"_".join(w["name"].split()): w["id"] for w in resp} return Enum("License", licenses, type=str) except Exception: diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index a7707d05..9350a638 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -39,6 +39,7 @@ from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.enums import ResponseStatus from aixplain.modules.agent.utils import process_variables +from aixplain.modules.agent.cache_agents import load_agents from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin @@ -77,6 +78,22 @@ def __init__( tasks: List[AgentTask] = [], **additional_info, ) -> None: + AgentCache, AgentDetails = load_agents(cache_expiry=86400) + + if id in AgentDetails: + cached_agent= AgentDetails[id] + name=cached_agent["name"] + description=cached_agent["description"] + tools=cached_agent["tools"] + llm_id=cached_agent["llm_id"] + api_key=cached_agent["api_key"] + supplier=cached_agent["supplier"] + version=cached_agent["version"] + cost=cached_agent["cost"] + status=cached_agent["status"] + tasks=cached_agent["tasks"] + additional_info=cached_agent["additional_info"] + """Create an Agent with the necessary information. Args: diff --git a/aixplain/modules/agent/cache_agents.py b/aixplain/modules/agent/cache_agents.py new file mode 100644 index 00000000..a04bad00 --- /dev/null +++ b/aixplain/modules/agent/cache_agents.py @@ -0,0 +1,101 @@ +import os +import json +import logging +from datetime import datetime +from enum import Enum +from urllib.parse import urljoin +from typing import Dict, Optional, List, Tuple, Union, Text +from aixplain.utils import config +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.enums import Supplier +from aixplain.modules.agent.tool import Tool + +AGENT_CACHE_FILE = f"{CACHE_FOLDER}/agents.json" +LOCK_FILE = f"{AGENT_CACHE_FILE}.lock" + +def load_agents(cache_expiry: Optional[int] = None) -> Tuple[Enum, Dict]: + """ + Load AI agents from cache or fetch from backend if not cached. + Only agents with status "onboarded" should be cached. + + Args: + cache_expiry (int, optional): Expiry time in seconds. Default is 24 hours. + + Returns: + Tuple[Enum, Dict]: (Enum of agent IDs, Dictionary with agent details) + """ + if cache_expiry is None: + cache_expiry = 86400 + + os.makedirs(CACHE_FOLDER, exist_ok=True) + + cached_data = load_from_cache(AGENT_CACHE_FILE, LOCK_FILE) + + if cached_data is not None: + return parse_agents(cached_data) + + api_key = config.TEAM_API_KEY + backend_url = config.BACKEND_URL + url = urljoin(backend_url, "sdk/agents") + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + try: + response = _request_with_retry("get", url, headers=headers) + response.raise_for_status() + agents_data = response.json() + except Exception as e: + logging.error(f"Failed to fetch agents from API: {e}") + return Enum("Agent", {}), {} + + + onboarded_agents = [agent for agent in agents_data if agent.get("status", "").lower() == "onboarded"] + + save_to_cache(AGENT_CACHE_FILE, {"items": onboarded_agents}, LOCK_FILE) + + return parse_agents({"items": onboarded_agents}) + + +def parse_agents(agents_data: Dict) -> Tuple[Enum, Dict]: + """ + Convert agent data into an Enum and dictionary format for easy use. + + Args: + agents_data (Dict): JSON response with agents list. + + Returns: + - agents_enum: Enum with agent IDs. + - agents_details: Dictionary containing all agent parameters. + """ + if not agents_data["items"]: + logging.warning("No onboarded agents found.") + return Enum("Agent", {}), {} + + agents_enum = Enum( + "Agent", + {a["id"].upper().replace("-", "_"): a["id"] for a in agents_data["items"]}, + type=str, + ) + + agents_details = { + agent["id"]: { + "id": agent["id"], + "name": agent.get("name", ""), + "description": agent.get("description", ""), + "role": agent.get("role", ""), + "tools": [Tool(t) if isinstance(t, dict) else t for t in agent.get("tools", [])], + "llm_id": agent.get("llm_id", "6646261c6eb563165658bbb1"), + "supplier": agent.get("supplier", "aiXplain"), + "version": agent.get("version", "1.0"), + "status": agent.get("status", "onboarded"), + "created_at": agent.get("created_at", ""), + "tasks": agent.get("tasks", []), + **agent, + } + for agent in agents_data["items"] + } + + return agents_enum, agents_details + + +Agent, AgentDetails = load_agents() diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 5788daab..69e02cc6 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -34,7 +34,7 @@ from aixplain.modules.model.response import ModelResponse from aixplain.enums.response_status import ResponseStatus from aixplain.modules.model.model_parameters import ModelParameters - +from aixplain.modules.model.cache_models import load_models class Model(Asset): """This is ready-to-use AI model. This model can be run in both synchronous and asynchronous manner. @@ -91,6 +91,21 @@ def __init__( model_params (Dict, optional): parameters for the function. **additional_info: Any additional Model info to be saved """ + ModelCache, ModelDetails = load_models(cache_expiry=86400) + + if id in ModelDetails: + + cached_model = ModelDetails[id] + input_params = cached_model["input"] + api_key = cached_model["spec"].get("api_key", api_key) + additional_info = cached_model["spec"].get("additional_info", {}) + function = cached_model["spec"].get("function", function) + is_subscribed = cached_model["spec"].get("is_subscribed", is_subscribed) + created_at = cached_model["spec"].get("created_at", created_at) + model_params = cached_model["spec"].get("model_params", model_params) + output_params = cached_model["output"] + description = cached_model["spec"].get("description", description) + super().__init__(id, name, description, supplier, version, cost=cost) self.api_key = api_key self.additional_info = additional_info diff --git a/aixplain/modules/model/cache_models.py b/aixplain/modules/model/cache_models.py new file mode 100644 index 00000000..d5628587 --- /dev/null +++ b/aixplain/modules/model/cache_models.py @@ -0,0 +1,86 @@ +import os +import json +import time +import logging +from datetime import datetime +from enum import Enum +from urllib.parse import urljoin +from typing import Dict, Optional, Union, Text +from aixplain.utils import config +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.enums import Supplier, Function + +CACHE_FILE = f"{CACHE_FOLDER}/models.json" +LOCK_FILE = f"{CACHE_FILE}.lock" + +def load_models(cache_expiry: Optional[int] = None): + """ + Load models from cache or fetch from backend if not cached. + Only models with status "onboarded" should be cached. + + Args: + cache_expiry (int, optional): Expiry time in seconds. Default is user-configurable. + """ + + api_key = config.TEAM_API_KEY + backend_url = config.BACKEND_URL + + cached_data = load_from_cache(CACHE_FILE, LOCK_FILE) + + if cached_data is not None: + + return parse_models(cached_data) + + url = urljoin(backend_url, "sdk/models") + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + response = _request_with_retry("get", url, headers=headers) + if not 200 <= response.status_code < 300: + raise Exception(f"Models could not be loaded, API key '{api_key}' might be invalid.") + + models_data = response.json() + + onboarded_models = [model for model in models_data["items"] if model["status"].lower() == "onboarded"] + save_to_cache(CACHE_FILE, {"items": onboarded_models}, LOCK_FILE) + + return parse_models({"items": onboarded_models}) + +def parse_models(models_data): + """ + Convert model data into an Enum and dictionary format for easy use. + + Returns: + - models_enum: Enum with model IDs. + - models_details: Dictionary containing all model parameters. + """ + + if not models_data["items"]: + logging.warning("No onboarded models found.") + return Enum("Model", {}), {} + models_enum = Enum("Model", {m["id"].upper().replace("-", "_"): m["id"] for m in models_data["items"]}, type=str) + + models_details = { + model["id"]: { + "id": model["id"], + "name": model.get("name", ""), + "description": model.get("description", ""), + "api_key": model.get("api_key", config.TEAM_API_KEY), + "supplier": model.get("supplier", "aiXplain"), + "version": model.get("version"), + "function": model.get("function"), + "is_subscribed": model.get("is_subscribed", False), + "cost": model.get("cost"), + "created_at": model.get("created_at"), + "input_params": model.get("input_params"), + "output_params": model.get("output_params"), + "model_params": model.get("model_params"), + **model, + } + for model in models_data["items"] + } + + return models_enum, models_details + + +Model, ModelDetails = load_models() diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 88364873..0cf5ba63 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -31,6 +31,7 @@ from aixplain.utils.file_utils import _request_with_retry from typing import Dict, Optional, Text, Union from urllib.parse import urljoin +from aixplain.modules.pipeline.pipeline_cache import PipelineDetails class Pipeline(Asset): @@ -72,6 +73,15 @@ def __init__( status (AssetStatus, optional): Pipeline status. Defaults to AssetStatus.DRAFT. **additional_info: Any additional Pipeline info to be saved """ + if id in PipelineDetails: + cached_pipeline = PipelineDetails[id] + name = cached_pipeline["name"] + api_key = cached_pipeline["api_key"] + supplier = cached_pipeline["supplier"] + version = cached_pipeline["version"] + status = cached_pipeline["status"] + additional_info = cached_pipeline["architecture"] + if not name: raise ValueError("Pipeline name is required") diff --git a/aixplain/modules/pipeline/pipeline_cache.py b/aixplain/modules/pipeline/pipeline_cache.py new file mode 100644 index 00000000..4c2f817e --- /dev/null +++ b/aixplain/modules/pipeline/pipeline_cache.py @@ -0,0 +1,81 @@ +import os +import json +import time +import logging +from datetime import datetime +from enum import Enum +from urllib.parse import urljoin +from typing import Dict, Optional, Union, Text +from aixplain.utils import config +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.enums import Supplier + +PIPELINE_CACHE_FILE = f"{CACHE_FOLDER}/pipelines.json" +LOCK_FILE = f"{PIPELINE_CACHE_FILE}.lock" +def load_pipelines(cache_expiry: Optional[int] = None): + """ + Load pipelines from cache or fetch from backend if not cached. + Only pipelines with status "onboarded" should be cached. + + Args: + cache_expiry (int, optional): Expiry time in seconds. Default is user-configurable. + """ + if cache_expiry is None: + cache_expiry = 86400 + + api_key = config.TEAM_API_KEY + backend_url = config.BACKEND_URL + + cached_data = load_from_cache(PIPELINE_CACHE_FILE, LOCK_FILE) + if cached_data is not None: + return parse_pipelines(cached_data) + + url = urljoin(backend_url, "sdk/pipelines") + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + response = _request_with_retry("get", url, headers=headers) + if not 200 <= response.status_code < 300: + raise Exception(f"Pipelines could not be loaded, API key '{api_key}' might be invalid.") + + pipelines_data = response.json() + + onboarded_pipelines = [pipeline for pipeline in pipelines_data["items"] if pipeline["status"].lower() == "onboarded"] + + save_to_cache(PIPELINE_CACHE_FILE, {"items": onboarded_pipelines}, LOCK_FILE) + + return parse_pipelines({"items": onboarded_pipelines}) + +def parse_pipelines(pipelines_data): + """ + Convert pipeline data into an Enum and dictionary format for easy use. + + Returns: + - pipelines_enum: Enum with pipeline IDs. + - pipelines_details: Dictionary containing all pipeline parameters. + """ + if not pipelines_data["items"]: + logging.warning("No onboarded pipelines found.") + return Enum("Pipeline", {}), {} + + pipelines_enum = Enum("Pipeline", {p["id"].upper().replace("-", "_"): p["id"] for p in pipelines_data["items"]}, type=str) + + pipelines_details = { + pipeline["id"]: { + "id": pipeline["id"], + "name": pipeline.get("name", ""), + "description": pipeline.get("description", ""), + "api_key": pipeline.get("api_key", config.TEAM_API_KEY), + "supplier": pipeline.get("supplier", "aiXplain"), + "version": pipeline.get("version", "1.0"), + "status": pipeline.get("status", "onboarded"), + "created_at": pipeline.get("created_at"), + "architecture": pipeline.get("architecture", {}), + **pipeline, + } + for pipeline in pipelines_data["items"] + } + + return pipelines_enum, pipelines_details + +Pipeline, PipelineDetails = load_pipelines() diff --git a/aixplain/utils/cache_utils.py b/aixplain/utils/cache_utils.py index 5a0eb6ae..6717234f 100644 --- a/aixplain/utils/cache_utils.py +++ b/aixplain/utils/cache_utils.py @@ -2,26 +2,33 @@ import json import time import logging +from filelock import FileLock -CACHE_DURATION = 24 * 60 * 60 -CACHE_FOLDER = ".aixplain_cache" +CACHE_FOLDER = ".cache" +CACHE_FILE = f"{CACHE_FOLDER}/cache.json" +LOCK_FILE = f"{CACHE_FILE}.lock" +DEFAULT_CACHE_EXPIRY = 86400 +def get_cache_expiry(): + return int(os.getenv("CACHE_EXPIRY_TIME", DEFAULT_CACHE_EXPIRY)) -def save_to_cache(cache_file, data): + +def save_to_cache(cache_file, data, lock_file): try: os.makedirs(os.path.dirname(cache_file), exist_ok=True) - with open(cache_file, "w") as f: - json.dump({"timestamp": time.time(), "data": data}, f) + with FileLock(lock_file): + with open(cache_file, "w") as f: + json.dump({"timestamp": time.time(), "data": data}, f) except Exception as e: logging.error(f"Failed to save cache to {cache_file}: {e}") - -def load_from_cache(cache_file): - if os.path.exists(cache_file) is True: - with open(cache_file, "r") as f: - cache_data = json.load(f) - if time.time() - cache_data["timestamp"] < CACHE_DURATION: - return cache_data["data"] - else: - return None +def load_from_cache(cache_file, lock_file): + if os.path.exists(cache_file): + with FileLock(lock_file): + with open(cache_file, "r") as f: + cache_data = json.load(f) + if time.time() - cache_data["timestamp"] < int(get_cache_expiry()): + return cache_data["data"] + else: + return None return None From 6c8e3cf6b515f7085750d56babbdf6d6aa7cdd21 Mon Sep 17 00:00:00 2001 From: xainaz Date: Thu, 6 Feb 2025 00:48:08 +0300 Subject: [PATCH 2/6] formatting --- aixplain/enums/function.py | 3 +- aixplain/enums/language.py | 4 +- aixplain/enums/license.py | 2 +- .../factories/team_agent_factory/__init__.py | 2 +- aixplain/modules/agent/__init__.py | 24 ++++----- aixplain/modules/agent/cache_agents.py | 16 +++--- aixplain/modules/model/__init__.py | 1 + aixplain/modules/model/cache_models.py | 8 +-- aixplain/modules/model/utility_model.py | 16 +++--- aixplain/modules/pipeline/asset.py | 51 ++++++------------- aixplain/modules/pipeline/pipeline_cache.py | 12 +++-- aixplain/utils/cache_utils.py | 2 + aixplain/utils/file_utils.py | 4 +- .../model/run_utility_model_test.py | 27 ++++++---- tests/unit/utility_test.py | 44 ++++++++++++---- tests/unit/utility_tool_decorator_test.py | 38 +++++++------- 16 files changed, 138 insertions(+), 116 deletions(-) diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index f8190c9d..a77f3cfc 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -31,7 +31,8 @@ import os CACHE_FILE = f"{CACHE_FOLDER}/functions.json" -LOCK_FILE = f"{CACHE_FILE}.lock" +LOCK_FILE = f"{CACHE_FILE}.lock" + def load_functions(): api_key = config.TEAM_API_KEY diff --git a/aixplain/enums/language.py b/aixplain/enums/language.py index 52938793..c129822f 100644 --- a/aixplain/enums/language.py +++ b/aixplain/enums/language.py @@ -28,11 +28,11 @@ from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/languages.json" -LOCK_FILE = f"{CACHE_FILE}.lock" +LOCK_FILE = f"{CACHE_FILE}.lock" def load_languages(): - resp = load_from_cache(CACHE_FILE,LOCK_FILE) + resp = load_from_cache(CACHE_FILE, LOCK_FILE) if resp is None: api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index 22ad9b0d..f9758b84 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -29,7 +29,7 @@ from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" -LOCK_FILE = f"{CACHE_FILE}.lock" +LOCK_FILE = f"{CACHE_FILE}.lock" def load_licenses(): diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index ea145d9a..c9b7e6cc 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -62,7 +62,7 @@ def create( from aixplain.modules.agent import Agent assert isinstance(agent, Agent), "TeamAgent Onboarding Error: Agents must be instances of Agent class" - + mentalist_and_inspector_llm_id = None if use_inspector or use_mentalist_and_inspector: mentalist_and_inspector_llm_id = llm_id diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 9350a638..441c20fa 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -81,18 +81,18 @@ def __init__( AgentCache, AgentDetails = load_agents(cache_expiry=86400) if id in AgentDetails: - cached_agent= AgentDetails[id] - name=cached_agent["name"] - description=cached_agent["description"] - tools=cached_agent["tools"] - llm_id=cached_agent["llm_id"] - api_key=cached_agent["api_key"] - supplier=cached_agent["supplier"] - version=cached_agent["version"] - cost=cached_agent["cost"] - status=cached_agent["status"] - tasks=cached_agent["tasks"] - additional_info=cached_agent["additional_info"] + cached_agent = AgentDetails[id] + name = cached_agent["name"] + description = cached_agent["description"] + tools = cached_agent["tools"] + llm_id = cached_agent["llm_id"] + api_key = cached_agent["api_key"] + supplier = cached_agent["supplier"] + version = cached_agent["version"] + cost = cached_agent["cost"] + status = cached_agent["status"] + tasks = cached_agent["tasks"] + additional_info = cached_agent["additional_info"] """Create an Agent with the necessary information. diff --git a/aixplain/modules/agent/cache_agents.py b/aixplain/modules/agent/cache_agents.py index a04bad00..4d38344f 100644 --- a/aixplain/modules/agent/cache_agents.py +++ b/aixplain/modules/agent/cache_agents.py @@ -9,10 +9,11 @@ from aixplain.utils.request_utils import _request_with_retry from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER from aixplain.enums import Supplier -from aixplain.modules.agent.tool import Tool +from aixplain.modules.agent.tool import Tool AGENT_CACHE_FILE = f"{CACHE_FOLDER}/agents.json" -LOCK_FILE = f"{AGENT_CACHE_FILE}.lock" +LOCK_FILE = f"{AGENT_CACHE_FILE}.lock" + def load_agents(cache_expiry: Optional[int] = None) -> Tuple[Enum, Dict]: """ @@ -26,7 +27,7 @@ def load_agents(cache_expiry: Optional[int] = None) -> Tuple[Enum, Dict]: Tuple[Enum, Dict]: (Enum of agent IDs, Dictionary with agent details) """ if cache_expiry is None: - cache_expiry = 86400 + cache_expiry = 86400 os.makedirs(CACHE_FOLDER, exist_ok=True) @@ -42,12 +43,11 @@ def load_agents(cache_expiry: Optional[int] = None) -> Tuple[Enum, Dict]: try: response = _request_with_retry("get", url, headers=headers) - response.raise_for_status() + response.raise_for_status() agents_data = response.json() except Exception as e: logging.error(f"Failed to fetch agents from API: {e}") - return Enum("Agent", {}), {} - + return Enum("Agent", {}), {} onboarded_agents = [agent for agent in agents_data if agent.get("status", "").lower() == "onboarded"] @@ -67,7 +67,7 @@ def parse_agents(agents_data: Dict) -> Tuple[Enum, Dict]: - agents_enum: Enum with agent IDs. - agents_details: Dictionary containing all agent parameters. """ - if not agents_data["items"]: + if not agents_data["items"]: logging.warning("No onboarded agents found.") return Enum("Agent", {}), {} @@ -90,7 +90,7 @@ def parse_agents(agents_data: Dict) -> Tuple[Enum, Dict]: "status": agent.get("status", "onboarded"), "created_at": agent.get("created_at", ""), "tasks": agent.get("tasks", []), - **agent, + **agent, } for agent in agents_data["items"] } diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 69e02cc6..f3962e01 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -36,6 +36,7 @@ from aixplain.modules.model.model_parameters import ModelParameters from aixplain.modules.model.cache_models import load_models + class Model(Asset): """This is ready-to-use AI model. This model can be run in both synchronous and asynchronous manner. diff --git a/aixplain/modules/model/cache_models.py b/aixplain/modules/model/cache_models.py index d5628587..a336c488 100644 --- a/aixplain/modules/model/cache_models.py +++ b/aixplain/modules/model/cache_models.py @@ -12,7 +12,8 @@ from aixplain.enums import Supplier, Function CACHE_FILE = f"{CACHE_FOLDER}/models.json" -LOCK_FILE = f"{CACHE_FILE}.lock" +LOCK_FILE = f"{CACHE_FILE}.lock" + def load_models(cache_expiry: Optional[int] = None): """ @@ -46,6 +47,7 @@ def load_models(cache_expiry: Optional[int] = None): return parse_models({"items": onboarded_models}) + def parse_models(models_data): """ Convert model data into an Enum and dictionary format for easy use. @@ -55,7 +57,7 @@ def parse_models(models_data): - models_details: Dictionary containing all model parameters. """ - if not models_data["items"]: + if not models_data["items"]: logging.warning("No onboarded models found.") return Enum("Model", {}), {} models_enum = Enum("Model", {m["id"].upper().replace("-", "_"): m["id"] for m in models_data["items"]}, type=str) @@ -75,7 +77,7 @@ def parse_models(models_data): "input_params": model.get("input_params"), "output_params": model.get("output_params"), "model_params": model.get("model_params"), - **model, + **model, } for model in models_data["items"] } diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 6a323f45..a0b0ca15 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -43,17 +43,20 @@ def validate(self): def to_dict(self): return {"name": self.name, "description": self.description, "type": self.type.value} + # Tool decorator -def utility_tool(name: Text, description: Text, inputs: List[UtilityModelInput] = None, output_examples: Text = "", status = AssetStatus.DRAFT): +def utility_tool( + name: Text, description: Text, inputs: List[UtilityModelInput] = None, output_examples: Text = "", status=AssetStatus.DRAFT +): """Decorator for utility tool functions - + Args: name: Name of the utility tool description: Description of what the utility tool does inputs: List of input parameters, must be UtilityModelInput objects output_examples: Examples of expected outputs status: Asset status - + Raises: ValueError: If name or description is empty TypeError: If inputs contains non-UtilityModelInput objects @@ -63,7 +66,7 @@ def utility_tool(name: Text, description: Text, inputs: List[UtilityModelInput] raise ValueError("Utility tool name cannot be empty") if not description or not description.strip(): raise ValueError("Utility tool description cannot be empty") - + # Validate inputs if inputs is not None: if not isinstance(inputs, list): @@ -71,7 +74,7 @@ def utility_tool(name: Text, description: Text, inputs: List[UtilityModelInput] for input_param in inputs: if not isinstance(input_param, UtilityModelInput): raise TypeError(f"Invalid input parameter: {input_param}. All inputs must be UtilityModelInput objects") - + def decorator(func): func._is_utility_tool = True # Mark function as utility tool func._tool_name = name.strip() @@ -80,6 +83,7 @@ def decorator(func): func._tool_output_examples = output_examples func._tool_status = status return func + return decorator @@ -116,7 +120,7 @@ def __init__( function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, - status: AssetStatus = AssetStatus.ONBOARDED,# TODO: change to draft when we have the backend ready + status: AssetStatus = AssetStatus.ONBOARDED, # TODO: change to draft when we have the backend ready **additional_info, ) -> None: """Utility Model Init diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 0cf5ba63..216274fc 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -81,7 +81,7 @@ def __init__( version = cached_pipeline["version"] status = cached_pipeline["status"] additional_info = cached_pipeline["architecture"] - + if not name: raise ValueError("Pipeline name is required") @@ -122,9 +122,7 @@ def __polling( while not completed and (end - start) < timeout: try: response_body = self.poll(poll_url, name=name) - logging.debug( - f"Polling for Pipeline: Status of polling for {name} : {response_body}" - ) + logging.debug(f"Polling for Pipeline: Status of polling for {name} : {response_body}") completed = response_body["completed"] end = time.time() @@ -136,13 +134,9 @@ def __polling( logging.error(f"Polling for Pipeline: polling for {name} : Continue") if response_body and response_body["status"] == "SUCCESS": try: - logging.debug( - f"Polling for Pipeline: Final status of polling for {name} : SUCCESS - {response_body}" - ) + logging.debug(f"Polling for Pipeline: Final status of polling for {name} : SUCCESS - {response_body}") except Exception: - logging.error( - f"Polling for Pipeline: Final status of polling for {name} : ERROR - {response_body}" - ) + logging.error(f"Polling for Pipeline: Final status of polling for {name} : ERROR - {response_body}") else: logging.error( f"Polling for Pipeline: Final status of polling for {name} : No response in {timeout} seconds - {response_body}" @@ -172,9 +166,7 @@ def poll(self, poll_url: Text, name: Text = "pipeline_process") -> Dict: resp["data"] = json.loads(resp["data"])["response"] except Exception: resp = r.json() - logging.info( - f"Single Poll for Pipeline: Status of polling for {name} : {resp}" - ) + logging.info(f"Single Poll for Pipeline: Status of polling for {name} : {resp}") except Exception: resp = {"status": "FAILED"} return resp @@ -196,9 +188,7 @@ def _should_fallback_to_v2(self, response: Dict, version: str) -> bool: should_fallback = False if "status" not in response or response["status"] == "FAILED": should_fallback = True - elif response["status"] == "SUCCESS" and ( - "data" not in response or not response["data"] - ): + elif response["status"] == "SUCCESS" and ("data" not in response or not response["data"]): should_fallback = True # Check for conditions that require a fallback @@ -304,10 +294,7 @@ def __prepare_payload( try: payload = json.loads(data) if isinstance(payload, dict) is False: - if ( - isinstance(payload, int) is True - or isinstance(payload, float) is True - ): + if isinstance(payload, int) is True or isinstance(payload, float) is True: payload = str(payload) payload = {"data": payload} except Exception: @@ -345,9 +332,7 @@ def __prepare_payload( asset_payload["dataAsset"]["dataset_id"] = dasset.id source_data_list = [ - dfield - for dfield in dasset.source_data - if dasset.source_data[dfield].id == data[node_label] + dfield for dfield in dasset.source_data if dasset.source_data[dfield].id == data[node_label] ] if len(source_data_list) > 0: @@ -420,9 +405,7 @@ def run_async( try: if 200 <= r.status_code < 300: resp = r.json() - logging.info( - f"Result of request for {name} - {r.status_code} - {resp}" - ) + logging.info(f"Result of request for {name} - {r.status_code} - {resp}") poll_url = resp["url"] response = {"status": "IN_PROGRESS", "url": poll_url} else: @@ -438,7 +421,9 @@ def run_async( error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." else: status_code = str(r.status_code) - error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." + error = ( + f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." + ) response = {"status": "FAILED", "error_message": error} logging.error(f"Error in request for {name} - {r.status_code}: {error}") except Exception: @@ -487,9 +472,7 @@ def update( for i, node in enumerate(pipeline["nodes"]): if "functionType" in node: - pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i][ - "functionType" - ].lower() + pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i]["functionType"].lower() # prepare payload status = "draft" if save_as_asset is True: @@ -507,9 +490,7 @@ def update( "Authorization": f"Token {api_key}", "Content-Type": "application/json", } - logging.info( - f"Start service for PUT Update Pipeline - {url} - {headers} - {json.dumps(payload)}" - ) + logging.info(f"Start service for PUT Update Pipeline - {url} - {headers} - {json.dumps(payload)}") r = _request_with_retry("put", url, headers=headers, json=payload) response = r.json() logging.info(f"Pipeline {response['id']} Updated.") @@ -564,9 +545,7 @@ def save( for i, node in enumerate(pipeline["nodes"]): if "functionType" in node: - pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i][ - "functionType" - ].lower() + pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i]["functionType"].lower() # prepare payload status = "draft" if save_as_asset is True: diff --git a/aixplain/modules/pipeline/pipeline_cache.py b/aixplain/modules/pipeline/pipeline_cache.py index 4c2f817e..2d90e069 100644 --- a/aixplain/modules/pipeline/pipeline_cache.py +++ b/aixplain/modules/pipeline/pipeline_cache.py @@ -13,6 +13,8 @@ PIPELINE_CACHE_FILE = f"{CACHE_FOLDER}/pipelines.json" LOCK_FILE = f"{PIPELINE_CACHE_FILE}.lock" + + def load_pipelines(cache_expiry: Optional[int] = None): """ Load pipelines from cache or fetch from backend if not cached. @@ -22,7 +24,7 @@ def load_pipelines(cache_expiry: Optional[int] = None): cache_expiry (int, optional): Expiry time in seconds. Default is user-configurable. """ if cache_expiry is None: - cache_expiry = 86400 + cache_expiry = 86400 api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL @@ -46,6 +48,7 @@ def load_pipelines(cache_expiry: Optional[int] = None): return parse_pipelines({"items": onboarded_pipelines}) + def parse_pipelines(pipelines_data): """ Convert pipeline data into an Enum and dictionary format for easy use. @@ -54,10 +57,10 @@ def parse_pipelines(pipelines_data): - pipelines_enum: Enum with pipeline IDs. - pipelines_details: Dictionary containing all pipeline parameters. """ - if not pipelines_data["items"]: + if not pipelines_data["items"]: logging.warning("No onboarded pipelines found.") return Enum("Pipeline", {}), {} - + pipelines_enum = Enum("Pipeline", {p["id"].upper().replace("-", "_"): p["id"] for p in pipelines_data["items"]}, type=str) pipelines_details = { @@ -71,11 +74,12 @@ def parse_pipelines(pipelines_data): "status": pipeline.get("status", "onboarded"), "created_at": pipeline.get("created_at"), "architecture": pipeline.get("architecture", {}), - **pipeline, + **pipeline, } for pipeline in pipelines_data["items"] } return pipelines_enum, pipelines_details + Pipeline, PipelineDetails = load_pipelines() diff --git a/aixplain/utils/cache_utils.py b/aixplain/utils/cache_utils.py index 6717234f..01981701 100644 --- a/aixplain/utils/cache_utils.py +++ b/aixplain/utils/cache_utils.py @@ -9,6 +9,7 @@ LOCK_FILE = f"{CACHE_FILE}.lock" DEFAULT_CACHE_EXPIRY = 86400 + def get_cache_expiry(): return int(os.getenv("CACHE_EXPIRY_TIME", DEFAULT_CACHE_EXPIRY)) @@ -22,6 +23,7 @@ def save_to_cache(cache_file, data, lock_file): except Exception as e: logging.error(f"Failed to save cache to {cache_file}: {e}") + def load_from_cache(cache_file, lock_file): if os.path.exists(cache_file): with FileLock(lock_file): diff --git a/aixplain/utils/file_utils.py b/aixplain/utils/file_utils.py index d39ca2b9..554b80d9 100644 --- a/aixplain/utils/file_utils.py +++ b/aixplain/utils/file_utils.py @@ -153,7 +153,9 @@ def upload_data( raise Exception("File Uploading Error: Failure on Uploading to S3.") -def s3_to_csv(s3_url: Text, aws_credentials: Optional[Dict[Text, Text]] = {"AWS_ACCESS_KEY_ID": None, "AWS_SECRET_ACCESS_KEY": None}) -> Text: +def s3_to_csv( + s3_url: Text, aws_credentials: Optional[Dict[Text, Text]] = {"AWS_ACCESS_KEY_ID": None, "AWS_SECRET_ACCESS_KEY": None} +) -> Text: """Convert s3 url to a csv file and download the file in `download_path` Args: diff --git a/tests/functional/model/run_utility_model_test.py b/tests/functional/model/run_utility_model_test.py index b9ef5465..17257b2a 100644 --- a/tests/functional/model/run_utility_model_test.py +++ b/tests/functional/model/run_utility_model_test.py @@ -2,6 +2,7 @@ from aixplain.modules.model.utility_model import UtilityModelInput, utility_tool from aixplain.enums import DataType + def test_run_utility_model(): utility_model = None try: @@ -36,22 +37,24 @@ def test_run_utility_model(): if utility_model: utility_model.delete() + def test_utility_model_with_decorator(): utility_model = None try: + @utility_tool( - name="add_numbers_test name", - description="Adds two numbers together.", - inputs=[ + name="add_numbers_test name", + description="Adds two numbers together.", + inputs=[ UtilityModelInput(name="num1", type=DataType.NUMBER, description="The first number."), - UtilityModelInput(name="num2", type=DataType.NUMBER, description="The second number.") + UtilityModelInput(name="num2", type=DataType.NUMBER, description="The second number."), ], ) def add_numbers(num1: int, num2: int) -> int: return num1 + num2 utility_model = ModelFactory.create_utility_model(code=add_numbers) - + assert utility_model.id is not None assert len(utility_model.inputs) == 2 assert utility_model.inputs[0].name == "num1" @@ -64,16 +67,18 @@ def add_numbers(num1: int, num2: int) -> int: if utility_model: utility_model.delete() + def test_utility_model_string_concatenation(): utility_model = None try: + @utility_tool( name="concatenate_strings", description="Concatenates two strings.", inputs=[ UtilityModelInput(name="str1", type=DataType.TEXT, description="The first string."), UtilityModelInput(name="str2", type=DataType.TEXT, description="The second string."), - ] + ], ) def concatenate_strings(str1: str, str2: str) -> str: """Concatenates two strings and returns the result.""" @@ -96,6 +101,7 @@ def concatenate_strings(str1: str, str2: str) -> str: if utility_model: utility_model.delete() + def test_utility_model_code_as_string(): utility_model = None try: @@ -108,10 +114,7 @@ def multiply_numbers(int1: int, int2: int) -> int: \"\"\"Multiply two numbers and returns the result.\"\"\" return int1 * int2 """ - utility_model = ModelFactory.create_utility_model( - name="Multiply Numbers Test", - code=code - ) + utility_model = ModelFactory.create_utility_model(name="Multiply Numbers Test", code=code) assert utility_model.id is not None assert len(utility_model.inputs) == 2 @@ -123,13 +126,15 @@ def multiply_numbers(int1: int, int2: int) -> int: if utility_model: utility_model.delete() + def test_utility_model_simple_function(): utility_model = None try: + def test_string(input: str): """test string""" return input - + utility_model = ModelFactory.create_utility_model( name="String Model Test", code=test_string, diff --git a/tests/unit/utility_test.py b/tests/unit/utility_test.py index 305c6a52..8adfe5bc 100644 --- a/tests/unit/utility_test.py +++ b/tests/unit/utility_test.py @@ -25,7 +25,9 @@ def test_utility_model(): assert utility_model.name == "utility_model_test" assert utility_model.description == "utility_model_test" assert utility_model.code == "utility_model_test" - assert utility_model.inputs == [UtilityModelInput(name="input_string", description="The input_string input is a text", type=DataType.TEXT)] + assert utility_model.inputs == [ + UtilityModelInput(name="input_string", description="The input_string input is a text", type=DataType.TEXT) + ] assert utility_model.output_examples == "output_description" @@ -87,8 +89,14 @@ def test_utility_model_to_dict(): def test_update_utility_model(): with requests_mock.Mocker() as mock: - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): + with patch( + "aixplain.factories.file_factory.FileFactory.to_link", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): + with patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): with patch( "aixplain.modules.model.utils.parse_code", return_value=( @@ -127,8 +135,14 @@ def test_update_utility_model(): def test_save_utility_model(): with requests_mock.Mocker() as mock: - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): + with patch( + "aixplain.factories.file_factory.FileFactory.to_link", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): + with patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): with patch( "aixplain.modules.model.utils.parse_code", return_value=( @@ -170,8 +184,14 @@ def test_save_utility_model(): def test_delete_utility_model(): with requests_mock.Mocker() as mock: - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): + with patch( + "aixplain.factories.file_factory.FileFactory.to_link", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): + with patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): mock.delete(urljoin(config.BACKEND_URL, "sdk/utilities/123"), status_code=200, json={"id": "123"}) utility_model = UtilityModel( id="123", @@ -241,8 +261,14 @@ def main(originCode): def test_validate_new_model(): """Test validation for a new model""" - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): + with patch( + "aixplain.factories.file_factory.FileFactory.to_link", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): + with patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): # Test with valid inputs utility_model = UtilityModel( id="", # Empty ID for new model diff --git a/tests/unit/utility_tool_decorator_test.py b/tests/unit/utility_tool_decorator_test.py index f9c87f02..c63aa5f6 100644 --- a/tests/unit/utility_tool_decorator_test.py +++ b/tests/unit/utility_tool_decorator_test.py @@ -3,16 +3,15 @@ from aixplain.enums.asset_status import AssetStatus from aixplain.modules.model.utility_model import utility_tool, UtilityModelInput + def test_utility_tool_basic_decoration(): """Test basic decoration with minimal parameters""" - @utility_tool( - name="test_function", - description="Test function description" - ) + + @utility_tool(name="test_function", description="Test function description") def test_func(input_text: str) -> str: return input_text - assert hasattr(test_func, '_is_utility_tool') + assert hasattr(test_func, "_is_utility_tool") assert test_func._is_utility_tool is True assert test_func._tool_name == "test_function" assert test_func._tool_description == "Test function description" @@ -20,19 +19,20 @@ def test_func(input_text: str) -> str: assert test_func._tool_output_examples == "" assert test_func._tool_status == AssetStatus.DRAFT + def test_utility_tool_with_all_parameters(): """Test decoration with all optional parameters""" inputs = [ UtilityModelInput(name="text_input", type=DataType.TEXT, description="A text input"), - UtilityModelInput(name="num_input", type=DataType.NUMBER, description="A number input") + UtilityModelInput(name="num_input", type=DataType.NUMBER, description="A number input"), ] - + @utility_tool( name="full_test_function", description="Full test function description", inputs=inputs, output_examples="Example output: Hello World", - status=AssetStatus.ONBOARDED + status=AssetStatus.ONBOARDED, ) def test_func(text_input: str, num_input: int) -> str: return f"{text_input} {num_input}" @@ -45,32 +45,28 @@ def test_func(text_input: str, num_input: int) -> str: assert test_func._tool_output_examples == "Example output: Hello World" assert test_func._tool_status == AssetStatus.ONBOARDED + def test_utility_tool_function_still_callable(): """Test that decorated function remains callable""" - @utility_tool( - name="callable_test", - description="Test function callable" - ) + + @utility_tool(name="callable_test", description="Test function callable") def test_func(x: int, y: int) -> int: return x + y assert test_func(2, 3) == 5 assert test_func._is_utility_tool is True + def test_utility_tool_invalid_inputs(): """Test validation of invalid inputs""" with pytest.raises(ValueError): - @utility_tool( - name="", # Empty name should raise error - description="Test description" - ) + + @utility_tool(name="", description="Test description") # Empty name should raise error def test_func(): pass with pytest.raises(ValueError): - @utility_tool( - name="test_name", - description="" # Empty description should raise error - ) + + @utility_tool(name="test_name", description="") # Empty description should raise error def test_func(): - pass \ No newline at end of file + pass From 7872aa40615dde5468a46858f59b18cc8e121843 Mon Sep 17 00:00:00 2001 From: xainaz Date: Fri, 7 Feb 2025 15:31:09 +0300 Subject: [PATCH 3/6] Agent Cache Class added --- aixplain/enums/__init__.py | 1 + aixplain/enums/aixplain_cache.py | 114 ++++++++++++++++++++ aixplain/modules/agent/__init__.py | 6 +- aixplain/modules/agent/cache_agents.py | 101 ----------------- aixplain/modules/model/__init__.py | 6 +- aixplain/modules/model/cache_models.py | 88 --------------- aixplain/modules/pipeline/asset.py | 4 +- aixplain/modules/pipeline/pipeline_cache.py | 85 --------------- 8 files changed, 124 insertions(+), 281 deletions(-) create mode 100644 aixplain/enums/aixplain_cache.py delete mode 100644 aixplain/modules/agent/cache_agents.py delete mode 100644 aixplain/modules/model/cache_models.py delete mode 100644 aixplain/modules/pipeline/pipeline_cache.py diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index ef497ddd..201dd9e7 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -14,3 +14,4 @@ from .sort_by import SortBy from .sort_order import SortOrder from .response_status import ResponseStatus +from .aixplain_cache import AixplainCache \ No newline at end of file diff --git a/aixplain/enums/aixplain_cache.py b/aixplain/enums/aixplain_cache.py new file mode 100644 index 00000000..339360a3 --- /dev/null +++ b/aixplain/enums/aixplain_cache.py @@ -0,0 +1,114 @@ +import os +import json +import logging +from datetime import datetime +from enum import Enum +from urllib.parse import urljoin +from typing import Dict, Optional, Tuple, List +from aixplain.utils import config +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER + + +class AixplainCache: + """ + A modular caching system to handle different asset types (Models, Pipelines, Agents). + This class reduces code repetition and allows easier maintenance. + """ + + def __init__(self, asset_type: str, cache_filename: str): + """ + Initialize the cache for a given asset type. + + Args: + asset_type (str): Type of the asset (e.g., "models", "pipelines", "agents"). + cache_filename (str): Filename for storing cached data. + """ + self.asset_type = asset_type + self.cache_file = f"{CACHE_FOLDER}/{cache_filename}.json" + self.lock_file = f"{self.cache_file}.lock" + os.makedirs(CACHE_FOLDER, exist_ok=True) # Ensure cache folder exists + + def load_assets(self, cache_expiry: Optional[int] = 86400) -> Tuple[Enum, Dict]: + """ + Load assets from cache or fetch from backend if not cached. + + Args: + cache_expiry (int, optional): Expiry time in seconds. Default is 24 hours. + + Returns: + Tuple[Enum, Dict]: (Enum of asset IDs, Dictionary with asset details) + """ + cached_data = load_from_cache(self.cache_file, self.lock_file) + if cached_data is not None: + return self.parse_assets(cached_data) + + api_key = config.TEAM_API_KEY + backend_url = config.BACKEND_URL + url = urljoin(backend_url, f"sdk/{self.asset_type}") + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + try: + response = _request_with_retry("get", url, headers=headers) + response.raise_for_status() + assets_data = response.json() + except Exception as e: + logging.error(f"Failed to fetch {self.asset_type} from API: {e}") + return Enum(self.asset_type.capitalize(), {}), {} + + if "items" not in assets_data: + return Enum(self.asset_type.capitalize(), {}), {} + + onboarded_assets = [asset for asset in assets_data["items"] if asset.get("status", "").lower() == "onboarded"] + + save_to_cache(self.cache_file, {"items": onboarded_assets}, self.lock_file) + + return self.parse_assets({"items": onboarded_assets}) + + def parse_assets(self, assets_data: Dict) -> Tuple[Enum, Dict]: + """ + Convert asset data into an Enum and dictionary format for easy use. + + Args: + assets_data (Dict): JSON response with asset list. + + Returns: + - assets_enum: Enum with asset IDs. + - assets_details: Dictionary containing all asset parameters. + """ + if not assets_data["items"]: # Handle case where no assets are onboarded + logging.warning(f"No onboarded {self.asset_type} found.") + return Enum(self.asset_type.capitalize(), {}), {} + + assets_enum = Enum( + self.asset_type.capitalize(), + {a["id"].upper().replace("-", "_"): a["id"] for a in assets_data["items"]}, + type=str, + ) + + assets_details = { + asset["id"]: { + "id": asset["id"], + "name": asset.get("name", ""), + "description": asset.get("description", ""), + "api_key": asset.get("api_key", config.TEAM_API_KEY), + "supplier": asset.get("supplier", "aiXplain"), + "version": asset.get("version", "1.0"), + "status": asset.get("status", "onboarded"), + "created_at": asset.get("created_at", ""), + **asset, # Include any extra fields + } + for asset in assets_data["items"] + } + + return assets_enum, assets_details + + +ModelCache = AixplainCache("models", "models") +Model, ModelDetails = ModelCache.load_assets() + +PipelineCache = AixplainCache("pipelines", "pipelines") +Pipeline, PipelineDetails = PipelineCache.load_assets() + +AgentCache = AixplainCache("agents", "agents") +Agent, AgentDetails = AgentCache.load_assets() diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 441c20fa..26052e06 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -37,9 +37,8 @@ from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData -from aixplain.enums import ResponseStatus +from aixplain.enums import ResponseStatus, AixplainCache from aixplain.modules.agent.utils import process_variables -from aixplain.modules.agent.cache_agents import load_agents from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin @@ -78,7 +77,8 @@ def __init__( tasks: List[AgentTask] = [], **additional_info, ) -> None: - AgentCache, AgentDetails = load_agents(cache_expiry=86400) + AgentCache = AixplainCache("agents", "agents") + AgentEnum, AgentDetails = AgentCache.load_assets() if id in AgentDetails: cached_agent = AgentDetails[id] diff --git a/aixplain/modules/agent/cache_agents.py b/aixplain/modules/agent/cache_agents.py deleted file mode 100644 index 4d38344f..00000000 --- a/aixplain/modules/agent/cache_agents.py +++ /dev/null @@ -1,101 +0,0 @@ -import os -import json -import logging -from datetime import datetime -from enum import Enum -from urllib.parse import urljoin -from typing import Dict, Optional, List, Tuple, Union, Text -from aixplain.utils import config -from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER -from aixplain.enums import Supplier -from aixplain.modules.agent.tool import Tool - -AGENT_CACHE_FILE = f"{CACHE_FOLDER}/agents.json" -LOCK_FILE = f"{AGENT_CACHE_FILE}.lock" - - -def load_agents(cache_expiry: Optional[int] = None) -> Tuple[Enum, Dict]: - """ - Load AI agents from cache or fetch from backend if not cached. - Only agents with status "onboarded" should be cached. - - Args: - cache_expiry (int, optional): Expiry time in seconds. Default is 24 hours. - - Returns: - Tuple[Enum, Dict]: (Enum of agent IDs, Dictionary with agent details) - """ - if cache_expiry is None: - cache_expiry = 86400 - - os.makedirs(CACHE_FOLDER, exist_ok=True) - - cached_data = load_from_cache(AGENT_CACHE_FILE, LOCK_FILE) - - if cached_data is not None: - return parse_agents(cached_data) - - api_key = config.TEAM_API_KEY - backend_url = config.BACKEND_URL - url = urljoin(backend_url, "sdk/agents") - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - - try: - response = _request_with_retry("get", url, headers=headers) - response.raise_for_status() - agents_data = response.json() - except Exception as e: - logging.error(f"Failed to fetch agents from API: {e}") - return Enum("Agent", {}), {} - - onboarded_agents = [agent for agent in agents_data if agent.get("status", "").lower() == "onboarded"] - - save_to_cache(AGENT_CACHE_FILE, {"items": onboarded_agents}, LOCK_FILE) - - return parse_agents({"items": onboarded_agents}) - - -def parse_agents(agents_data: Dict) -> Tuple[Enum, Dict]: - """ - Convert agent data into an Enum and dictionary format for easy use. - - Args: - agents_data (Dict): JSON response with agents list. - - Returns: - - agents_enum: Enum with agent IDs. - - agents_details: Dictionary containing all agent parameters. - """ - if not agents_data["items"]: - logging.warning("No onboarded agents found.") - return Enum("Agent", {}), {} - - agents_enum = Enum( - "Agent", - {a["id"].upper().replace("-", "_"): a["id"] for a in agents_data["items"]}, - type=str, - ) - - agents_details = { - agent["id"]: { - "id": agent["id"], - "name": agent.get("name", ""), - "description": agent.get("description", ""), - "role": agent.get("role", ""), - "tools": [Tool(t) if isinstance(t, dict) else t for t in agent.get("tools", [])], - "llm_id": agent.get("llm_id", "6646261c6eb563165658bbb1"), - "supplier": agent.get("supplier", "aiXplain"), - "version": agent.get("version", "1.0"), - "status": agent.get("status", "onboarded"), - "created_at": agent.get("created_at", ""), - "tasks": agent.get("tasks", []), - **agent, - } - for agent in agents_data["items"] - } - - return agents_enum, agents_details - - -Agent, AgentDetails = load_agents() diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index f3962e01..3aee0a31 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -23,7 +23,7 @@ import time import logging import traceback -from aixplain.enums import Supplier, Function +from aixplain.enums import Supplier, Function, AixplainCache from aixplain.modules.asset import Asset from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config @@ -34,7 +34,6 @@ from aixplain.modules.model.response import ModelResponse from aixplain.enums.response_status import ResponseStatus from aixplain.modules.model.model_parameters import ModelParameters -from aixplain.modules.model.cache_models import load_models class Model(Asset): @@ -92,7 +91,8 @@ def __init__( model_params (Dict, optional): parameters for the function. **additional_info: Any additional Model info to be saved """ - ModelCache, ModelDetails = load_models(cache_expiry=86400) + ModelCache = AixplainCache("models", "models") + ModelEnum, ModelDetails = ModelCache.load_assets() if id in ModelDetails: diff --git a/aixplain/modules/model/cache_models.py b/aixplain/modules/model/cache_models.py deleted file mode 100644 index a336c488..00000000 --- a/aixplain/modules/model/cache_models.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -import json -import time -import logging -from datetime import datetime -from enum import Enum -from urllib.parse import urljoin -from typing import Dict, Optional, Union, Text -from aixplain.utils import config -from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER -from aixplain.enums import Supplier, Function - -CACHE_FILE = f"{CACHE_FOLDER}/models.json" -LOCK_FILE = f"{CACHE_FILE}.lock" - - -def load_models(cache_expiry: Optional[int] = None): - """ - Load models from cache or fetch from backend if not cached. - Only models with status "onboarded" should be cached. - - Args: - cache_expiry (int, optional): Expiry time in seconds. Default is user-configurable. - """ - - api_key = config.TEAM_API_KEY - backend_url = config.BACKEND_URL - - cached_data = load_from_cache(CACHE_FILE, LOCK_FILE) - - if cached_data is not None: - - return parse_models(cached_data) - - url = urljoin(backend_url, "sdk/models") - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - - response = _request_with_retry("get", url, headers=headers) - if not 200 <= response.status_code < 300: - raise Exception(f"Models could not be loaded, API key '{api_key}' might be invalid.") - - models_data = response.json() - - onboarded_models = [model for model in models_data["items"] if model["status"].lower() == "onboarded"] - save_to_cache(CACHE_FILE, {"items": onboarded_models}, LOCK_FILE) - - return parse_models({"items": onboarded_models}) - - -def parse_models(models_data): - """ - Convert model data into an Enum and dictionary format for easy use. - - Returns: - - models_enum: Enum with model IDs. - - models_details: Dictionary containing all model parameters. - """ - - if not models_data["items"]: - logging.warning("No onboarded models found.") - return Enum("Model", {}), {} - models_enum = Enum("Model", {m["id"].upper().replace("-", "_"): m["id"] for m in models_data["items"]}, type=str) - - models_details = { - model["id"]: { - "id": model["id"], - "name": model.get("name", ""), - "description": model.get("description", ""), - "api_key": model.get("api_key", config.TEAM_API_KEY), - "supplier": model.get("supplier", "aiXplain"), - "version": model.get("version"), - "function": model.get("function"), - "is_subscribed": model.get("is_subscribed", False), - "cost": model.get("cost"), - "created_at": model.get("created_at"), - "input_params": model.get("input_params"), - "output_params": model.get("output_params"), - "model_params": model.get("model_params"), - **model, - } - for model in models_data["items"] - } - - return models_enum, models_details - - -Model, ModelDetails = load_models() diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 216274fc..ee2c421e 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -31,7 +31,7 @@ from aixplain.utils.file_utils import _request_with_retry from typing import Dict, Optional, Text, Union from urllib.parse import urljoin -from aixplain.modules.pipeline.pipeline_cache import PipelineDetails +from aixplain.enums import AixplainCache class Pipeline(Asset): @@ -61,6 +61,8 @@ def __init__( status: AssetStatus = AssetStatus.DRAFT, **additional_info, ) -> None: + PipelineCache = AixplainCache("pipelines", "pipelines") + PipelineEnum, PipelineDetails = PipelineCache.load_assets() """Create a Pipeline with the necessary information Args: diff --git a/aixplain/modules/pipeline/pipeline_cache.py b/aixplain/modules/pipeline/pipeline_cache.py deleted file mode 100644 index 2d90e069..00000000 --- a/aixplain/modules/pipeline/pipeline_cache.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -import json -import time -import logging -from datetime import datetime -from enum import Enum -from urllib.parse import urljoin -from typing import Dict, Optional, Union, Text -from aixplain.utils import config -from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER -from aixplain.enums import Supplier - -PIPELINE_CACHE_FILE = f"{CACHE_FOLDER}/pipelines.json" -LOCK_FILE = f"{PIPELINE_CACHE_FILE}.lock" - - -def load_pipelines(cache_expiry: Optional[int] = None): - """ - Load pipelines from cache or fetch from backend if not cached. - Only pipelines with status "onboarded" should be cached. - - Args: - cache_expiry (int, optional): Expiry time in seconds. Default is user-configurable. - """ - if cache_expiry is None: - cache_expiry = 86400 - - api_key = config.TEAM_API_KEY - backend_url = config.BACKEND_URL - - cached_data = load_from_cache(PIPELINE_CACHE_FILE, LOCK_FILE) - if cached_data is not None: - return parse_pipelines(cached_data) - - url = urljoin(backend_url, "sdk/pipelines") - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - - response = _request_with_retry("get", url, headers=headers) - if not 200 <= response.status_code < 300: - raise Exception(f"Pipelines could not be loaded, API key '{api_key}' might be invalid.") - - pipelines_data = response.json() - - onboarded_pipelines = [pipeline for pipeline in pipelines_data["items"] if pipeline["status"].lower() == "onboarded"] - - save_to_cache(PIPELINE_CACHE_FILE, {"items": onboarded_pipelines}, LOCK_FILE) - - return parse_pipelines({"items": onboarded_pipelines}) - - -def parse_pipelines(pipelines_data): - """ - Convert pipeline data into an Enum and dictionary format for easy use. - - Returns: - - pipelines_enum: Enum with pipeline IDs. - - pipelines_details: Dictionary containing all pipeline parameters. - """ - if not pipelines_data["items"]: - logging.warning("No onboarded pipelines found.") - return Enum("Pipeline", {}), {} - - pipelines_enum = Enum("Pipeline", {p["id"].upper().replace("-", "_"): p["id"] for p in pipelines_data["items"]}, type=str) - - pipelines_details = { - pipeline["id"]: { - "id": pipeline["id"], - "name": pipeline.get("name", ""), - "description": pipeline.get("description", ""), - "api_key": pipeline.get("api_key", config.TEAM_API_KEY), - "supplier": pipeline.get("supplier", "aiXplain"), - "version": pipeline.get("version", "1.0"), - "status": pipeline.get("status", "onboarded"), - "created_at": pipeline.get("created_at"), - "architecture": pipeline.get("architecture", {}), - **pipeline, - } - for pipeline in pipelines_data["items"] - } - - return pipelines_enum, pipelines_details - - -Pipeline, PipelineDetails = load_pipelines() From 92d8e283ca590f2625877a018f863bf351ca5a7f Mon Sep 17 00:00:00 2001 From: ahmetgunduz Date: Mon, 10 Feb 2025 10:49:28 +0000 Subject: [PATCH 4/6] removed unused imports --- aixplain/enums/aixplain_cache.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/aixplain/enums/aixplain_cache.py b/aixplain/enums/aixplain_cache.py index 339360a3..23129e24 100644 --- a/aixplain/enums/aixplain_cache.py +++ b/aixplain/enums/aixplain_cache.py @@ -1,10 +1,8 @@ import os -import json import logging -from datetime import datetime from enum import Enum from urllib.parse import urljoin -from typing import Dict, Optional, Tuple, List +from typing import Dict, Optional, Tuple from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER @@ -54,7 +52,7 @@ def load_assets(self, cache_expiry: Optional[int] = 86400) -> Tuple[Enum, Dict]: assets_data = response.json() except Exception as e: logging.error(f"Failed to fetch {self.asset_type} from API: {e}") - return Enum(self.asset_type.capitalize(), {}), {} + return Enum(self.asset_type.capitalize(), {}), {} if "items" not in assets_data: return Enum(self.asset_type.capitalize(), {}), {} @@ -104,11 +102,16 @@ def parse_assets(self, assets_data: Dict) -> Tuple[Enum, Dict]: return assets_enum, assets_details -ModelCache = AixplainCache("models", "models") -Model, ModelDetails = ModelCache.load_assets() +if __name__ == "__main__": + ModelCache = AixplainCache("models", "models") + Model, ModelDetails = ModelCache.load_assets() + + PipelineCache = AixplainCache("pipelines", "pipelines") + Pipeline, PipelineDetails = PipelineCache.load_assets() -PipelineCache = AixplainCache("pipelines", "pipelines") -Pipeline, PipelineDetails = PipelineCache.load_assets() + AgentCache = AixplainCache("agents", "agents") + Agent, AgentDetails = AgentCache.load_assets() -AgentCache = AixplainCache("agents", "agents") -Agent, AgentDetails = AgentCache.load_assets() + print(ModelDetails) + print(PipelineDetails) + print(AgentDetails) From 31d7b7954fdfa56a598041c5530d63185116bdc5 Mon Sep 17 00:00:00 2001 From: xainaz Date: Wed, 19 Mar 2025 14:23:29 +0300 Subject: [PATCH 5/6] made requested changes and added functional tests --- aixplain/enums/aixplain_cache.py | 22 +++------------- aixplain/modules/agent/__init__.py | 32 +++++++++++------------ aixplain/modules/model/__init__.py | 23 +++++++++------- tests/functional/model/run_model_test.py | 29 ++++++++++++++++++-- tests/functional/pipelines/create_test.py | 27 ++++++++++++++++++- 5 files changed, 86 insertions(+), 47 deletions(-) diff --git a/aixplain/enums/aixplain_cache.py b/aixplain/enums/aixplain_cache.py index 23129e24..26543693 100644 --- a/aixplain/enums/aixplain_cache.py +++ b/aixplain/enums/aixplain_cache.py @@ -14,7 +14,7 @@ class AixplainCache: This class reduces code repetition and allows easier maintenance. """ - def __init__(self, asset_type: str, cache_filename: str): + def __init__(self, asset_type: str, cache_filename: Optional[str] = None): """ Initialize the cache for a given asset type. @@ -23,7 +23,8 @@ def __init__(self, asset_type: str, cache_filename: str): cache_filename (str): Filename for storing cached data. """ self.asset_type = asset_type - self.cache_file = f"{CACHE_FOLDER}/{cache_filename}.json" + filename = cache_filename if cache_filename else asset_type + self.cache_file = f"{CACHE_FOLDER}/{filename}.json" self.lock_file = f"{self.cache_file}.lock" os.makedirs(CACHE_FOLDER, exist_ok=True) # Ensure cache folder exists @@ -99,19 +100,4 @@ def parse_assets(self, assets_data: Dict) -> Tuple[Enum, Dict]: for asset in assets_data["items"] } - return assets_enum, assets_details - - -if __name__ == "__main__": - ModelCache = AixplainCache("models", "models") - Model, ModelDetails = ModelCache.load_assets() - - PipelineCache = AixplainCache("pipelines", "pipelines") - Pipeline, PipelineDetails = PipelineCache.load_assets() - - AgentCache = AixplainCache("agents", "agents") - Agent, AgentDetails = AgentCache.load_assets() - - print(ModelDetails) - print(PipelineDetails) - print(AgentDetails) + return assets_enum, assets_details \ No newline at end of file diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 372c78c3..4bf54565 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -61,6 +61,8 @@ class Agent(Model): api_key (str): The TEAM API key used for authentication. cost (Dict, optional): model price. Defaults to None. """ + AgentCache = AixplainCache("agents", "agents") + AgentEnum, AgentDetails = AgentCache.load_assets() is_valid: bool @@ -80,22 +82,20 @@ def __init__( tasks: List[AgentTask] = [], **additional_info, ) -> None: - AgentCache = AixplainCache("agents", "agents") - AgentEnum, AgentDetails = AgentCache.load_assets() - - if id in AgentDetails: - cached_agent = AgentDetails[id] - name = cached_agent["name"] - description = cached_agent["description"] - tools = cached_agent["tools"] - llm_id = cached_agent["llm_id"] - api_key = cached_agent["api_key"] - supplier = cached_agent["supplier"] - version = cached_agent["version"] - cost = cached_agent["cost"] - status = cached_agent["status"] - tasks = cached_agent["tasks"] - additional_info = cached_agent["additional_info"] + + if id in self.__class__.AgentDetails: + cached_agent = self.__class__.AgentDetails[id] + name = cached_agent.get("name", name) + description = cached_agent.get("description", description) + tools = cached_agent.get("tools", tools) + llm_id = cached_agent.get("llm_id", llm_id) + api_key = cached_agent.get("api_key", api_key) + supplier = cached_agent.get("supplier", supplier) + version = cached_agent.get("version", version) + cost = cached_agent.get("cost", cost) + status = cached_agent.get("status", status) + tasks = cached_agent.get("tasks", tasks) + additional_info = cached_agent.get("additional_info", additional_info) """Create an Agent with the necessary information. diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 3aee0a31..e7e3def9 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -95,21 +95,24 @@ def __init__( ModelEnum, ModelDetails = ModelCache.load_assets() if id in ModelDetails: - cached_model = ModelDetails[id] - input_params = cached_model["input"] - api_key = cached_model["spec"].get("api_key", api_key) - additional_info = cached_model["spec"].get("additional_info", {}) - function = cached_model["spec"].get("function", function) - is_subscribed = cached_model["spec"].get("is_subscribed", is_subscribed) - created_at = cached_model["spec"].get("created_at", created_at) - model_params = cached_model["spec"].get("model_params", model_params) - output_params = cached_model["output"] - description = cached_model["spec"].get("description", description) + + input_params = cached_model.get("params", input_params) + function = cached_model.get("function", {}).get("name", function) + name = cached_model.get("name", name) + supplier = cached_model.get("supplier", supplier) + + created_at_str = cached_model.get("createdAt") + if created_at_str: + created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) + + cost = cached_model.get("pricing", cost) + super().__init__(id, name, description, supplier, version, cost=cost) self.api_key = api_key self.additional_info = additional_info + self.name = name self.url = config.MODELS_RUN_URL self.backend_url = config.BACKEND_URL self.function = function diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index d3d0082f..6c601917 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -1,12 +1,14 @@ __author__ = "thiagocastroferreira" - +import os +import json from aixplain.enums import Function from aixplain.factories import ModelFactory from aixplain.modules import LLM from datetime import datetime, timedelta, timezone from pathlib import Path - +from aixplain.utils.cache_utils import CACHE_FOLDER +from aixplain.modules.model import Model def pytest_generate_tests(metafunc): if "llm_model" in metafunc.fixturenames: @@ -94,3 +96,26 @@ def test_llm_run_with_file(): # Verify response assert response["status"] == "SUCCESS" assert "🤖" in response["data"], "Robot emoji should be present in the response" + + +def test_aixplain_model_cache_creation(): + """Ensure AixplainCache is triggered and cache is created.""" + + cache_file = os.path.join(CACHE_FOLDER, "models.json") + + # Clean up cache before the test + if os.path.exists(cache_file): + os.remove(cache_file) + + # Instantiate the Model (replace this with a real model ID from your env) + model_id = "6239efa4822d7a13b8e20454" # Translate from Punjabi to Portuguese (Brazil) + _ = Model(id=model_id) + + # Assert the cache file was created + assert os.path.exists(cache_file), "Expected cache file was not created." + + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + assert "data" in cache_data, "Cache file structure invalid - missing 'data' key." + assert any(m.get("id") == model_id for m in cache_data["data"]["items"]), "Instantiated model not found in cache." diff --git a/tests/functional/pipelines/create_test.py b/tests/functional/pipelines/create_test.py index ae33f454..8bcb3b7d 100644 --- a/tests/functional/pipelines/create_test.py +++ b/tests/functional/pipelines/create_test.py @@ -15,7 +15,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - +import os +from aixplain.utils.cache_utils import CACHE_FOLDER +from aixplain.modules.pipeline import Pipeline import json import pytest from aixplain.factories import PipelineFactory @@ -75,3 +77,26 @@ def test_create_pipeline_wrong_path(PipelineFactory): with pytest.raises(Exception): PipelineFactory.create(name=pipeline_name, pipeline="/") + + + +@pytest.mark.parametrize("PipelineFactory", [PipelineFactory]) +def test_pipeline_cache_creation(PipelineFactory): + cache_file = os.path.join(CACHE_FOLDER, "pipelines.json") + if os.path.exists(cache_file): + os.remove(cache_file) + + pipeline_json = "tests/functional/pipelines/data/pipeline.json" + pipeline_name = str(uuid4()) + pipeline = PipelineFactory.create(name=pipeline_name, pipeline=pipeline_json) + + assert os.path.exists(cache_file), "Pipeline cache file was not created!" + + with open(cache_file, "r") as f: + cache_data = json.load(f) + + assert "data" in cache_data, "Cache format invalid, missing 'data'." + + pipeline.delete() + if os.path.exists(cache_file): + os.remove(cache_file) \ No newline at end of file From be58677b07c6988bc72d848d6bf314205d67ac57 Mon Sep 17 00:00:00 2001 From: xainaz Date: Mon, 24 Mar 2025 22:40:40 +0300 Subject: [PATCH 6/6] changes --- aixplain/enums/aixplain_cache.py | 168 ++++++++++++++++++----------- aixplain/enums/function.py | 6 +- aixplain/enums/language.py | 6 +- aixplain/enums/license.py | 6 +- aixplain/modules/model/__init__.py | 72 +++++++++---- 5 files changed, 170 insertions(+), 88 deletions(-) diff --git a/aixplain/enums/aixplain_cache.py b/aixplain/enums/aixplain_cache.py index 26543693..f8cb2f78 100644 --- a/aixplain/enums/aixplain_cache.py +++ b/aixplain/enums/aixplain_cache.py @@ -1,103 +1,151 @@ import os import logging +import json +import time from enum import Enum from urllib.parse import urljoin from typing import Dict, Optional, Tuple +from dataclasses import dataclass +from filelock import FileLock + from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER - +from aixplain.enums.privacy import Privacy +# Constants +CACHE_FOLDER = ".cache" +DEFAULT_CACHE_EXPIRY = 86400 + +@dataclass +class Asset: + id: str + name: str = "" + description: str = "" + api_key: str = config.TEAM_API_KEY + supplier: str = "aiXplain" + version: str = "1.0" + status: str = "onboarded" + created_at: str = "" + +class AssetType(Enum): + MODELS = "models" + PIPELINES = "pipelines" + AGENTS = "agents" + +def get_cache_expiry(): + return int(os.getenv("CACHE_EXPIRY_TIME", DEFAULT_CACHE_EXPIRY)) + +def _serialize(obj): + if isinstance(obj, (Privacy)): + return str(obj) # or obj.to_dict() if you have it + if isinstance(obj, Enum): + return obj.value + return obj.__dict__ if hasattr(obj, "__dict__") else str(obj) class AixplainCache: """ A modular caching system to handle different asset types (Models, Pipelines, Agents). - This class reduces code repetition and allows easier maintenance. """ - def __init__(self, asset_type: str, cache_filename: Optional[str] = None): - """ - Initialize the cache for a given asset type. - - Args: - asset_type (str): Type of the asset (e.g., "models", "pipelines", "agents"). - cache_filename (str): Filename for storing cached data. - """ - self.asset_type = asset_type - filename = cache_filename if cache_filename else asset_type - self.cache_file = f"{CACHE_FOLDER}/{filename}.json" - self.lock_file = f"{self.cache_file}.lock" - os.makedirs(CACHE_FOLDER, exist_ok=True) # Ensure cache folder exists - - def load_assets(self, cache_expiry: Optional[int] = 86400) -> Tuple[Enum, Dict]: + @staticmethod + def save_to_cache(cache_file, data, lock_file): + os.makedirs(os.path.dirname(cache_file), exist_ok=True) + with FileLock(lock_file): + with open(cache_file, "w") as f: + json.dump({"timestamp": time.time(), "data": data}, f,default=_serialize) + + @staticmethod + def load_from_cache(cache_file, lock_file): + if os.path.exists(cache_file): + with FileLock(lock_file): + with open(cache_file, "r") as f: + cache_data = json.load(f) + if time.time() - cache_data["timestamp"] < get_cache_expiry(): + return cache_data["data"] + else: + try: + os.remove(cache_file) + if os.path.exists(lock_file): + os.remove(lock_file) + except Exception as e: + logging.warning(f"Failed to remove expired cache or lock file: {e}") + return None + + @staticmethod + def fetch_assets_from_backend(asset_type_str: str) -> Optional[Dict]: """ - Load assets from cache or fetch from backend if not cached. - - Args: - cache_expiry (int, optional): Expiry time in seconds. Default is 24 hours. - - Returns: - Tuple[Enum, Dict]: (Enum of asset IDs, Dictionary with asset details) + Fetch assets data from the backend API. """ - cached_data = load_from_cache(self.cache_file, self.lock_file) - if cached_data is not None: - return self.parse_assets(cached_data) - api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL - url = urljoin(backend_url, f"sdk/{self.asset_type}") + url = urljoin(backend_url, f"sdk/{asset_type_str}") headers = {"x-api-key": api_key, "Content-Type": "application/json"} try: response = _request_with_retry("get", url, headers=headers) response.raise_for_status() - assets_data = response.json() + return response.json() except Exception as e: - logging.error(f"Failed to fetch {self.asset_type} from API: {e}") - return Enum(self.asset_type.capitalize(), {}), {} + logging.error(f"Failed to fetch {asset_type_str} from API: {e}") + return None - if "items" not in assets_data: - return Enum(self.asset_type.capitalize(), {}), {} + def __init__(self, asset_type: AssetType | str, cache_filename: Optional[str] = None): + if isinstance(asset_type, str): + asset_type = AssetType(asset_type.lower()) + self.asset_type = asset_type - onboarded_assets = [asset for asset in assets_data["items"] if asset.get("status", "").lower() == "onboarded"] + filename = cache_filename if cache_filename else self.asset_type.value + self.cache_file = f"{CACHE_FOLDER}/{filename}.json" + self.lock_file = f"{self.cache_file}.lock" + os.makedirs(CACHE_FOLDER, exist_ok=True) - save_to_cache(self.cache_file, {"items": onboarded_assets}, self.lock_file) + + def load_assets(self) -> Tuple[Enum, Dict]: + """ + Load assets from cache or fetch from backend if not cached. + """ + cached_data = self.load_from_cache(self.cache_file, self.lock_file) + if cached_data: + return self.parse_assets(cached_data) + + assets_data = self.fetch_assets_from_backend(self.asset_type.value) + if not assets_data or "items" not in assets_data: + return Enum(self.asset_type.name, {}), {} + + onboarded_assets = [ + asset for asset in assets_data["items"] + if asset.get("status", "").lower() == "onboarded" + ] + + self.save_to_cache(self.cache_file, {"items": onboarded_assets}, self.lock_file) return self.parse_assets({"items": onboarded_assets}) def parse_assets(self, assets_data: Dict) -> Tuple[Enum, Dict]: """ Convert asset data into an Enum and dictionary format for easy use. - - Args: - assets_data (Dict): JSON response with asset list. - - Returns: - - assets_enum: Enum with asset IDs. - - assets_details: Dictionary containing all asset parameters. """ - if not assets_data["items"]: # Handle case where no assets are onboarded - logging.warning(f"No onboarded {self.asset_type} found.") - return Enum(self.asset_type.capitalize(), {}), {} + if not assets_data["items"]: + logging.warning(f"No onboarded {self.asset_type.value} found.") + return Enum(self.asset_type.name, {}), {} assets_enum = Enum( - self.asset_type.capitalize(), + self.asset_type.name, {a["id"].upper().replace("-", "_"): a["id"] for a in assets_data["items"]}, type=str, ) assets_details = { - asset["id"]: { - "id": asset["id"], - "name": asset.get("name", ""), - "description": asset.get("description", ""), - "api_key": asset.get("api_key", config.TEAM_API_KEY), - "supplier": asset.get("supplier", "aiXplain"), - "version": asset.get("version", "1.0"), - "status": asset.get("status", "onboarded"), - "created_at": asset.get("created_at", ""), - **asset, # Include any extra fields - } + asset["id"]: Asset( + id=asset["id"], + name=asset.get("name", ""), + description=asset.get("description", ""), + api_key=asset.get("api_key", config.TEAM_API_KEY), + supplier=asset.get("supplier", "aiXplain"), + version=asset.get("version", "1.0"), + status=asset.get("status", "onboarded"), + created_at=asset.get("created_at", "") + ) for asset in assets_data["items"] } - return assets_enum, assets_details \ No newline at end of file + return assets_enum, assets_details diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index a77f3cfc..62408c58 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -25,7 +25,7 @@ from aixplain.utils.request_utils import _request_with_retry from enum import Enum from urllib.parse import urljoin -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.enums.aixplain_cache import AixplainCache, CACHE_FOLDER from typing import Tuple, Dict from aixplain.base.parameters import BaseParameters, Parameter import os @@ -40,7 +40,7 @@ def load_functions(): os.makedirs(CACHE_FOLDER, exist_ok=True) - resp = load_from_cache(CACHE_FILE, LOCK_FILE) + resp = AixplainCache.load_from_cache(CACHE_FILE, LOCK_FILE) if resp is None: url = urljoin(backend_url, "sdk/functions") @@ -51,7 +51,7 @@ def load_functions(): f'Functions could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp, LOCK_FILE) + AixplainCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) class Function(str, Enum): def __new__(cls, value): diff --git a/aixplain/enums/language.py b/aixplain/enums/language.py index c129822f..e0d035bd 100644 --- a/aixplain/enums/language.py +++ b/aixplain/enums/language.py @@ -25,14 +25,14 @@ from urllib.parse import urljoin from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.enums.aixplain_cache import AixplainCache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/languages.json" LOCK_FILE = f"{CACHE_FILE}.lock" def load_languages(): - resp = load_from_cache(CACHE_FILE, LOCK_FILE) + resp = AixplainCache.load_from_cache(CACHE_FILE, LOCK_FILE) if resp is None: api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL @@ -46,7 +46,7 @@ def load_languages(): f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp, LOCK_FILE) + AixplainCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) languages = {} for w in resp: diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index f9758b84..ef9285d9 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -26,14 +26,14 @@ from urllib.parse import urljoin from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.enums.aixplain_cache import AixplainCache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" LOCK_FILE = f"{CACHE_FILE}.lock" def load_licenses(): - resp = load_from_cache(CACHE_FILE, LOCK_FILE) + resp = AixplainCache.load_from_cache(CACHE_FILE, LOCK_FILE) try: if resp is None: @@ -49,7 +49,7 @@ def load_licenses(): f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp, LOCK_FILE) + AixplainCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) licenses = {"_".join(w["name"].split()): w["id"] for w in resp} return Enum("License", licenses, type=str) diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index e7e3def9..c07222e6 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -34,6 +34,7 @@ from aixplain.modules.model.response import ModelResponse from aixplain.enums.response_status import ResponseStatus from aixplain.modules.model.model_parameters import ModelParameters +from aixplain.enums.aixplain_cache import AssetType class Model(Asset): @@ -91,36 +92,69 @@ def __init__( model_params (Dict, optional): parameters for the function. **additional_info: Any additional Model info to be saved """ - ModelCache = AixplainCache("models", "models") - ModelEnum, ModelDetails = ModelCache.load_assets() - - if id in ModelDetails: - cached_model = ModelDetails[id] - - input_params = cached_model.get("params", input_params) - function = cached_model.get("function", {}).get("name", function) - name = cached_model.get("name", name) - supplier = cached_model.get("supplier", supplier) - - created_at_str = cached_model.get("createdAt") + model_details = self._get_model_details_from_cache(id) + + if model_details: + name = model_details.get("name", name) + description = model_details.get("description", description) + supplier = model_details.get("supplier", supplier) + version = model_details.get("version", version) + function = model_details.get("function", {}).get("name", function) + cost = model_details.get("pricing", cost) + input_params = model_details.get("params", input_params) + output_params = model_details.get("output_params", output_params) + model_params = model_details.get("model_params", model_params) + created_at_str = model_details.get("createdAt") if created_at_str: created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) - - cost = cached_model.get("pricing", cost) - + is_subscribed = model_details.get("is_subscribed", is_subscribed) super().__init__(id, name, description, supplier, version, cost=cost) + self.api_key = api_key - self.additional_info = additional_info - self.name = name - self.url = config.MODELS_RUN_URL - self.backend_url = config.BACKEND_URL self.function = function self.is_subscribed = is_subscribed self.created_at = created_at self.input_params = input_params self.output_params = output_params self.model_params = ModelParameters(model_params) if model_params else None + self.url = config.MODELS_RUN_URL + self.backend_url = config.BACKEND_URL + self.additional_info = additional_info + + if not model_details: + self._add_model_to_cache() + + @staticmethod + def _get_model_details_from_cache(model_id: str) -> Optional[Dict]: + """ + Private helper to load model details from the cache. + """ + try: + model_cache = AixplainCache(AssetType.MODELS) + _, models_data = model_cache.load_assets() + model_asset = models_data.get(model_id) + return model_asset.__dict__ if model_asset else None + except Exception as e: + logging.error(f"Error loading model from cache: {e}") + traceback.print_exc() + return None + + def _add_model_to_cache(self): + try: + + model_cache = AixplainCache(AssetType.MODELS) + _, models_data = model_cache.load_assets() + + models_data[self.id] = self + + serializable_data = {mid: vars(m) for mid, m in models_data.items()} + + model_cache.save_to_cache(model_cache.cache_file, {"items": list(serializable_data.values())}, model_cache.lock_file) + + logging.info(f"Model {self.id} added to cache.") + except Exception as e: + logging.error(f"Failed to add model {self.id} to cache: {e}") def to_dict(self) -> Dict: """Get the model info as a Dictionary