From 292f441ea063dd9fd370e82020bec0dede93e5b4 Mon Sep 17 00:00:00 2001 From: xainaz Date: Thu, 6 Feb 2025 00:38:05 +0300 Subject: [PATCH 01/13] caching for agents, pipelines and models --- aixplain/enums/function.py | 11 ++- aixplain/enums/language.py | 5 +- aixplain/enums/license.py | 6 +- aixplain/modules/agent/__init__.py | 17 ++++ aixplain/modules/agent/cache_agents.py | 101 ++++++++++++++++++++ aixplain/modules/model/__init__.py | 17 +++- aixplain/modules/model/cache_models.py | 86 +++++++++++++++++ aixplain/modules/pipeline/asset.py | 10 ++ aixplain/modules/pipeline/pipeline_cache.py | 81 ++++++++++++++++ aixplain/utils/cache_utils.py | 35 ++++--- 10 files changed, 347 insertions(+), 22 deletions(-) create mode 100644 aixplain/modules/agent/cache_agents.py create mode 100644 aixplain/modules/model/cache_models.py create mode 100644 aixplain/modules/pipeline/pipeline_cache.py diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index a51f5301..f8190c9d 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -28,15 +28,18 @@ from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER from typing import Tuple, Dict from aixplain.base.parameters import BaseParameters, Parameter +import os CACHE_FILE = f"{CACHE_FOLDER}/functions.json" - +LOCK_FILE = f"{CACHE_FILE}.lock" def load_functions(): api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL - resp = load_from_cache(CACHE_FILE) + os.makedirs(CACHE_FOLDER, exist_ok=True) + + resp = load_from_cache(CACHE_FILE, LOCK_FILE) if resp is None: url = urljoin(backend_url, "sdk/functions") @@ -47,7 +50,7 @@ def load_functions(): f'Functions could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp) + save_to_cache(CACHE_FILE, resp, LOCK_FILE) class Function(str, Enum): def __new__(cls, value): @@ -63,6 +66,8 @@ def get_input_output_params(self) -> Tuple[Dict, Dict]: Tuple[Dict, Dict]: A tuple containing (input_params, output_params) """ function_io = FunctionInputOutput.get(self.value, None) + if function_io is None: + return {}, {} input_params = {param["code"]: param for param in function_io["spec"]["params"]} output_params = {param["code"]: param for param in function_io["spec"]["output"]} return input_params, output_params diff --git a/aixplain/enums/language.py b/aixplain/enums/language.py index db66b2a1..52938793 100644 --- a/aixplain/enums/language.py +++ b/aixplain/enums/language.py @@ -28,10 +28,11 @@ from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/languages.json" +LOCK_FILE = f"{CACHE_FILE}.lock" def load_languages(): - resp = load_from_cache(CACHE_FILE) + resp = load_from_cache(CACHE_FILE,LOCK_FILE) if resp is None: api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL @@ -45,7 +46,7 @@ def load_languages(): f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp) + save_to_cache(CACHE_FILE, resp, LOCK_FILE) languages = {} for w in resp: diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index a860a539..22ad9b0d 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -29,10 +29,11 @@ from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" +LOCK_FILE = f"{CACHE_FILE}.lock" def load_licenses(): - resp = load_from_cache(CACHE_FILE) + resp = load_from_cache(CACHE_FILE, LOCK_FILE) try: if resp is None: @@ -48,7 +49,8 @@ def load_licenses(): f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp) + save_to_cache(CACHE_FILE, resp, LOCK_FILE) + licenses = {"_".join(w["name"].split()): w["id"] for w in resp} return Enum("License", licenses, type=str) except Exception: diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index a7707d05..9350a638 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -39,6 +39,7 @@ from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.enums import ResponseStatus from aixplain.modules.agent.utils import process_variables +from aixplain.modules.agent.cache_agents import load_agents from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin @@ -77,6 +78,22 @@ def __init__( tasks: List[AgentTask] = [], **additional_info, ) -> None: + AgentCache, AgentDetails = load_agents(cache_expiry=86400) + + if id in AgentDetails: + cached_agent= AgentDetails[id] + name=cached_agent["name"] + description=cached_agent["description"] + tools=cached_agent["tools"] + llm_id=cached_agent["llm_id"] + api_key=cached_agent["api_key"] + supplier=cached_agent["supplier"] + version=cached_agent["version"] + cost=cached_agent["cost"] + status=cached_agent["status"] + tasks=cached_agent["tasks"] + additional_info=cached_agent["additional_info"] + """Create an Agent with the necessary information. Args: diff --git a/aixplain/modules/agent/cache_agents.py b/aixplain/modules/agent/cache_agents.py new file mode 100644 index 00000000..a04bad00 --- /dev/null +++ b/aixplain/modules/agent/cache_agents.py @@ -0,0 +1,101 @@ +import os +import json +import logging +from datetime import datetime +from enum import Enum +from urllib.parse import urljoin +from typing import Dict, Optional, List, Tuple, Union, Text +from aixplain.utils import config +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.enums import Supplier +from aixplain.modules.agent.tool import Tool + +AGENT_CACHE_FILE = f"{CACHE_FOLDER}/agents.json" +LOCK_FILE = f"{AGENT_CACHE_FILE}.lock" + +def load_agents(cache_expiry: Optional[int] = None) -> Tuple[Enum, Dict]: + """ + Load AI agents from cache or fetch from backend if not cached. + Only agents with status "onboarded" should be cached. + + Args: + cache_expiry (int, optional): Expiry time in seconds. Default is 24 hours. + + Returns: + Tuple[Enum, Dict]: (Enum of agent IDs, Dictionary with agent details) + """ + if cache_expiry is None: + cache_expiry = 86400 + + os.makedirs(CACHE_FOLDER, exist_ok=True) + + cached_data = load_from_cache(AGENT_CACHE_FILE, LOCK_FILE) + + if cached_data is not None: + return parse_agents(cached_data) + + api_key = config.TEAM_API_KEY + backend_url = config.BACKEND_URL + url = urljoin(backend_url, "sdk/agents") + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + try: + response = _request_with_retry("get", url, headers=headers) + response.raise_for_status() + agents_data = response.json() + except Exception as e: + logging.error(f"Failed to fetch agents from API: {e}") + return Enum("Agent", {}), {} + + + onboarded_agents = [agent for agent in agents_data if agent.get("status", "").lower() == "onboarded"] + + save_to_cache(AGENT_CACHE_FILE, {"items": onboarded_agents}, LOCK_FILE) + + return parse_agents({"items": onboarded_agents}) + + +def parse_agents(agents_data: Dict) -> Tuple[Enum, Dict]: + """ + Convert agent data into an Enum and dictionary format for easy use. + + Args: + agents_data (Dict): JSON response with agents list. + + Returns: + - agents_enum: Enum with agent IDs. + - agents_details: Dictionary containing all agent parameters. + """ + if not agents_data["items"]: + logging.warning("No onboarded agents found.") + return Enum("Agent", {}), {} + + agents_enum = Enum( + "Agent", + {a["id"].upper().replace("-", "_"): a["id"] for a in agents_data["items"]}, + type=str, + ) + + agents_details = { + agent["id"]: { + "id": agent["id"], + "name": agent.get("name", ""), + "description": agent.get("description", ""), + "role": agent.get("role", ""), + "tools": [Tool(t) if isinstance(t, dict) else t for t in agent.get("tools", [])], + "llm_id": agent.get("llm_id", "6646261c6eb563165658bbb1"), + "supplier": agent.get("supplier", "aiXplain"), + "version": agent.get("version", "1.0"), + "status": agent.get("status", "onboarded"), + "created_at": agent.get("created_at", ""), + "tasks": agent.get("tasks", []), + **agent, + } + for agent in agents_data["items"] + } + + return agents_enum, agents_details + + +Agent, AgentDetails = load_agents() diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 5788daab..69e02cc6 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -34,7 +34,7 @@ from aixplain.modules.model.response import ModelResponse from aixplain.enums.response_status import ResponseStatus from aixplain.modules.model.model_parameters import ModelParameters - +from aixplain.modules.model.cache_models import load_models class Model(Asset): """This is ready-to-use AI model. This model can be run in both synchronous and asynchronous manner. @@ -91,6 +91,21 @@ def __init__( model_params (Dict, optional): parameters for the function. **additional_info: Any additional Model info to be saved """ + ModelCache, ModelDetails = load_models(cache_expiry=86400) + + if id in ModelDetails: + + cached_model = ModelDetails[id] + input_params = cached_model["input"] + api_key = cached_model["spec"].get("api_key", api_key) + additional_info = cached_model["spec"].get("additional_info", {}) + function = cached_model["spec"].get("function", function) + is_subscribed = cached_model["spec"].get("is_subscribed", is_subscribed) + created_at = cached_model["spec"].get("created_at", created_at) + model_params = cached_model["spec"].get("model_params", model_params) + output_params = cached_model["output"] + description = cached_model["spec"].get("description", description) + super().__init__(id, name, description, supplier, version, cost=cost) self.api_key = api_key self.additional_info = additional_info diff --git a/aixplain/modules/model/cache_models.py b/aixplain/modules/model/cache_models.py new file mode 100644 index 00000000..d5628587 --- /dev/null +++ b/aixplain/modules/model/cache_models.py @@ -0,0 +1,86 @@ +import os +import json +import time +import logging +from datetime import datetime +from enum import Enum +from urllib.parse import urljoin +from typing import Dict, Optional, Union, Text +from aixplain.utils import config +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.enums import Supplier, Function + +CACHE_FILE = f"{CACHE_FOLDER}/models.json" +LOCK_FILE = f"{CACHE_FILE}.lock" + +def load_models(cache_expiry: Optional[int] = None): + """ + Load models from cache or fetch from backend if not cached. + Only models with status "onboarded" should be cached. + + Args: + cache_expiry (int, optional): Expiry time in seconds. Default is user-configurable. + """ + + api_key = config.TEAM_API_KEY + backend_url = config.BACKEND_URL + + cached_data = load_from_cache(CACHE_FILE, LOCK_FILE) + + if cached_data is not None: + + return parse_models(cached_data) + + url = urljoin(backend_url, "sdk/models") + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + response = _request_with_retry("get", url, headers=headers) + if not 200 <= response.status_code < 300: + raise Exception(f"Models could not be loaded, API key '{api_key}' might be invalid.") + + models_data = response.json() + + onboarded_models = [model for model in models_data["items"] if model["status"].lower() == "onboarded"] + save_to_cache(CACHE_FILE, {"items": onboarded_models}, LOCK_FILE) + + return parse_models({"items": onboarded_models}) + +def parse_models(models_data): + """ + Convert model data into an Enum and dictionary format for easy use. + + Returns: + - models_enum: Enum with model IDs. + - models_details: Dictionary containing all model parameters. + """ + + if not models_data["items"]: + logging.warning("No onboarded models found.") + return Enum("Model", {}), {} + models_enum = Enum("Model", {m["id"].upper().replace("-", "_"): m["id"] for m in models_data["items"]}, type=str) + + models_details = { + model["id"]: { + "id": model["id"], + "name": model.get("name", ""), + "description": model.get("description", ""), + "api_key": model.get("api_key", config.TEAM_API_KEY), + "supplier": model.get("supplier", "aiXplain"), + "version": model.get("version"), + "function": model.get("function"), + "is_subscribed": model.get("is_subscribed", False), + "cost": model.get("cost"), + "created_at": model.get("created_at"), + "input_params": model.get("input_params"), + "output_params": model.get("output_params"), + "model_params": model.get("model_params"), + **model, + } + for model in models_data["items"] + } + + return models_enum, models_details + + +Model, ModelDetails = load_models() diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 88364873..0cf5ba63 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -31,6 +31,7 @@ from aixplain.utils.file_utils import _request_with_retry from typing import Dict, Optional, Text, Union from urllib.parse import urljoin +from aixplain.modules.pipeline.pipeline_cache import PipelineDetails class Pipeline(Asset): @@ -72,6 +73,15 @@ def __init__( status (AssetStatus, optional): Pipeline status. Defaults to AssetStatus.DRAFT. **additional_info: Any additional Pipeline info to be saved """ + if id in PipelineDetails: + cached_pipeline = PipelineDetails[id] + name = cached_pipeline["name"] + api_key = cached_pipeline["api_key"] + supplier = cached_pipeline["supplier"] + version = cached_pipeline["version"] + status = cached_pipeline["status"] + additional_info = cached_pipeline["architecture"] + if not name: raise ValueError("Pipeline name is required") diff --git a/aixplain/modules/pipeline/pipeline_cache.py b/aixplain/modules/pipeline/pipeline_cache.py new file mode 100644 index 00000000..4c2f817e --- /dev/null +++ b/aixplain/modules/pipeline/pipeline_cache.py @@ -0,0 +1,81 @@ +import os +import json +import time +import logging +from datetime import datetime +from enum import Enum +from urllib.parse import urljoin +from typing import Dict, Optional, Union, Text +from aixplain.utils import config +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.enums import Supplier + +PIPELINE_CACHE_FILE = f"{CACHE_FOLDER}/pipelines.json" +LOCK_FILE = f"{PIPELINE_CACHE_FILE}.lock" +def load_pipelines(cache_expiry: Optional[int] = None): + """ + Load pipelines from cache or fetch from backend if not cached. + Only pipelines with status "onboarded" should be cached. + + Args: + cache_expiry (int, optional): Expiry time in seconds. Default is user-configurable. + """ + if cache_expiry is None: + cache_expiry = 86400 + + api_key = config.TEAM_API_KEY + backend_url = config.BACKEND_URL + + cached_data = load_from_cache(PIPELINE_CACHE_FILE, LOCK_FILE) + if cached_data is not None: + return parse_pipelines(cached_data) + + url = urljoin(backend_url, "sdk/pipelines") + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + response = _request_with_retry("get", url, headers=headers) + if not 200 <= response.status_code < 300: + raise Exception(f"Pipelines could not be loaded, API key '{api_key}' might be invalid.") + + pipelines_data = response.json() + + onboarded_pipelines = [pipeline for pipeline in pipelines_data["items"] if pipeline["status"].lower() == "onboarded"] + + save_to_cache(PIPELINE_CACHE_FILE, {"items": onboarded_pipelines}, LOCK_FILE) + + return parse_pipelines({"items": onboarded_pipelines}) + +def parse_pipelines(pipelines_data): + """ + Convert pipeline data into an Enum and dictionary format for easy use. + + Returns: + - pipelines_enum: Enum with pipeline IDs. + - pipelines_details: Dictionary containing all pipeline parameters. + """ + if not pipelines_data["items"]: + logging.warning("No onboarded pipelines found.") + return Enum("Pipeline", {}), {} + + pipelines_enum = Enum("Pipeline", {p["id"].upper().replace("-", "_"): p["id"] for p in pipelines_data["items"]}, type=str) + + pipelines_details = { + pipeline["id"]: { + "id": pipeline["id"], + "name": pipeline.get("name", ""), + "description": pipeline.get("description", ""), + "api_key": pipeline.get("api_key", config.TEAM_API_KEY), + "supplier": pipeline.get("supplier", "aiXplain"), + "version": pipeline.get("version", "1.0"), + "status": pipeline.get("status", "onboarded"), + "created_at": pipeline.get("created_at"), + "architecture": pipeline.get("architecture", {}), + **pipeline, + } + for pipeline in pipelines_data["items"] + } + + return pipelines_enum, pipelines_details + +Pipeline, PipelineDetails = load_pipelines() diff --git a/aixplain/utils/cache_utils.py b/aixplain/utils/cache_utils.py index 5a0eb6ae..6717234f 100644 --- a/aixplain/utils/cache_utils.py +++ b/aixplain/utils/cache_utils.py @@ -2,26 +2,33 @@ import json import time import logging +from filelock import FileLock -CACHE_DURATION = 24 * 60 * 60 -CACHE_FOLDER = ".aixplain_cache" +CACHE_FOLDER = ".cache" +CACHE_FILE = f"{CACHE_FOLDER}/cache.json" +LOCK_FILE = f"{CACHE_FILE}.lock" +DEFAULT_CACHE_EXPIRY = 86400 +def get_cache_expiry(): + return int(os.getenv("CACHE_EXPIRY_TIME", DEFAULT_CACHE_EXPIRY)) -def save_to_cache(cache_file, data): + +def save_to_cache(cache_file, data, lock_file): try: os.makedirs(os.path.dirname(cache_file), exist_ok=True) - with open(cache_file, "w") as f: - json.dump({"timestamp": time.time(), "data": data}, f) + with FileLock(lock_file): + with open(cache_file, "w") as f: + json.dump({"timestamp": time.time(), "data": data}, f) except Exception as e: logging.error(f"Failed to save cache to {cache_file}: {e}") - -def load_from_cache(cache_file): - if os.path.exists(cache_file) is True: - with open(cache_file, "r") as f: - cache_data = json.load(f) - if time.time() - cache_data["timestamp"] < CACHE_DURATION: - return cache_data["data"] - else: - return None +def load_from_cache(cache_file, lock_file): + if os.path.exists(cache_file): + with FileLock(lock_file): + with open(cache_file, "r") as f: + cache_data = json.load(f) + if time.time() - cache_data["timestamp"] < int(get_cache_expiry()): + return cache_data["data"] + else: + return None return None From 6c8e3cf6b515f7085750d56babbdf6d6aa7cdd21 Mon Sep 17 00:00:00 2001 From: xainaz Date: Thu, 6 Feb 2025 00:48:08 +0300 Subject: [PATCH 02/13] formatting --- aixplain/enums/function.py | 3 +- aixplain/enums/language.py | 4 +- aixplain/enums/license.py | 2 +- .../factories/team_agent_factory/__init__.py | 2 +- aixplain/modules/agent/__init__.py | 24 ++++----- aixplain/modules/agent/cache_agents.py | 16 +++--- aixplain/modules/model/__init__.py | 1 + aixplain/modules/model/cache_models.py | 8 +-- aixplain/modules/model/utility_model.py | 16 +++--- aixplain/modules/pipeline/asset.py | 51 ++++++------------- aixplain/modules/pipeline/pipeline_cache.py | 12 +++-- aixplain/utils/cache_utils.py | 2 + aixplain/utils/file_utils.py | 4 +- .../model/run_utility_model_test.py | 27 ++++++---- tests/unit/utility_test.py | 44 ++++++++++++---- tests/unit/utility_tool_decorator_test.py | 38 +++++++------- 16 files changed, 138 insertions(+), 116 deletions(-) diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index f8190c9d..a77f3cfc 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -31,7 +31,8 @@ import os CACHE_FILE = f"{CACHE_FOLDER}/functions.json" -LOCK_FILE = f"{CACHE_FILE}.lock" +LOCK_FILE = f"{CACHE_FILE}.lock" + def load_functions(): api_key = config.TEAM_API_KEY diff --git a/aixplain/enums/language.py b/aixplain/enums/language.py index 52938793..c129822f 100644 --- a/aixplain/enums/language.py +++ b/aixplain/enums/language.py @@ -28,11 +28,11 @@ from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/languages.json" -LOCK_FILE = f"{CACHE_FILE}.lock" +LOCK_FILE = f"{CACHE_FILE}.lock" def load_languages(): - resp = load_from_cache(CACHE_FILE,LOCK_FILE) + resp = load_from_cache(CACHE_FILE, LOCK_FILE) if resp is None: api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index 22ad9b0d..f9758b84 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -29,7 +29,7 @@ from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" -LOCK_FILE = f"{CACHE_FILE}.lock" +LOCK_FILE = f"{CACHE_FILE}.lock" def load_licenses(): diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index ea145d9a..c9b7e6cc 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -62,7 +62,7 @@ def create( from aixplain.modules.agent import Agent assert isinstance(agent, Agent), "TeamAgent Onboarding Error: Agents must be instances of Agent class" - + mentalist_and_inspector_llm_id = None if use_inspector or use_mentalist_and_inspector: mentalist_and_inspector_llm_id = llm_id diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 9350a638..441c20fa 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -81,18 +81,18 @@ def __init__( AgentCache, AgentDetails = load_agents(cache_expiry=86400) if id in AgentDetails: - cached_agent= AgentDetails[id] - name=cached_agent["name"] - description=cached_agent["description"] - tools=cached_agent["tools"] - llm_id=cached_agent["llm_id"] - api_key=cached_agent["api_key"] - supplier=cached_agent["supplier"] - version=cached_agent["version"] - cost=cached_agent["cost"] - status=cached_agent["status"] - tasks=cached_agent["tasks"] - additional_info=cached_agent["additional_info"] + cached_agent = AgentDetails[id] + name = cached_agent["name"] + description = cached_agent["description"] + tools = cached_agent["tools"] + llm_id = cached_agent["llm_id"] + api_key = cached_agent["api_key"] + supplier = cached_agent["supplier"] + version = cached_agent["version"] + cost = cached_agent["cost"] + status = cached_agent["status"] + tasks = cached_agent["tasks"] + additional_info = cached_agent["additional_info"] """Create an Agent with the necessary information. diff --git a/aixplain/modules/agent/cache_agents.py b/aixplain/modules/agent/cache_agents.py index a04bad00..4d38344f 100644 --- a/aixplain/modules/agent/cache_agents.py +++ b/aixplain/modules/agent/cache_agents.py @@ -9,10 +9,11 @@ from aixplain.utils.request_utils import _request_with_retry from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER from aixplain.enums import Supplier -from aixplain.modules.agent.tool import Tool +from aixplain.modules.agent.tool import Tool AGENT_CACHE_FILE = f"{CACHE_FOLDER}/agents.json" -LOCK_FILE = f"{AGENT_CACHE_FILE}.lock" +LOCK_FILE = f"{AGENT_CACHE_FILE}.lock" + def load_agents(cache_expiry: Optional[int] = None) -> Tuple[Enum, Dict]: """ @@ -26,7 +27,7 @@ def load_agents(cache_expiry: Optional[int] = None) -> Tuple[Enum, Dict]: Tuple[Enum, Dict]: (Enum of agent IDs, Dictionary with agent details) """ if cache_expiry is None: - cache_expiry = 86400 + cache_expiry = 86400 os.makedirs(CACHE_FOLDER, exist_ok=True) @@ -42,12 +43,11 @@ def load_agents(cache_expiry: Optional[int] = None) -> Tuple[Enum, Dict]: try: response = _request_with_retry("get", url, headers=headers) - response.raise_for_status() + response.raise_for_status() agents_data = response.json() except Exception as e: logging.error(f"Failed to fetch agents from API: {e}") - return Enum("Agent", {}), {} - + return Enum("Agent", {}), {} onboarded_agents = [agent for agent in agents_data if agent.get("status", "").lower() == "onboarded"] @@ -67,7 +67,7 @@ def parse_agents(agents_data: Dict) -> Tuple[Enum, Dict]: - agents_enum: Enum with agent IDs. - agents_details: Dictionary containing all agent parameters. """ - if not agents_data["items"]: + if not agents_data["items"]: logging.warning("No onboarded agents found.") return Enum("Agent", {}), {} @@ -90,7 +90,7 @@ def parse_agents(agents_data: Dict) -> Tuple[Enum, Dict]: "status": agent.get("status", "onboarded"), "created_at": agent.get("created_at", ""), "tasks": agent.get("tasks", []), - **agent, + **agent, } for agent in agents_data["items"] } diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 69e02cc6..f3962e01 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -36,6 +36,7 @@ from aixplain.modules.model.model_parameters import ModelParameters from aixplain.modules.model.cache_models import load_models + class Model(Asset): """This is ready-to-use AI model. This model can be run in both synchronous and asynchronous manner. diff --git a/aixplain/modules/model/cache_models.py b/aixplain/modules/model/cache_models.py index d5628587..a336c488 100644 --- a/aixplain/modules/model/cache_models.py +++ b/aixplain/modules/model/cache_models.py @@ -12,7 +12,8 @@ from aixplain.enums import Supplier, Function CACHE_FILE = f"{CACHE_FOLDER}/models.json" -LOCK_FILE = f"{CACHE_FILE}.lock" +LOCK_FILE = f"{CACHE_FILE}.lock" + def load_models(cache_expiry: Optional[int] = None): """ @@ -46,6 +47,7 @@ def load_models(cache_expiry: Optional[int] = None): return parse_models({"items": onboarded_models}) + def parse_models(models_data): """ Convert model data into an Enum and dictionary format for easy use. @@ -55,7 +57,7 @@ def parse_models(models_data): - models_details: Dictionary containing all model parameters. """ - if not models_data["items"]: + if not models_data["items"]: logging.warning("No onboarded models found.") return Enum("Model", {}), {} models_enum = Enum("Model", {m["id"].upper().replace("-", "_"): m["id"] for m in models_data["items"]}, type=str) @@ -75,7 +77,7 @@ def parse_models(models_data): "input_params": model.get("input_params"), "output_params": model.get("output_params"), "model_params": model.get("model_params"), - **model, + **model, } for model in models_data["items"] } diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 6a323f45..a0b0ca15 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -43,17 +43,20 @@ def validate(self): def to_dict(self): return {"name": self.name, "description": self.description, "type": self.type.value} + # Tool decorator -def utility_tool(name: Text, description: Text, inputs: List[UtilityModelInput] = None, output_examples: Text = "", status = AssetStatus.DRAFT): +def utility_tool( + name: Text, description: Text, inputs: List[UtilityModelInput] = None, output_examples: Text = "", status=AssetStatus.DRAFT +): """Decorator for utility tool functions - + Args: name: Name of the utility tool description: Description of what the utility tool does inputs: List of input parameters, must be UtilityModelInput objects output_examples: Examples of expected outputs status: Asset status - + Raises: ValueError: If name or description is empty TypeError: If inputs contains non-UtilityModelInput objects @@ -63,7 +66,7 @@ def utility_tool(name: Text, description: Text, inputs: List[UtilityModelInput] raise ValueError("Utility tool name cannot be empty") if not description or not description.strip(): raise ValueError("Utility tool description cannot be empty") - + # Validate inputs if inputs is not None: if not isinstance(inputs, list): @@ -71,7 +74,7 @@ def utility_tool(name: Text, description: Text, inputs: List[UtilityModelInput] for input_param in inputs: if not isinstance(input_param, UtilityModelInput): raise TypeError(f"Invalid input parameter: {input_param}. All inputs must be UtilityModelInput objects") - + def decorator(func): func._is_utility_tool = True # Mark function as utility tool func._tool_name = name.strip() @@ -80,6 +83,7 @@ def decorator(func): func._tool_output_examples = output_examples func._tool_status = status return func + return decorator @@ -116,7 +120,7 @@ def __init__( function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, - status: AssetStatus = AssetStatus.ONBOARDED,# TODO: change to draft when we have the backend ready + status: AssetStatus = AssetStatus.ONBOARDED, # TODO: change to draft when we have the backend ready **additional_info, ) -> None: """Utility Model Init diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 0cf5ba63..216274fc 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -81,7 +81,7 @@ def __init__( version = cached_pipeline["version"] status = cached_pipeline["status"] additional_info = cached_pipeline["architecture"] - + if not name: raise ValueError("Pipeline name is required") @@ -122,9 +122,7 @@ def __polling( while not completed and (end - start) < timeout: try: response_body = self.poll(poll_url, name=name) - logging.debug( - f"Polling for Pipeline: Status of polling for {name} : {response_body}" - ) + logging.debug(f"Polling for Pipeline: Status of polling for {name} : {response_body}") completed = response_body["completed"] end = time.time() @@ -136,13 +134,9 @@ def __polling( logging.error(f"Polling for Pipeline: polling for {name} : Continue") if response_body and response_body["status"] == "SUCCESS": try: - logging.debug( - f"Polling for Pipeline: Final status of polling for {name} : SUCCESS - {response_body}" - ) + logging.debug(f"Polling for Pipeline: Final status of polling for {name} : SUCCESS - {response_body}") except Exception: - logging.error( - f"Polling for Pipeline: Final status of polling for {name} : ERROR - {response_body}" - ) + logging.error(f"Polling for Pipeline: Final status of polling for {name} : ERROR - {response_body}") else: logging.error( f"Polling for Pipeline: Final status of polling for {name} : No response in {timeout} seconds - {response_body}" @@ -172,9 +166,7 @@ def poll(self, poll_url: Text, name: Text = "pipeline_process") -> Dict: resp["data"] = json.loads(resp["data"])["response"] except Exception: resp = r.json() - logging.info( - f"Single Poll for Pipeline: Status of polling for {name} : {resp}" - ) + logging.info(f"Single Poll for Pipeline: Status of polling for {name} : {resp}") except Exception: resp = {"status": "FAILED"} return resp @@ -196,9 +188,7 @@ def _should_fallback_to_v2(self, response: Dict, version: str) -> bool: should_fallback = False if "status" not in response or response["status"] == "FAILED": should_fallback = True - elif response["status"] == "SUCCESS" and ( - "data" not in response or not response["data"] - ): + elif response["status"] == "SUCCESS" and ("data" not in response or not response["data"]): should_fallback = True # Check for conditions that require a fallback @@ -304,10 +294,7 @@ def __prepare_payload( try: payload = json.loads(data) if isinstance(payload, dict) is False: - if ( - isinstance(payload, int) is True - or isinstance(payload, float) is True - ): + if isinstance(payload, int) is True or isinstance(payload, float) is True: payload = str(payload) payload = {"data": payload} except Exception: @@ -345,9 +332,7 @@ def __prepare_payload( asset_payload["dataAsset"]["dataset_id"] = dasset.id source_data_list = [ - dfield - for dfield in dasset.source_data - if dasset.source_data[dfield].id == data[node_label] + dfield for dfield in dasset.source_data if dasset.source_data[dfield].id == data[node_label] ] if len(source_data_list) > 0: @@ -420,9 +405,7 @@ def run_async( try: if 200 <= r.status_code < 300: resp = r.json() - logging.info( - f"Result of request for {name} - {r.status_code} - {resp}" - ) + logging.info(f"Result of request for {name} - {r.status_code} - {resp}") poll_url = resp["url"] response = {"status": "IN_PROGRESS", "url": poll_url} else: @@ -438,7 +421,9 @@ def run_async( error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." else: status_code = str(r.status_code) - error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." + error = ( + f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." + ) response = {"status": "FAILED", "error_message": error} logging.error(f"Error in request for {name} - {r.status_code}: {error}") except Exception: @@ -487,9 +472,7 @@ def update( for i, node in enumerate(pipeline["nodes"]): if "functionType" in node: - pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i][ - "functionType" - ].lower() + pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i]["functionType"].lower() # prepare payload status = "draft" if save_as_asset is True: @@ -507,9 +490,7 @@ def update( "Authorization": f"Token {api_key}", "Content-Type": "application/json", } - logging.info( - f"Start service for PUT Update Pipeline - {url} - {headers} - {json.dumps(payload)}" - ) + logging.info(f"Start service for PUT Update Pipeline - {url} - {headers} - {json.dumps(payload)}") r = _request_with_retry("put", url, headers=headers, json=payload) response = r.json() logging.info(f"Pipeline {response['id']} Updated.") @@ -564,9 +545,7 @@ def save( for i, node in enumerate(pipeline["nodes"]): if "functionType" in node: - pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i][ - "functionType" - ].lower() + pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i]["functionType"].lower() # prepare payload status = "draft" if save_as_asset is True: diff --git a/aixplain/modules/pipeline/pipeline_cache.py b/aixplain/modules/pipeline/pipeline_cache.py index 4c2f817e..2d90e069 100644 --- a/aixplain/modules/pipeline/pipeline_cache.py +++ b/aixplain/modules/pipeline/pipeline_cache.py @@ -13,6 +13,8 @@ PIPELINE_CACHE_FILE = f"{CACHE_FOLDER}/pipelines.json" LOCK_FILE = f"{PIPELINE_CACHE_FILE}.lock" + + def load_pipelines(cache_expiry: Optional[int] = None): """ Load pipelines from cache or fetch from backend if not cached. @@ -22,7 +24,7 @@ def load_pipelines(cache_expiry: Optional[int] = None): cache_expiry (int, optional): Expiry time in seconds. Default is user-configurable. """ if cache_expiry is None: - cache_expiry = 86400 + cache_expiry = 86400 api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL @@ -46,6 +48,7 @@ def load_pipelines(cache_expiry: Optional[int] = None): return parse_pipelines({"items": onboarded_pipelines}) + def parse_pipelines(pipelines_data): """ Convert pipeline data into an Enum and dictionary format for easy use. @@ -54,10 +57,10 @@ def parse_pipelines(pipelines_data): - pipelines_enum: Enum with pipeline IDs. - pipelines_details: Dictionary containing all pipeline parameters. """ - if not pipelines_data["items"]: + if not pipelines_data["items"]: logging.warning("No onboarded pipelines found.") return Enum("Pipeline", {}), {} - + pipelines_enum = Enum("Pipeline", {p["id"].upper().replace("-", "_"): p["id"] for p in pipelines_data["items"]}, type=str) pipelines_details = { @@ -71,11 +74,12 @@ def parse_pipelines(pipelines_data): "status": pipeline.get("status", "onboarded"), "created_at": pipeline.get("created_at"), "architecture": pipeline.get("architecture", {}), - **pipeline, + **pipeline, } for pipeline in pipelines_data["items"] } return pipelines_enum, pipelines_details + Pipeline, PipelineDetails = load_pipelines() diff --git a/aixplain/utils/cache_utils.py b/aixplain/utils/cache_utils.py index 6717234f..01981701 100644 --- a/aixplain/utils/cache_utils.py +++ b/aixplain/utils/cache_utils.py @@ -9,6 +9,7 @@ LOCK_FILE = f"{CACHE_FILE}.lock" DEFAULT_CACHE_EXPIRY = 86400 + def get_cache_expiry(): return int(os.getenv("CACHE_EXPIRY_TIME", DEFAULT_CACHE_EXPIRY)) @@ -22,6 +23,7 @@ def save_to_cache(cache_file, data, lock_file): except Exception as e: logging.error(f"Failed to save cache to {cache_file}: {e}") + def load_from_cache(cache_file, lock_file): if os.path.exists(cache_file): with FileLock(lock_file): diff --git a/aixplain/utils/file_utils.py b/aixplain/utils/file_utils.py index d39ca2b9..554b80d9 100644 --- a/aixplain/utils/file_utils.py +++ b/aixplain/utils/file_utils.py @@ -153,7 +153,9 @@ def upload_data( raise Exception("File Uploading Error: Failure on Uploading to S3.") -def s3_to_csv(s3_url: Text, aws_credentials: Optional[Dict[Text, Text]] = {"AWS_ACCESS_KEY_ID": None, "AWS_SECRET_ACCESS_KEY": None}) -> Text: +def s3_to_csv( + s3_url: Text, aws_credentials: Optional[Dict[Text, Text]] = {"AWS_ACCESS_KEY_ID": None, "AWS_SECRET_ACCESS_KEY": None} +) -> Text: """Convert s3 url to a csv file and download the file in `download_path` Args: diff --git a/tests/functional/model/run_utility_model_test.py b/tests/functional/model/run_utility_model_test.py index b9ef5465..17257b2a 100644 --- a/tests/functional/model/run_utility_model_test.py +++ b/tests/functional/model/run_utility_model_test.py @@ -2,6 +2,7 @@ from aixplain.modules.model.utility_model import UtilityModelInput, utility_tool from aixplain.enums import DataType + def test_run_utility_model(): utility_model = None try: @@ -36,22 +37,24 @@ def test_run_utility_model(): if utility_model: utility_model.delete() + def test_utility_model_with_decorator(): utility_model = None try: + @utility_tool( - name="add_numbers_test name", - description="Adds two numbers together.", - inputs=[ + name="add_numbers_test name", + description="Adds two numbers together.", + inputs=[ UtilityModelInput(name="num1", type=DataType.NUMBER, description="The first number."), - UtilityModelInput(name="num2", type=DataType.NUMBER, description="The second number.") + UtilityModelInput(name="num2", type=DataType.NUMBER, description="The second number."), ], ) def add_numbers(num1: int, num2: int) -> int: return num1 + num2 utility_model = ModelFactory.create_utility_model(code=add_numbers) - + assert utility_model.id is not None assert len(utility_model.inputs) == 2 assert utility_model.inputs[0].name == "num1" @@ -64,16 +67,18 @@ def add_numbers(num1: int, num2: int) -> int: if utility_model: utility_model.delete() + def test_utility_model_string_concatenation(): utility_model = None try: + @utility_tool( name="concatenate_strings", description="Concatenates two strings.", inputs=[ UtilityModelInput(name="str1", type=DataType.TEXT, description="The first string."), UtilityModelInput(name="str2", type=DataType.TEXT, description="The second string."), - ] + ], ) def concatenate_strings(str1: str, str2: str) -> str: """Concatenates two strings and returns the result.""" @@ -96,6 +101,7 @@ def concatenate_strings(str1: str, str2: str) -> str: if utility_model: utility_model.delete() + def test_utility_model_code_as_string(): utility_model = None try: @@ -108,10 +114,7 @@ def multiply_numbers(int1: int, int2: int) -> int: \"\"\"Multiply two numbers and returns the result.\"\"\" return int1 * int2 """ - utility_model = ModelFactory.create_utility_model( - name="Multiply Numbers Test", - code=code - ) + utility_model = ModelFactory.create_utility_model(name="Multiply Numbers Test", code=code) assert utility_model.id is not None assert len(utility_model.inputs) == 2 @@ -123,13 +126,15 @@ def multiply_numbers(int1: int, int2: int) -> int: if utility_model: utility_model.delete() + def test_utility_model_simple_function(): utility_model = None try: + def test_string(input: str): """test string""" return input - + utility_model = ModelFactory.create_utility_model( name="String Model Test", code=test_string, diff --git a/tests/unit/utility_test.py b/tests/unit/utility_test.py index 305c6a52..8adfe5bc 100644 --- a/tests/unit/utility_test.py +++ b/tests/unit/utility_test.py @@ -25,7 +25,9 @@ def test_utility_model(): assert utility_model.name == "utility_model_test" assert utility_model.description == "utility_model_test" assert utility_model.code == "utility_model_test" - assert utility_model.inputs == [UtilityModelInput(name="input_string", description="The input_string input is a text", type=DataType.TEXT)] + assert utility_model.inputs == [ + UtilityModelInput(name="input_string", description="The input_string input is a text", type=DataType.TEXT) + ] assert utility_model.output_examples == "output_description" @@ -87,8 +89,14 @@ def test_utility_model_to_dict(): def test_update_utility_model(): with requests_mock.Mocker() as mock: - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): + with patch( + "aixplain.factories.file_factory.FileFactory.to_link", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): + with patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): with patch( "aixplain.modules.model.utils.parse_code", return_value=( @@ -127,8 +135,14 @@ def test_update_utility_model(): def test_save_utility_model(): with requests_mock.Mocker() as mock: - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): + with patch( + "aixplain.factories.file_factory.FileFactory.to_link", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): + with patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): with patch( "aixplain.modules.model.utils.parse_code", return_value=( @@ -170,8 +184,14 @@ def test_save_utility_model(): def test_delete_utility_model(): with requests_mock.Mocker() as mock: - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): + with patch( + "aixplain.factories.file_factory.FileFactory.to_link", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): + with patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): mock.delete(urljoin(config.BACKEND_URL, "sdk/utilities/123"), status_code=200, json={"id": "123"}) utility_model = UtilityModel( id="123", @@ -241,8 +261,14 @@ def main(originCode): def test_validate_new_model(): """Test validation for a new model""" - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n'): + with patch( + "aixplain.factories.file_factory.FileFactory.to_link", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): + with patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value='def main(input_string:str):\n """\n Get driving directions from start_location to end_location\n """\n return f"This is the output for input: {input_string}"\n', + ): # Test with valid inputs utility_model = UtilityModel( id="", # Empty ID for new model diff --git a/tests/unit/utility_tool_decorator_test.py b/tests/unit/utility_tool_decorator_test.py index f9c87f02..c63aa5f6 100644 --- a/tests/unit/utility_tool_decorator_test.py +++ b/tests/unit/utility_tool_decorator_test.py @@ -3,16 +3,15 @@ from aixplain.enums.asset_status import AssetStatus from aixplain.modules.model.utility_model import utility_tool, UtilityModelInput + def test_utility_tool_basic_decoration(): """Test basic decoration with minimal parameters""" - @utility_tool( - name="test_function", - description="Test function description" - ) + + @utility_tool(name="test_function", description="Test function description") def test_func(input_text: str) -> str: return input_text - assert hasattr(test_func, '_is_utility_tool') + assert hasattr(test_func, "_is_utility_tool") assert test_func._is_utility_tool is True assert test_func._tool_name == "test_function" assert test_func._tool_description == "Test function description" @@ -20,19 +19,20 @@ def test_func(input_text: str) -> str: assert test_func._tool_output_examples == "" assert test_func._tool_status == AssetStatus.DRAFT + def test_utility_tool_with_all_parameters(): """Test decoration with all optional parameters""" inputs = [ UtilityModelInput(name="text_input", type=DataType.TEXT, description="A text input"), - UtilityModelInput(name="num_input", type=DataType.NUMBER, description="A number input") + UtilityModelInput(name="num_input", type=DataType.NUMBER, description="A number input"), ] - + @utility_tool( name="full_test_function", description="Full test function description", inputs=inputs, output_examples="Example output: Hello World", - status=AssetStatus.ONBOARDED + status=AssetStatus.ONBOARDED, ) def test_func(text_input: str, num_input: int) -> str: return f"{text_input} {num_input}" @@ -45,32 +45,28 @@ def test_func(text_input: str, num_input: int) -> str: assert test_func._tool_output_examples == "Example output: Hello World" assert test_func._tool_status == AssetStatus.ONBOARDED + def test_utility_tool_function_still_callable(): """Test that decorated function remains callable""" - @utility_tool( - name="callable_test", - description="Test function callable" - ) + + @utility_tool(name="callable_test", description="Test function callable") def test_func(x: int, y: int) -> int: return x + y assert test_func(2, 3) == 5 assert test_func._is_utility_tool is True + def test_utility_tool_invalid_inputs(): """Test validation of invalid inputs""" with pytest.raises(ValueError): - @utility_tool( - name="", # Empty name should raise error - description="Test description" - ) + + @utility_tool(name="", description="Test description") # Empty name should raise error def test_func(): pass with pytest.raises(ValueError): - @utility_tool( - name="test_name", - description="" # Empty description should raise error - ) + + @utility_tool(name="test_name", description="") # Empty description should raise error def test_func(): - pass \ No newline at end of file + pass From 7872aa40615dde5468a46858f59b18cc8e121843 Mon Sep 17 00:00:00 2001 From: xainaz Date: Fri, 7 Feb 2025 15:31:09 +0300 Subject: [PATCH 03/13] Agent Cache Class added --- aixplain/enums/__init__.py | 1 + aixplain/enums/aixplain_cache.py | 114 ++++++++++++++++++++ aixplain/modules/agent/__init__.py | 6 +- aixplain/modules/agent/cache_agents.py | 101 ----------------- aixplain/modules/model/__init__.py | 6 +- aixplain/modules/model/cache_models.py | 88 --------------- aixplain/modules/pipeline/asset.py | 4 +- aixplain/modules/pipeline/pipeline_cache.py | 85 --------------- 8 files changed, 124 insertions(+), 281 deletions(-) create mode 100644 aixplain/enums/aixplain_cache.py delete mode 100644 aixplain/modules/agent/cache_agents.py delete mode 100644 aixplain/modules/model/cache_models.py delete mode 100644 aixplain/modules/pipeline/pipeline_cache.py diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index ef497ddd..201dd9e7 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -14,3 +14,4 @@ from .sort_by import SortBy from .sort_order import SortOrder from .response_status import ResponseStatus +from .aixplain_cache import AixplainCache \ No newline at end of file diff --git a/aixplain/enums/aixplain_cache.py b/aixplain/enums/aixplain_cache.py new file mode 100644 index 00000000..339360a3 --- /dev/null +++ b/aixplain/enums/aixplain_cache.py @@ -0,0 +1,114 @@ +import os +import json +import logging +from datetime import datetime +from enum import Enum +from urllib.parse import urljoin +from typing import Dict, Optional, Tuple, List +from aixplain.utils import config +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER + + +class AixplainCache: + """ + A modular caching system to handle different asset types (Models, Pipelines, Agents). + This class reduces code repetition and allows easier maintenance. + """ + + def __init__(self, asset_type: str, cache_filename: str): + """ + Initialize the cache for a given asset type. + + Args: + asset_type (str): Type of the asset (e.g., "models", "pipelines", "agents"). + cache_filename (str): Filename for storing cached data. + """ + self.asset_type = asset_type + self.cache_file = f"{CACHE_FOLDER}/{cache_filename}.json" + self.lock_file = f"{self.cache_file}.lock" + os.makedirs(CACHE_FOLDER, exist_ok=True) # Ensure cache folder exists + + def load_assets(self, cache_expiry: Optional[int] = 86400) -> Tuple[Enum, Dict]: + """ + Load assets from cache or fetch from backend if not cached. + + Args: + cache_expiry (int, optional): Expiry time in seconds. Default is 24 hours. + + Returns: + Tuple[Enum, Dict]: (Enum of asset IDs, Dictionary with asset details) + """ + cached_data = load_from_cache(self.cache_file, self.lock_file) + if cached_data is not None: + return self.parse_assets(cached_data) + + api_key = config.TEAM_API_KEY + backend_url = config.BACKEND_URL + url = urljoin(backend_url, f"sdk/{self.asset_type}") + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + try: + response = _request_with_retry("get", url, headers=headers) + response.raise_for_status() + assets_data = response.json() + except Exception as e: + logging.error(f"Failed to fetch {self.asset_type} from API: {e}") + return Enum(self.asset_type.capitalize(), {}), {} + + if "items" not in assets_data: + return Enum(self.asset_type.capitalize(), {}), {} + + onboarded_assets = [asset for asset in assets_data["items"] if asset.get("status", "").lower() == "onboarded"] + + save_to_cache(self.cache_file, {"items": onboarded_assets}, self.lock_file) + + return self.parse_assets({"items": onboarded_assets}) + + def parse_assets(self, assets_data: Dict) -> Tuple[Enum, Dict]: + """ + Convert asset data into an Enum and dictionary format for easy use. + + Args: + assets_data (Dict): JSON response with asset list. + + Returns: + - assets_enum: Enum with asset IDs. + - assets_details: Dictionary containing all asset parameters. + """ + if not assets_data["items"]: # Handle case where no assets are onboarded + logging.warning(f"No onboarded {self.asset_type} found.") + return Enum(self.asset_type.capitalize(), {}), {} + + assets_enum = Enum( + self.asset_type.capitalize(), + {a["id"].upper().replace("-", "_"): a["id"] for a in assets_data["items"]}, + type=str, + ) + + assets_details = { + asset["id"]: { + "id": asset["id"], + "name": asset.get("name", ""), + "description": asset.get("description", ""), + "api_key": asset.get("api_key", config.TEAM_API_KEY), + "supplier": asset.get("supplier", "aiXplain"), + "version": asset.get("version", "1.0"), + "status": asset.get("status", "onboarded"), + "created_at": asset.get("created_at", ""), + **asset, # Include any extra fields + } + for asset in assets_data["items"] + } + + return assets_enum, assets_details + + +ModelCache = AixplainCache("models", "models") +Model, ModelDetails = ModelCache.load_assets() + +PipelineCache = AixplainCache("pipelines", "pipelines") +Pipeline, PipelineDetails = PipelineCache.load_assets() + +AgentCache = AixplainCache("agents", "agents") +Agent, AgentDetails = AgentCache.load_assets() diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 441c20fa..26052e06 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -37,9 +37,8 @@ from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData -from aixplain.enums import ResponseStatus +from aixplain.enums import ResponseStatus, AixplainCache from aixplain.modules.agent.utils import process_variables -from aixplain.modules.agent.cache_agents import load_agents from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin @@ -78,7 +77,8 @@ def __init__( tasks: List[AgentTask] = [], **additional_info, ) -> None: - AgentCache, AgentDetails = load_agents(cache_expiry=86400) + AgentCache = AixplainCache("agents", "agents") + AgentEnum, AgentDetails = AgentCache.load_assets() if id in AgentDetails: cached_agent = AgentDetails[id] diff --git a/aixplain/modules/agent/cache_agents.py b/aixplain/modules/agent/cache_agents.py deleted file mode 100644 index 4d38344f..00000000 --- a/aixplain/modules/agent/cache_agents.py +++ /dev/null @@ -1,101 +0,0 @@ -import os -import json -import logging -from datetime import datetime -from enum import Enum -from urllib.parse import urljoin -from typing import Dict, Optional, List, Tuple, Union, Text -from aixplain.utils import config -from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER -from aixplain.enums import Supplier -from aixplain.modules.agent.tool import Tool - -AGENT_CACHE_FILE = f"{CACHE_FOLDER}/agents.json" -LOCK_FILE = f"{AGENT_CACHE_FILE}.lock" - - -def load_agents(cache_expiry: Optional[int] = None) -> Tuple[Enum, Dict]: - """ - Load AI agents from cache or fetch from backend if not cached. - Only agents with status "onboarded" should be cached. - - Args: - cache_expiry (int, optional): Expiry time in seconds. Default is 24 hours. - - Returns: - Tuple[Enum, Dict]: (Enum of agent IDs, Dictionary with agent details) - """ - if cache_expiry is None: - cache_expiry = 86400 - - os.makedirs(CACHE_FOLDER, exist_ok=True) - - cached_data = load_from_cache(AGENT_CACHE_FILE, LOCK_FILE) - - if cached_data is not None: - return parse_agents(cached_data) - - api_key = config.TEAM_API_KEY - backend_url = config.BACKEND_URL - url = urljoin(backend_url, "sdk/agents") - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - - try: - response = _request_with_retry("get", url, headers=headers) - response.raise_for_status() - agents_data = response.json() - except Exception as e: - logging.error(f"Failed to fetch agents from API: {e}") - return Enum("Agent", {}), {} - - onboarded_agents = [agent for agent in agents_data if agent.get("status", "").lower() == "onboarded"] - - save_to_cache(AGENT_CACHE_FILE, {"items": onboarded_agents}, LOCK_FILE) - - return parse_agents({"items": onboarded_agents}) - - -def parse_agents(agents_data: Dict) -> Tuple[Enum, Dict]: - """ - Convert agent data into an Enum and dictionary format for easy use. - - Args: - agents_data (Dict): JSON response with agents list. - - Returns: - - agents_enum: Enum with agent IDs. - - agents_details: Dictionary containing all agent parameters. - """ - if not agents_data["items"]: - logging.warning("No onboarded agents found.") - return Enum("Agent", {}), {} - - agents_enum = Enum( - "Agent", - {a["id"].upper().replace("-", "_"): a["id"] for a in agents_data["items"]}, - type=str, - ) - - agents_details = { - agent["id"]: { - "id": agent["id"], - "name": agent.get("name", ""), - "description": agent.get("description", ""), - "role": agent.get("role", ""), - "tools": [Tool(t) if isinstance(t, dict) else t for t in agent.get("tools", [])], - "llm_id": agent.get("llm_id", "6646261c6eb563165658bbb1"), - "supplier": agent.get("supplier", "aiXplain"), - "version": agent.get("version", "1.0"), - "status": agent.get("status", "onboarded"), - "created_at": agent.get("created_at", ""), - "tasks": agent.get("tasks", []), - **agent, - } - for agent in agents_data["items"] - } - - return agents_enum, agents_details - - -Agent, AgentDetails = load_agents() diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index f3962e01..3aee0a31 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -23,7 +23,7 @@ import time import logging import traceback -from aixplain.enums import Supplier, Function +from aixplain.enums import Supplier, Function, AixplainCache from aixplain.modules.asset import Asset from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config @@ -34,7 +34,6 @@ from aixplain.modules.model.response import ModelResponse from aixplain.enums.response_status import ResponseStatus from aixplain.modules.model.model_parameters import ModelParameters -from aixplain.modules.model.cache_models import load_models class Model(Asset): @@ -92,7 +91,8 @@ def __init__( model_params (Dict, optional): parameters for the function. **additional_info: Any additional Model info to be saved """ - ModelCache, ModelDetails = load_models(cache_expiry=86400) + ModelCache = AixplainCache("models", "models") + ModelEnum, ModelDetails = ModelCache.load_assets() if id in ModelDetails: diff --git a/aixplain/modules/model/cache_models.py b/aixplain/modules/model/cache_models.py deleted file mode 100644 index a336c488..00000000 --- a/aixplain/modules/model/cache_models.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -import json -import time -import logging -from datetime import datetime -from enum import Enum -from urllib.parse import urljoin -from typing import Dict, Optional, Union, Text -from aixplain.utils import config -from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER -from aixplain.enums import Supplier, Function - -CACHE_FILE = f"{CACHE_FOLDER}/models.json" -LOCK_FILE = f"{CACHE_FILE}.lock" - - -def load_models(cache_expiry: Optional[int] = None): - """ - Load models from cache or fetch from backend if not cached. - Only models with status "onboarded" should be cached. - - Args: - cache_expiry (int, optional): Expiry time in seconds. Default is user-configurable. - """ - - api_key = config.TEAM_API_KEY - backend_url = config.BACKEND_URL - - cached_data = load_from_cache(CACHE_FILE, LOCK_FILE) - - if cached_data is not None: - - return parse_models(cached_data) - - url = urljoin(backend_url, "sdk/models") - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - - response = _request_with_retry("get", url, headers=headers) - if not 200 <= response.status_code < 300: - raise Exception(f"Models could not be loaded, API key '{api_key}' might be invalid.") - - models_data = response.json() - - onboarded_models = [model for model in models_data["items"] if model["status"].lower() == "onboarded"] - save_to_cache(CACHE_FILE, {"items": onboarded_models}, LOCK_FILE) - - return parse_models({"items": onboarded_models}) - - -def parse_models(models_data): - """ - Convert model data into an Enum and dictionary format for easy use. - - Returns: - - models_enum: Enum with model IDs. - - models_details: Dictionary containing all model parameters. - """ - - if not models_data["items"]: - logging.warning("No onboarded models found.") - return Enum("Model", {}), {} - models_enum = Enum("Model", {m["id"].upper().replace("-", "_"): m["id"] for m in models_data["items"]}, type=str) - - models_details = { - model["id"]: { - "id": model["id"], - "name": model.get("name", ""), - "description": model.get("description", ""), - "api_key": model.get("api_key", config.TEAM_API_KEY), - "supplier": model.get("supplier", "aiXplain"), - "version": model.get("version"), - "function": model.get("function"), - "is_subscribed": model.get("is_subscribed", False), - "cost": model.get("cost"), - "created_at": model.get("created_at"), - "input_params": model.get("input_params"), - "output_params": model.get("output_params"), - "model_params": model.get("model_params"), - **model, - } - for model in models_data["items"] - } - - return models_enum, models_details - - -Model, ModelDetails = load_models() diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 216274fc..ee2c421e 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -31,7 +31,7 @@ from aixplain.utils.file_utils import _request_with_retry from typing import Dict, Optional, Text, Union from urllib.parse import urljoin -from aixplain.modules.pipeline.pipeline_cache import PipelineDetails +from aixplain.enums import AixplainCache class Pipeline(Asset): @@ -61,6 +61,8 @@ def __init__( status: AssetStatus = AssetStatus.DRAFT, **additional_info, ) -> None: + PipelineCache = AixplainCache("pipelines", "pipelines") + PipelineEnum, PipelineDetails = PipelineCache.load_assets() """Create a Pipeline with the necessary information Args: diff --git a/aixplain/modules/pipeline/pipeline_cache.py b/aixplain/modules/pipeline/pipeline_cache.py deleted file mode 100644 index 2d90e069..00000000 --- a/aixplain/modules/pipeline/pipeline_cache.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -import json -import time -import logging -from datetime import datetime -from enum import Enum -from urllib.parse import urljoin -from typing import Dict, Optional, Union, Text -from aixplain.utils import config -from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER -from aixplain.enums import Supplier - -PIPELINE_CACHE_FILE = f"{CACHE_FOLDER}/pipelines.json" -LOCK_FILE = f"{PIPELINE_CACHE_FILE}.lock" - - -def load_pipelines(cache_expiry: Optional[int] = None): - """ - Load pipelines from cache or fetch from backend if not cached. - Only pipelines with status "onboarded" should be cached. - - Args: - cache_expiry (int, optional): Expiry time in seconds. Default is user-configurable. - """ - if cache_expiry is None: - cache_expiry = 86400 - - api_key = config.TEAM_API_KEY - backend_url = config.BACKEND_URL - - cached_data = load_from_cache(PIPELINE_CACHE_FILE, LOCK_FILE) - if cached_data is not None: - return parse_pipelines(cached_data) - - url = urljoin(backend_url, "sdk/pipelines") - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - - response = _request_with_retry("get", url, headers=headers) - if not 200 <= response.status_code < 300: - raise Exception(f"Pipelines could not be loaded, API key '{api_key}' might be invalid.") - - pipelines_data = response.json() - - onboarded_pipelines = [pipeline for pipeline in pipelines_data["items"] if pipeline["status"].lower() == "onboarded"] - - save_to_cache(PIPELINE_CACHE_FILE, {"items": onboarded_pipelines}, LOCK_FILE) - - return parse_pipelines({"items": onboarded_pipelines}) - - -def parse_pipelines(pipelines_data): - """ - Convert pipeline data into an Enum and dictionary format for easy use. - - Returns: - - pipelines_enum: Enum with pipeline IDs. - - pipelines_details: Dictionary containing all pipeline parameters. - """ - if not pipelines_data["items"]: - logging.warning("No onboarded pipelines found.") - return Enum("Pipeline", {}), {} - - pipelines_enum = Enum("Pipeline", {p["id"].upper().replace("-", "_"): p["id"] for p in pipelines_data["items"]}, type=str) - - pipelines_details = { - pipeline["id"]: { - "id": pipeline["id"], - "name": pipeline.get("name", ""), - "description": pipeline.get("description", ""), - "api_key": pipeline.get("api_key", config.TEAM_API_KEY), - "supplier": pipeline.get("supplier", "aiXplain"), - "version": pipeline.get("version", "1.0"), - "status": pipeline.get("status", "onboarded"), - "created_at": pipeline.get("created_at"), - "architecture": pipeline.get("architecture", {}), - **pipeline, - } - for pipeline in pipelines_data["items"] - } - - return pipelines_enum, pipelines_details - - -Pipeline, PipelineDetails = load_pipelines() From 92d8e283ca590f2625877a018f863bf351ca5a7f Mon Sep 17 00:00:00 2001 From: ahmetgunduz Date: Mon, 10 Feb 2025 10:49:28 +0000 Subject: [PATCH 04/13] removed unused imports --- aixplain/enums/aixplain_cache.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/aixplain/enums/aixplain_cache.py b/aixplain/enums/aixplain_cache.py index 339360a3..23129e24 100644 --- a/aixplain/enums/aixplain_cache.py +++ b/aixplain/enums/aixplain_cache.py @@ -1,10 +1,8 @@ import os -import json import logging -from datetime import datetime from enum import Enum from urllib.parse import urljoin -from typing import Dict, Optional, Tuple, List +from typing import Dict, Optional, Tuple from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER @@ -54,7 +52,7 @@ def load_assets(self, cache_expiry: Optional[int] = 86400) -> Tuple[Enum, Dict]: assets_data = response.json() except Exception as e: logging.error(f"Failed to fetch {self.asset_type} from API: {e}") - return Enum(self.asset_type.capitalize(), {}), {} + return Enum(self.asset_type.capitalize(), {}), {} if "items" not in assets_data: return Enum(self.asset_type.capitalize(), {}), {} @@ -104,11 +102,16 @@ def parse_assets(self, assets_data: Dict) -> Tuple[Enum, Dict]: return assets_enum, assets_details -ModelCache = AixplainCache("models", "models") -Model, ModelDetails = ModelCache.load_assets() +if __name__ == "__main__": + ModelCache = AixplainCache("models", "models") + Model, ModelDetails = ModelCache.load_assets() + + PipelineCache = AixplainCache("pipelines", "pipelines") + Pipeline, PipelineDetails = PipelineCache.load_assets() -PipelineCache = AixplainCache("pipelines", "pipelines") -Pipeline, PipelineDetails = PipelineCache.load_assets() + AgentCache = AixplainCache("agents", "agents") + Agent, AgentDetails = AgentCache.load_assets() -AgentCache = AixplainCache("agents", "agents") -Agent, AgentDetails = AgentCache.load_assets() + print(ModelDetails) + print(PipelineDetails) + print(AgentDetails) From 31d7b7954fdfa56a598041c5530d63185116bdc5 Mon Sep 17 00:00:00 2001 From: xainaz Date: Wed, 19 Mar 2025 14:23:29 +0300 Subject: [PATCH 05/13] made requested changes and added functional tests --- aixplain/enums/aixplain_cache.py | 22 +++------------- aixplain/modules/agent/__init__.py | 32 +++++++++++------------ aixplain/modules/model/__init__.py | 23 +++++++++------- tests/functional/model/run_model_test.py | 29 ++++++++++++++++++-- tests/functional/pipelines/create_test.py | 27 ++++++++++++++++++- 5 files changed, 86 insertions(+), 47 deletions(-) diff --git a/aixplain/enums/aixplain_cache.py b/aixplain/enums/aixplain_cache.py index 23129e24..26543693 100644 --- a/aixplain/enums/aixplain_cache.py +++ b/aixplain/enums/aixplain_cache.py @@ -14,7 +14,7 @@ class AixplainCache: This class reduces code repetition and allows easier maintenance. """ - def __init__(self, asset_type: str, cache_filename: str): + def __init__(self, asset_type: str, cache_filename: Optional[str] = None): """ Initialize the cache for a given asset type. @@ -23,7 +23,8 @@ def __init__(self, asset_type: str, cache_filename: str): cache_filename (str): Filename for storing cached data. """ self.asset_type = asset_type - self.cache_file = f"{CACHE_FOLDER}/{cache_filename}.json" + filename = cache_filename if cache_filename else asset_type + self.cache_file = f"{CACHE_FOLDER}/{filename}.json" self.lock_file = f"{self.cache_file}.lock" os.makedirs(CACHE_FOLDER, exist_ok=True) # Ensure cache folder exists @@ -99,19 +100,4 @@ def parse_assets(self, assets_data: Dict) -> Tuple[Enum, Dict]: for asset in assets_data["items"] } - return assets_enum, assets_details - - -if __name__ == "__main__": - ModelCache = AixplainCache("models", "models") - Model, ModelDetails = ModelCache.load_assets() - - PipelineCache = AixplainCache("pipelines", "pipelines") - Pipeline, PipelineDetails = PipelineCache.load_assets() - - AgentCache = AixplainCache("agents", "agents") - Agent, AgentDetails = AgentCache.load_assets() - - print(ModelDetails) - print(PipelineDetails) - print(AgentDetails) + return assets_enum, assets_details \ No newline at end of file diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 372c78c3..4bf54565 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -61,6 +61,8 @@ class Agent(Model): api_key (str): The TEAM API key used for authentication. cost (Dict, optional): model price. Defaults to None. """ + AgentCache = AixplainCache("agents", "agents") + AgentEnum, AgentDetails = AgentCache.load_assets() is_valid: bool @@ -80,22 +82,20 @@ def __init__( tasks: List[AgentTask] = [], **additional_info, ) -> None: - AgentCache = AixplainCache("agents", "agents") - AgentEnum, AgentDetails = AgentCache.load_assets() - - if id in AgentDetails: - cached_agent = AgentDetails[id] - name = cached_agent["name"] - description = cached_agent["description"] - tools = cached_agent["tools"] - llm_id = cached_agent["llm_id"] - api_key = cached_agent["api_key"] - supplier = cached_agent["supplier"] - version = cached_agent["version"] - cost = cached_agent["cost"] - status = cached_agent["status"] - tasks = cached_agent["tasks"] - additional_info = cached_agent["additional_info"] + + if id in self.__class__.AgentDetails: + cached_agent = self.__class__.AgentDetails[id] + name = cached_agent.get("name", name) + description = cached_agent.get("description", description) + tools = cached_agent.get("tools", tools) + llm_id = cached_agent.get("llm_id", llm_id) + api_key = cached_agent.get("api_key", api_key) + supplier = cached_agent.get("supplier", supplier) + version = cached_agent.get("version", version) + cost = cached_agent.get("cost", cost) + status = cached_agent.get("status", status) + tasks = cached_agent.get("tasks", tasks) + additional_info = cached_agent.get("additional_info", additional_info) """Create an Agent with the necessary information. diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 3aee0a31..e7e3def9 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -95,21 +95,24 @@ def __init__( ModelEnum, ModelDetails = ModelCache.load_assets() if id in ModelDetails: - cached_model = ModelDetails[id] - input_params = cached_model["input"] - api_key = cached_model["spec"].get("api_key", api_key) - additional_info = cached_model["spec"].get("additional_info", {}) - function = cached_model["spec"].get("function", function) - is_subscribed = cached_model["spec"].get("is_subscribed", is_subscribed) - created_at = cached_model["spec"].get("created_at", created_at) - model_params = cached_model["spec"].get("model_params", model_params) - output_params = cached_model["output"] - description = cached_model["spec"].get("description", description) + + input_params = cached_model.get("params", input_params) + function = cached_model.get("function", {}).get("name", function) + name = cached_model.get("name", name) + supplier = cached_model.get("supplier", supplier) + + created_at_str = cached_model.get("createdAt") + if created_at_str: + created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) + + cost = cached_model.get("pricing", cost) + super().__init__(id, name, description, supplier, version, cost=cost) self.api_key = api_key self.additional_info = additional_info + self.name = name self.url = config.MODELS_RUN_URL self.backend_url = config.BACKEND_URL self.function = function diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index d3d0082f..6c601917 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -1,12 +1,14 @@ __author__ = "thiagocastroferreira" - +import os +import json from aixplain.enums import Function from aixplain.factories import ModelFactory from aixplain.modules import LLM from datetime import datetime, timedelta, timezone from pathlib import Path - +from aixplain.utils.cache_utils import CACHE_FOLDER +from aixplain.modules.model import Model def pytest_generate_tests(metafunc): if "llm_model" in metafunc.fixturenames: @@ -94,3 +96,26 @@ def test_llm_run_with_file(): # Verify response assert response["status"] == "SUCCESS" assert "🤖" in response["data"], "Robot emoji should be present in the response" + + +def test_aixplain_model_cache_creation(): + """Ensure AixplainCache is triggered and cache is created.""" + + cache_file = os.path.join(CACHE_FOLDER, "models.json") + + # Clean up cache before the test + if os.path.exists(cache_file): + os.remove(cache_file) + + # Instantiate the Model (replace this with a real model ID from your env) + model_id = "6239efa4822d7a13b8e20454" # Translate from Punjabi to Portuguese (Brazil) + _ = Model(id=model_id) + + # Assert the cache file was created + assert os.path.exists(cache_file), "Expected cache file was not created." + + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + assert "data" in cache_data, "Cache file structure invalid - missing 'data' key." + assert any(m.get("id") == model_id for m in cache_data["data"]["items"]), "Instantiated model not found in cache." diff --git a/tests/functional/pipelines/create_test.py b/tests/functional/pipelines/create_test.py index ae33f454..8bcb3b7d 100644 --- a/tests/functional/pipelines/create_test.py +++ b/tests/functional/pipelines/create_test.py @@ -15,7 +15,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - +import os +from aixplain.utils.cache_utils import CACHE_FOLDER +from aixplain.modules.pipeline import Pipeline import json import pytest from aixplain.factories import PipelineFactory @@ -75,3 +77,26 @@ def test_create_pipeline_wrong_path(PipelineFactory): with pytest.raises(Exception): PipelineFactory.create(name=pipeline_name, pipeline="/") + + + +@pytest.mark.parametrize("PipelineFactory", [PipelineFactory]) +def test_pipeline_cache_creation(PipelineFactory): + cache_file = os.path.join(CACHE_FOLDER, "pipelines.json") + if os.path.exists(cache_file): + os.remove(cache_file) + + pipeline_json = "tests/functional/pipelines/data/pipeline.json" + pipeline_name = str(uuid4()) + pipeline = PipelineFactory.create(name=pipeline_name, pipeline=pipeline_json) + + assert os.path.exists(cache_file), "Pipeline cache file was not created!" + + with open(cache_file, "r") as f: + cache_data = json.load(f) + + assert "data" in cache_data, "Cache format invalid, missing 'data'." + + pipeline.delete() + if os.path.exists(cache_file): + os.remove(cache_file) \ No newline at end of file From be58677b07c6988bc72d848d6bf314205d67ac57 Mon Sep 17 00:00:00 2001 From: xainaz Date: Mon, 24 Mar 2025 22:40:40 +0300 Subject: [PATCH 06/13] changes --- aixplain/enums/aixplain_cache.py | 168 ++++++++++++++++++----------- aixplain/enums/function.py | 6 +- aixplain/enums/language.py | 6 +- aixplain/enums/license.py | 6 +- aixplain/modules/model/__init__.py | 72 +++++++++---- 5 files changed, 170 insertions(+), 88 deletions(-) diff --git a/aixplain/enums/aixplain_cache.py b/aixplain/enums/aixplain_cache.py index 26543693..f8cb2f78 100644 --- a/aixplain/enums/aixplain_cache.py +++ b/aixplain/enums/aixplain_cache.py @@ -1,103 +1,151 @@ import os import logging +import json +import time from enum import Enum from urllib.parse import urljoin from typing import Dict, Optional, Tuple +from dataclasses import dataclass +from filelock import FileLock + from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER - +from aixplain.enums.privacy import Privacy +# Constants +CACHE_FOLDER = ".cache" +DEFAULT_CACHE_EXPIRY = 86400 + +@dataclass +class Asset: + id: str + name: str = "" + description: str = "" + api_key: str = config.TEAM_API_KEY + supplier: str = "aiXplain" + version: str = "1.0" + status: str = "onboarded" + created_at: str = "" + +class AssetType(Enum): + MODELS = "models" + PIPELINES = "pipelines" + AGENTS = "agents" + +def get_cache_expiry(): + return int(os.getenv("CACHE_EXPIRY_TIME", DEFAULT_CACHE_EXPIRY)) + +def _serialize(obj): + if isinstance(obj, (Privacy)): + return str(obj) # or obj.to_dict() if you have it + if isinstance(obj, Enum): + return obj.value + return obj.__dict__ if hasattr(obj, "__dict__") else str(obj) class AixplainCache: """ A modular caching system to handle different asset types (Models, Pipelines, Agents). - This class reduces code repetition and allows easier maintenance. """ - def __init__(self, asset_type: str, cache_filename: Optional[str] = None): - """ - Initialize the cache for a given asset type. - - Args: - asset_type (str): Type of the asset (e.g., "models", "pipelines", "agents"). - cache_filename (str): Filename for storing cached data. - """ - self.asset_type = asset_type - filename = cache_filename if cache_filename else asset_type - self.cache_file = f"{CACHE_FOLDER}/{filename}.json" - self.lock_file = f"{self.cache_file}.lock" - os.makedirs(CACHE_FOLDER, exist_ok=True) # Ensure cache folder exists - - def load_assets(self, cache_expiry: Optional[int] = 86400) -> Tuple[Enum, Dict]: + @staticmethod + def save_to_cache(cache_file, data, lock_file): + os.makedirs(os.path.dirname(cache_file), exist_ok=True) + with FileLock(lock_file): + with open(cache_file, "w") as f: + json.dump({"timestamp": time.time(), "data": data}, f,default=_serialize) + + @staticmethod + def load_from_cache(cache_file, lock_file): + if os.path.exists(cache_file): + with FileLock(lock_file): + with open(cache_file, "r") as f: + cache_data = json.load(f) + if time.time() - cache_data["timestamp"] < get_cache_expiry(): + return cache_data["data"] + else: + try: + os.remove(cache_file) + if os.path.exists(lock_file): + os.remove(lock_file) + except Exception as e: + logging.warning(f"Failed to remove expired cache or lock file: {e}") + return None + + @staticmethod + def fetch_assets_from_backend(asset_type_str: str) -> Optional[Dict]: """ - Load assets from cache or fetch from backend if not cached. - - Args: - cache_expiry (int, optional): Expiry time in seconds. Default is 24 hours. - - Returns: - Tuple[Enum, Dict]: (Enum of asset IDs, Dictionary with asset details) + Fetch assets data from the backend API. """ - cached_data = load_from_cache(self.cache_file, self.lock_file) - if cached_data is not None: - return self.parse_assets(cached_data) - api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL - url = urljoin(backend_url, f"sdk/{self.asset_type}") + url = urljoin(backend_url, f"sdk/{asset_type_str}") headers = {"x-api-key": api_key, "Content-Type": "application/json"} try: response = _request_with_retry("get", url, headers=headers) response.raise_for_status() - assets_data = response.json() + return response.json() except Exception as e: - logging.error(f"Failed to fetch {self.asset_type} from API: {e}") - return Enum(self.asset_type.capitalize(), {}), {} + logging.error(f"Failed to fetch {asset_type_str} from API: {e}") + return None - if "items" not in assets_data: - return Enum(self.asset_type.capitalize(), {}), {} + def __init__(self, asset_type: AssetType | str, cache_filename: Optional[str] = None): + if isinstance(asset_type, str): + asset_type = AssetType(asset_type.lower()) + self.asset_type = asset_type - onboarded_assets = [asset for asset in assets_data["items"] if asset.get("status", "").lower() == "onboarded"] + filename = cache_filename if cache_filename else self.asset_type.value + self.cache_file = f"{CACHE_FOLDER}/{filename}.json" + self.lock_file = f"{self.cache_file}.lock" + os.makedirs(CACHE_FOLDER, exist_ok=True) - save_to_cache(self.cache_file, {"items": onboarded_assets}, self.lock_file) + + def load_assets(self) -> Tuple[Enum, Dict]: + """ + Load assets from cache or fetch from backend if not cached. + """ + cached_data = self.load_from_cache(self.cache_file, self.lock_file) + if cached_data: + return self.parse_assets(cached_data) + + assets_data = self.fetch_assets_from_backend(self.asset_type.value) + if not assets_data or "items" not in assets_data: + return Enum(self.asset_type.name, {}), {} + + onboarded_assets = [ + asset for asset in assets_data["items"] + if asset.get("status", "").lower() == "onboarded" + ] + + self.save_to_cache(self.cache_file, {"items": onboarded_assets}, self.lock_file) return self.parse_assets({"items": onboarded_assets}) def parse_assets(self, assets_data: Dict) -> Tuple[Enum, Dict]: """ Convert asset data into an Enum and dictionary format for easy use. - - Args: - assets_data (Dict): JSON response with asset list. - - Returns: - - assets_enum: Enum with asset IDs. - - assets_details: Dictionary containing all asset parameters. """ - if not assets_data["items"]: # Handle case where no assets are onboarded - logging.warning(f"No onboarded {self.asset_type} found.") - return Enum(self.asset_type.capitalize(), {}), {} + if not assets_data["items"]: + logging.warning(f"No onboarded {self.asset_type.value} found.") + return Enum(self.asset_type.name, {}), {} assets_enum = Enum( - self.asset_type.capitalize(), + self.asset_type.name, {a["id"].upper().replace("-", "_"): a["id"] for a in assets_data["items"]}, type=str, ) assets_details = { - asset["id"]: { - "id": asset["id"], - "name": asset.get("name", ""), - "description": asset.get("description", ""), - "api_key": asset.get("api_key", config.TEAM_API_KEY), - "supplier": asset.get("supplier", "aiXplain"), - "version": asset.get("version", "1.0"), - "status": asset.get("status", "onboarded"), - "created_at": asset.get("created_at", ""), - **asset, # Include any extra fields - } + asset["id"]: Asset( + id=asset["id"], + name=asset.get("name", ""), + description=asset.get("description", ""), + api_key=asset.get("api_key", config.TEAM_API_KEY), + supplier=asset.get("supplier", "aiXplain"), + version=asset.get("version", "1.0"), + status=asset.get("status", "onboarded"), + created_at=asset.get("created_at", "") + ) for asset in assets_data["items"] } - return assets_enum, assets_details \ No newline at end of file + return assets_enum, assets_details diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index a77f3cfc..62408c58 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -25,7 +25,7 @@ from aixplain.utils.request_utils import _request_with_retry from enum import Enum from urllib.parse import urljoin -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.enums.aixplain_cache import AixplainCache, CACHE_FOLDER from typing import Tuple, Dict from aixplain.base.parameters import BaseParameters, Parameter import os @@ -40,7 +40,7 @@ def load_functions(): os.makedirs(CACHE_FOLDER, exist_ok=True) - resp = load_from_cache(CACHE_FILE, LOCK_FILE) + resp = AixplainCache.load_from_cache(CACHE_FILE, LOCK_FILE) if resp is None: url = urljoin(backend_url, "sdk/functions") @@ -51,7 +51,7 @@ def load_functions(): f'Functions could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp, LOCK_FILE) + AixplainCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) class Function(str, Enum): def __new__(cls, value): diff --git a/aixplain/enums/language.py b/aixplain/enums/language.py index c129822f..e0d035bd 100644 --- a/aixplain/enums/language.py +++ b/aixplain/enums/language.py @@ -25,14 +25,14 @@ from urllib.parse import urljoin from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.enums.aixplain_cache import AixplainCache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/languages.json" LOCK_FILE = f"{CACHE_FILE}.lock" def load_languages(): - resp = load_from_cache(CACHE_FILE, LOCK_FILE) + resp = AixplainCache.load_from_cache(CACHE_FILE, LOCK_FILE) if resp is None: api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL @@ -46,7 +46,7 @@ def load_languages(): f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp, LOCK_FILE) + AixplainCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) languages = {} for w in resp: diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index f9758b84..ef9285d9 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -26,14 +26,14 @@ from urllib.parse import urljoin from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.enums.aixplain_cache import AixplainCache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" LOCK_FILE = f"{CACHE_FILE}.lock" def load_licenses(): - resp = load_from_cache(CACHE_FILE, LOCK_FILE) + resp = AixplainCache.load_from_cache(CACHE_FILE, LOCK_FILE) try: if resp is None: @@ -49,7 +49,7 @@ def load_licenses(): f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp, LOCK_FILE) + AixplainCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) licenses = {"_".join(w["name"].split()): w["id"] for w in resp} return Enum("License", licenses, type=str) diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index e7e3def9..c07222e6 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -34,6 +34,7 @@ from aixplain.modules.model.response import ModelResponse from aixplain.enums.response_status import ResponseStatus from aixplain.modules.model.model_parameters import ModelParameters +from aixplain.enums.aixplain_cache import AssetType class Model(Asset): @@ -91,36 +92,69 @@ def __init__( model_params (Dict, optional): parameters for the function. **additional_info: Any additional Model info to be saved """ - ModelCache = AixplainCache("models", "models") - ModelEnum, ModelDetails = ModelCache.load_assets() - - if id in ModelDetails: - cached_model = ModelDetails[id] - - input_params = cached_model.get("params", input_params) - function = cached_model.get("function", {}).get("name", function) - name = cached_model.get("name", name) - supplier = cached_model.get("supplier", supplier) - - created_at_str = cached_model.get("createdAt") + model_details = self._get_model_details_from_cache(id) + + if model_details: + name = model_details.get("name", name) + description = model_details.get("description", description) + supplier = model_details.get("supplier", supplier) + version = model_details.get("version", version) + function = model_details.get("function", {}).get("name", function) + cost = model_details.get("pricing", cost) + input_params = model_details.get("params", input_params) + output_params = model_details.get("output_params", output_params) + model_params = model_details.get("model_params", model_params) + created_at_str = model_details.get("createdAt") if created_at_str: created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) - - cost = cached_model.get("pricing", cost) - + is_subscribed = model_details.get("is_subscribed", is_subscribed) super().__init__(id, name, description, supplier, version, cost=cost) + self.api_key = api_key - self.additional_info = additional_info - self.name = name - self.url = config.MODELS_RUN_URL - self.backend_url = config.BACKEND_URL self.function = function self.is_subscribed = is_subscribed self.created_at = created_at self.input_params = input_params self.output_params = output_params self.model_params = ModelParameters(model_params) if model_params else None + self.url = config.MODELS_RUN_URL + self.backend_url = config.BACKEND_URL + self.additional_info = additional_info + + if not model_details: + self._add_model_to_cache() + + @staticmethod + def _get_model_details_from_cache(model_id: str) -> Optional[Dict]: + """ + Private helper to load model details from the cache. + """ + try: + model_cache = AixplainCache(AssetType.MODELS) + _, models_data = model_cache.load_assets() + model_asset = models_data.get(model_id) + return model_asset.__dict__ if model_asset else None + except Exception as e: + logging.error(f"Error loading model from cache: {e}") + traceback.print_exc() + return None + + def _add_model_to_cache(self): + try: + + model_cache = AixplainCache(AssetType.MODELS) + _, models_data = model_cache.load_assets() + + models_data[self.id] = self + + serializable_data = {mid: vars(m) for mid, m in models_data.items()} + + model_cache.save_to_cache(model_cache.cache_file, {"items": list(serializable_data.values())}, model_cache.lock_file) + + logging.info(f"Model {self.id} added to cache.") + except Exception as e: + logging.error(f"Failed to add model {self.id} to cache: {e}") def to_dict(self) -> Dict: """Get the model info as a Dictionary From 42e3472b0bc3cdd94da9f8bee3157df6837ba90c Mon Sep 17 00:00:00 2001 From: xainaz Date: Thu, 3 Apr 2025 12:34:50 +0300 Subject: [PATCH 07/13] Made changes to model cache --- aixplain/enums/__init__.py | 2 +- .../{aixplain_cache.py => asset_cache.py} | 21 +++++- aixplain/enums/function.py | 6 +- aixplain/enums/language.py | 6 +- aixplain/enums/license.py | 6 +- aixplain/factories/model_factory/__init__.py | 51 +++++++++++++- aixplain/modules/agent/__init__.py | 4 +- aixplain/modules/model/__init__.py | 67 ++----------------- tests/functional/model/run_model_test.py | 2 +- 9 files changed, 89 insertions(+), 76 deletions(-) rename aixplain/enums/{aixplain_cache.py => asset_cache.py} (86%) diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index f9cac81d..0eba518a 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -15,5 +15,5 @@ from .sort_by import SortBy from .sort_order import SortOrder from .response_status import ResponseStatus -from .aixplain_cache import AixplainCache +from .asset_cache import AssetCache, Asset, AssetType from .database_source import DatabaseSourceType diff --git a/aixplain/enums/aixplain_cache.py b/aixplain/enums/asset_cache.py similarity index 86% rename from aixplain/enums/aixplain_cache.py rename to aixplain/enums/asset_cache.py index f8cb2f78..1887b6d5 100644 --- a/aixplain/enums/aixplain_cache.py +++ b/aixplain/enums/asset_cache.py @@ -41,7 +41,7 @@ def _serialize(obj): return obj.value return obj.__dict__ if hasattr(obj, "__dict__") else str(obj) -class AixplainCache: +class AssetCache: """ A modular caching system to handle different asset types (Models, Pipelines, Agents). """ @@ -98,6 +98,9 @@ def __init__(self, asset_type: AssetType | str, cache_filename: Optional[str] = self.lock_file = f"{self.cache_file}.lock" os.makedirs(CACHE_FOLDER, exist_ok=True) + # Load assets immediately during initialization + self.assets_enum, self.assets_data = self._initialize_assets() + def load_assets(self) -> Tuple[Enum, Dict]: """ @@ -149,3 +152,19 @@ def parse_assets(self, assets_data: Dict) -> Tuple[Enum, Dict]: } return assets_enum, assets_details + + def _initialize_assets(self) -> Tuple[Enum, Dict]: + cached_data = self.load_from_cache(self.cache_file, self.lock_file) + if cached_data: + return self.parse_assets(cached_data) + + assets_data = self.fetch_assets_from_backend(self.asset_type.value) + if not assets_data or "items" not in assets_data: + return Enum(self.asset_type.name, {}), {} + + onboarded_assets = [ + asset for asset in assets_data["items"] + if asset.get("status", "").lower() == "onboarded" + ] + self.save_to_cache(self.cache_file, {"items": onboarded_assets}, self.lock_file) + return self.parse_assets({"items": onboarded_assets}) diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index 62408c58..12f57644 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -25,7 +25,7 @@ from aixplain.utils.request_utils import _request_with_retry from enum import Enum from urllib.parse import urljoin -from aixplain.enums.aixplain_cache import AixplainCache, CACHE_FOLDER +from .asset_cache import AssetCache, CACHE_FOLDER from typing import Tuple, Dict from aixplain.base.parameters import BaseParameters, Parameter import os @@ -40,7 +40,7 @@ def load_functions(): os.makedirs(CACHE_FOLDER, exist_ok=True) - resp = AixplainCache.load_from_cache(CACHE_FILE, LOCK_FILE) + resp = AssetCache.load_from_cache(CACHE_FILE, LOCK_FILE) if resp is None: url = urljoin(backend_url, "sdk/functions") @@ -51,7 +51,7 @@ def load_functions(): f'Functions could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - AixplainCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) + AssetCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) class Function(str, Enum): def __new__(cls, value): diff --git a/aixplain/enums/language.py b/aixplain/enums/language.py index e0d035bd..f1f184de 100644 --- a/aixplain/enums/language.py +++ b/aixplain/enums/language.py @@ -25,14 +25,14 @@ from urllib.parse import urljoin from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry -from aixplain.enums.aixplain_cache import AixplainCache, CACHE_FOLDER +from .asset_cache import AssetCache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/languages.json" LOCK_FILE = f"{CACHE_FILE}.lock" def load_languages(): - resp = AixplainCache.load_from_cache(CACHE_FILE, LOCK_FILE) + resp = AssetCache.load_from_cache(CACHE_FILE, LOCK_FILE) if resp is None: api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL @@ -46,7 +46,7 @@ def load_languages(): f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - AixplainCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) + AssetCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) languages = {} for w in resp: diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index ef9285d9..90a873f3 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -26,14 +26,14 @@ from urllib.parse import urljoin from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry -from aixplain.enums.aixplain_cache import AixplainCache, CACHE_FOLDER +from .asset_cache import AssetCache, CACHE_FOLDER CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" LOCK_FILE = f"{CACHE_FILE}.lock" def load_licenses(): - resp = AixplainCache.load_from_cache(CACHE_FILE, LOCK_FILE) + resp = AssetCache.load_from_cache(CACHE_FILE, LOCK_FILE) try: if resp is None: @@ -49,7 +49,7 @@ def load_licenses(): f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - AixplainCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) + AssetCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) licenses = {"_".join(w["name"].split()): w["id"] for w in resp} return Enum("License", licenses, type=str) diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index b39cc668..733232ee 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -29,6 +29,8 @@ from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin +from aixplain.enums import AssetCache, AssetType, Asset +from aixplain.factories.model_factory.utils import create_model_from_response class ModelFactory: @@ -88,6 +90,17 @@ def create_utility_model( if 200 <= r.status_code < 300: utility_model.id = resp["id"] logging.info(f"Utility Model Creation: Model {utility_model.id} instantiated.") + + new_asset = Asset( + id=resp["id"], + name=resp.get("name", ""), + description=resp.get("description", ""), + api_key=api_key or config.TEAM_API_KEY, + supplier=resp.get("supplier", "aiXplain"), + version=resp.get("version", "1.0"), + status=resp.get("status", "onboarded"), + created_at=resp.get("createdAt", "") + ) return utility_model else: error_message = ( @@ -97,7 +110,7 @@ def create_utility_model( raise Exception(error_message) @classmethod - def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: + def get(cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = True) -> Model: """Create a 'Model' object from model id Args: @@ -107,8 +120,27 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: Returns: Model: Created 'Model' object """ + cache = AssetCache(AssetType.MODELS) + asset = cache.assets_data.get(model_id) if use_cache else None + + if asset: + logging.info(f"ModelFactory: Loaded model {model_id} from cache.") + return Model( + id=asset.id, + name=asset.name, + description=asset.description, + api_key=api_key or asset.api_key, + supplier=asset.supplier, + version=asset.version, + created_at=asset.created_at, + ) + + url = urljoin(cls.backend_url, f"sdk/models/{model_id}") + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + resp = None try: + logging.info(f"ModelFactory: Model {model_id} not in cache. Fetching from backend...") url = urljoin(cls.backend_url, f"sdk/models/{model_id}") headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} @@ -131,9 +163,24 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: resp["api_key"] = api_key from aixplain.factories.model_factory.utils import create_model_from_response + new_asset = Asset( + id=resp["id"], + name=resp.get("name", ""), + description=resp.get("description", ""), + api_key=api_key or config.TEAM_API_KEY, + supplier=resp.get("supplier", "aiXplain"), + version=resp.get("version", "1.0"), + status=resp.get("status", "onboarded"), + created_at=resp.get("createdAt", "") + ) + cache.assets_data[new_asset.id] = new_asset + serializable_data = {k: vars(v) for k, v in cache.assets_data.items()} + cache.save_to_cache(cache.cache_file, {"items": list(serializable_data.values())}, cache.lock_file) + model = create_model_from_response(resp) - logging.info(f"Model Creation: Model {model_id} instantiated.") + logging.info(f"ModelFactory: Model {model_id} fetched and cached successfully.") return model + else: error_message = f"Model GET Error: Failed to retrieve model {model_id}. Status Code: {r.status_code}. Error: {resp}" logging.error(error_message) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 4bf54565..8ab5b38c 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -37,7 +37,7 @@ from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData -from aixplain.enums import ResponseStatus, AixplainCache +from aixplain.enums import ResponseStatus, AssetCache from aixplain.modules.agent.utils import process_variables from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin @@ -61,7 +61,7 @@ class Agent(Model): api_key (str): The TEAM API key used for authentication. cost (Dict, optional): model price. Defaults to None. """ - AgentCache = AixplainCache("agents", "agents") + AgentCache = AssetCache("agents", "agents") AgentEnum, AgentDetails = AgentCache.load_assets() is_valid: bool diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index c07222e6..3bac37af 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -23,7 +23,7 @@ import time import logging import traceback -from aixplain.enums import Supplier, Function, AixplainCache +from aixplain.enums import Supplier, Function from aixplain.modules.asset import Asset from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config @@ -34,7 +34,6 @@ from aixplain.modules.model.response import ModelResponse from aixplain.enums.response_status import ResponseStatus from aixplain.modules.model.model_parameters import ModelParameters -from aixplain.enums.aixplain_cache import AssetType class Model(Asset): @@ -92,69 +91,17 @@ def __init__( model_params (Dict, optional): parameters for the function. **additional_info: Any additional Model info to be saved """ - model_details = self._get_model_details_from_cache(id) - - if model_details: - name = model_details.get("name", name) - description = model_details.get("description", description) - supplier = model_details.get("supplier", supplier) - version = model_details.get("version", version) - function = model_details.get("function", {}).get("name", function) - cost = model_details.get("pricing", cost) - input_params = model_details.get("params", input_params) - output_params = model_details.get("output_params", output_params) - model_params = model_details.get("model_params", model_params) - created_at_str = model_details.get("createdAt") - if created_at_str: - created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) - is_subscribed = model_details.get("is_subscribed", is_subscribed) - super().__init__(id, name, description, supplier, version, cost=cost) - self.api_key = api_key + self.additional_info = additional_info + self.url = config.MODELS_RUN_URL + self.backend_url = config.BACKEND_URL self.function = function self.is_subscribed = is_subscribed self.created_at = created_at self.input_params = input_params self.output_params = output_params self.model_params = ModelParameters(model_params) if model_params else None - self.url = config.MODELS_RUN_URL - self.backend_url = config.BACKEND_URL - self.additional_info = additional_info - - if not model_details: - self._add_model_to_cache() - - @staticmethod - def _get_model_details_from_cache(model_id: str) -> Optional[Dict]: - """ - Private helper to load model details from the cache. - """ - try: - model_cache = AixplainCache(AssetType.MODELS) - _, models_data = model_cache.load_assets() - model_asset = models_data.get(model_id) - return model_asset.__dict__ if model_asset else None - except Exception as e: - logging.error(f"Error loading model from cache: {e}") - traceback.print_exc() - return None - - def _add_model_to_cache(self): - try: - - model_cache = AixplainCache(AssetType.MODELS) - _, models_data = model_cache.load_assets() - - models_data[self.id] = self - - serializable_data = {mid: vars(m) for mid, m in models_data.items()} - - model_cache.save_to_cache(model_cache.cache_file, {"items": list(serializable_data.values())}, model_cache.lock_file) - - logging.info(f"Model {self.id} added to cache.") - except Exception as e: - logging.error(f"Failed to add model {self.id} to cache: {e}") def to_dict(self) -> Dict: """Get the model info as a Dictionary @@ -296,7 +243,7 @@ def run( start = time.time() payload = build_payload(data=data, parameters=parameters) url = f"{self.url}/{self.id}".replace("api/v1/execute", "api/v2/execute") - logging.debug(f"Model Run Sync: Start service for {name} - {url}") + logging.debug(f"Model Run Sync: Start service for {name} - {url} - {payload}") response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) if response["status"] == "IN_PROGRESS": try: @@ -334,8 +281,8 @@ def run_async( dict: polling URL in response """ url = f"{self.url}/{self.id}" - logging.debug(f"Model Run Async: Start service for {name} - {url}") payload = build_payload(data=data, parameters=parameters) + logging.debug(f"Model Run Async: Start service for {name} - {url} - {payload}") response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) return ModelResponse( status=response.pop("status", ResponseStatus.FAILED), @@ -426,4 +373,4 @@ def delete(self) -> None: except Exception: message = "Model Deletion Error: Make sure the model exists and you are the owner." logging.error(message) - raise Exception(f"{message}") + raise Exception(f"{message}") \ No newline at end of file diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 6c601917..b78855a1 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -99,7 +99,7 @@ def test_llm_run_with_file(): def test_aixplain_model_cache_creation(): - """Ensure AixplainCache is triggered and cache is created.""" + """Ensure AssetCache is triggered and cache is created.""" cache_file = os.path.join(CACHE_FOLDER, "models.json") From ea612329764b200cc2036f891b951cc0c80d61ec Mon Sep 17 00:00:00 2001 From: xainaz Date: Thu, 3 Apr 2025 12:42:14 +0300 Subject: [PATCH 08/13] Made changes to model cache --- aixplain/factories/model_factory/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index 733232ee..b6c040a9 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -140,7 +140,7 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = T resp = None try: - logging.info(f"ModelFactory: Model {model_id} not in cache. Fetching from backend...") + logging.info(f"Fetching Model from backend...") url = urljoin(cls.backend_url, f"sdk/models/{model_id}") headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} From 380e3ca2d8dc599d971ff9fcdf312ea3965c0583 Mon Sep 17 00:00:00 2001 From: xainaz Date: Thu, 10 Apr 2025 16:34:27 +0300 Subject: [PATCH 09/13] model changes --- aixplain/enums/__init__.py | 2 +- aixplain/enums/asset_cache.py | 282 ++++++++++--------- aixplain/enums/function.py | 23 +- aixplain/enums/language.py | 26 +- aixplain/enums/license.py | 35 ++- aixplain/factories/model_factory/__init__.py | 107 +++---- aixplain/modules/agent/__init__.py | 4 +- aixplain/modules/pipeline/asset.py | 4 +- 8 files changed, 226 insertions(+), 257 deletions(-) diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index 0eba518a..51c663de 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -15,5 +15,5 @@ from .sort_by import SortBy from .sort_order import SortOrder from .response_status import ResponseStatus -from .asset_cache import AssetCache, Asset, AssetType +from .asset_cache import AssetCache from .database_source import DatabaseSourceType diff --git a/aixplain/enums/asset_cache.py b/aixplain/enums/asset_cache.py index 1887b6d5..836ede1a 100644 --- a/aixplain/enums/asset_cache.py +++ b/aixplain/enums/asset_cache.py @@ -2,21 +2,27 @@ import logging import json import time -from enum import Enum -from urllib.parse import urljoin -from typing import Dict, Optional, Tuple -from dataclasses import dataclass +from typing import Dict, Optional +from dataclasses import dataclass, asdict from filelock import FileLock from aixplain.utils import config -from aixplain.utils.request_utils import _request_with_retry -from aixplain.enums.privacy import Privacy +from aixplain.utils.file_utils import _request_with_retry +from urllib.parse import urljoin +from typing import TypeVar, Generic, Type +from typing import List + +logger = logging.getLogger(__name__) + + +T = TypeVar("T") + # Constants CACHE_FOLDER = ".cache" DEFAULT_CACHE_EXPIRY = 86400 @dataclass -class Asset: +class Model: id: str name: str = "" description: str = "" @@ -26,145 +32,143 @@ class Asset: status: str = "onboarded" created_at: str = "" -class AssetType(Enum): - MODELS = "models" - PIPELINES = "pipelines" - AGENTS = "agents" + @classmethod + def from_dict(cls, data: Dict) -> "Model": + return cls(**data) -def get_cache_expiry(): - return int(os.getenv("CACHE_EXPIRY_TIME", DEFAULT_CACHE_EXPIRY)) - -def _serialize(obj): - if isinstance(obj, (Privacy)): - return str(obj) # or obj.to_dict() if you have it - if isinstance(obj, Enum): - return obj.value - return obj.__dict__ if hasattr(obj, "__dict__") else str(obj) +@dataclass +class Store(Generic[T]): + data: Dict[str, T] + expiry: int -class AssetCache: +class AssetCache(Generic[T]): """ A modular caching system to handle different asset types (Models, Pipelines, Agents). """ - @staticmethod - def save_to_cache(cache_file, data, lock_file): - os.makedirs(os.path.dirname(cache_file), exist_ok=True) - with FileLock(lock_file): - with open(cache_file, "w") as f: - json.dump({"timestamp": time.time(), "data": data}, f,default=_serialize) - - @staticmethod - def load_from_cache(cache_file, lock_file): - if os.path.exists(cache_file): - with FileLock(lock_file): - with open(cache_file, "r") as f: - cache_data = json.load(f) - if time.time() - cache_data["timestamp"] < get_cache_expiry(): - return cache_data["data"] - else: - try: - os.remove(cache_file) - if os.path.exists(lock_file): - os.remove(lock_file) - except Exception as e: - logging.warning(f"Failed to remove expired cache or lock file: {e}") - return None - - @staticmethod - def fetch_assets_from_backend(asset_type_str: str) -> Optional[Dict]: - """ - Fetch assets data from the backend API. - """ - api_key = config.TEAM_API_KEY - backend_url = config.BACKEND_URL - url = urljoin(backend_url, f"sdk/{asset_type_str}") - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - + def __init__( + self, + cls: Type[T], + cache_filename: Optional[str] = None, + ): + self.cls = cls + if cache_filename is None: + cache_filename = self.cls.__name__.lower() + + # create cache file and lock file name + self.cache_file = os.path.join(CACHE_FOLDER, f"{cache_filename}.json") + self.lock_file = os.path.join(CACHE_FOLDER, f"{cache_filename}.lock") + self.store = Store(data={}, expiry=self.compute_expiry()) + self.load() + + if not os.path.exists(self.cache_file): + self.save() + + def compute_expiry(self): try: - response = _request_with_retry("get", url, headers=headers) - response.raise_for_status() - return response.json() + expiry = int(os.getenv("CACHE_EXPIRY_TIME", DEFAULT_CACHE_EXPIRY)) except Exception as e: - logging.error(f"Failed to fetch {asset_type_str} from API: {e}") - return None - - def __init__(self, asset_type: AssetType | str, cache_filename: Optional[str] = None): - if isinstance(asset_type, str): - asset_type = AssetType(asset_type.lower()) - self.asset_type = asset_type - - filename = cache_filename if cache_filename else self.asset_type.value - self.cache_file = f"{CACHE_FOLDER}/{filename}.json" - self.lock_file = f"{self.cache_file}.lock" + logger.warning( + f"Failed to parse CACHE_EXPIRY_TIME: {e}, " + f"fallback to default value {DEFAULT_CACHE_EXPIRY}" + ) + # remove the CACHE_EXPIRY_TIME from the environment variables + del os.environ["CACHE_EXPIRY_TIME"] + expiry = DEFAULT_CACHE_EXPIRY + + return time.time() + int(expiry) + + def invalidate(self): + self.store = Store(data={}, expiry=self.compute_expiry()) + # delete cache file and lock file + if os.path.exists(self.cache_file): + os.remove(self.cache_file) + if os.path.exists(self.lock_file): + os.remove(self.lock_file) + + def load(self): + if not os.path.exists(self.cache_file): + self.invalidate() + return + + with FileLock(self.lock_file): + with open(self.cache_file, "r") as f: + try: + cache_data = json.load(f) + except Exception as e: + # data is corrupted, invalidate the cache + self.invalidate() + logging.warning(f"Failed to parse cache file: {e}") + return + + try: + expiry = cache_data["expiry"] + raw_data = cache_data["data"] + parsed_data = { + k: self.cls( + id=v.get("id", ""), + name=v.get("name", ""), + description=v.get("description", ""), + api_key=v.get("api_key", config.TEAM_API_KEY), + supplier=v.get("supplier", "aiXplain"), + version=v.get("version", "1.0"), + status=v.get("status", "onboarded"), + created_at=v.get("created_at", ""), + ) for k, v in raw_data.items() + } + + + self.store = Store(data=parsed_data, expiry=expiry) + except Exception as e: + self.invalidate() + logging.warning(f"Failed to load cache data: {e}") + + + if self.store.expiry < time.time(): + logger.warning( + f"Cache expired, invalidating cache for {self.cls.__name__}" + ) + # cache expired, invalidate the cache + self.invalidate() + return + + def save(self): os.makedirs(CACHE_FOLDER, exist_ok=True) - # Load assets immediately during initialization - self.assets_enum, self.assets_data = self._initialize_assets() - - - def load_assets(self) -> Tuple[Enum, Dict]: - """ - Load assets from cache or fetch from backend if not cached. - """ - cached_data = self.load_from_cache(self.cache_file, self.lock_file) - if cached_data: - return self.parse_assets(cached_data) - - assets_data = self.fetch_assets_from_backend(self.asset_type.value) - if not assets_data or "items" not in assets_data: - return Enum(self.asset_type.name, {}), {} - - onboarded_assets = [ - asset for asset in assets_data["items"] - if asset.get("status", "").lower() == "onboarded" - ] - - self.save_to_cache(self.cache_file, {"items": onboarded_assets}, self.lock_file) - - return self.parse_assets({"items": onboarded_assets}) - - def parse_assets(self, assets_data: Dict) -> Tuple[Enum, Dict]: - """ - Convert asset data into an Enum and dictionary format for easy use. - """ - if not assets_data["items"]: - logging.warning(f"No onboarded {self.asset_type.value} found.") - return Enum(self.asset_type.name, {}), {} - - assets_enum = Enum( - self.asset_type.name, - {a["id"].upper().replace("-", "_"): a["id"] for a in assets_data["items"]}, - type=str, - ) - - assets_details = { - asset["id"]: Asset( - id=asset["id"], - name=asset.get("name", ""), - description=asset.get("description", ""), - api_key=asset.get("api_key", config.TEAM_API_KEY), - supplier=asset.get("supplier", "aiXplain"), - version=asset.get("version", "1.0"), - status=asset.get("status", "onboarded"), - created_at=asset.get("created_at", "") - ) - for asset in assets_data["items"] - } - - return assets_enum, assets_details - - def _initialize_assets(self) -> Tuple[Enum, Dict]: - cached_data = self.load_from_cache(self.cache_file, self.lock_file) - if cached_data: - return self.parse_assets(cached_data) - - assets_data = self.fetch_assets_from_backend(self.asset_type.value) - if not assets_data or "items" not in assets_data: - return Enum(self.asset_type.name, {}), {} - - onboarded_assets = [ - asset for asset in assets_data["items"] - if asset.get("status", "").lower() == "onboarded" - ] - self.save_to_cache(self.cache_file, {"items": onboarded_assets}, self.lock_file) - return self.parse_assets({"items": onboarded_assets}) + with FileLock(self.lock_file): + with open(self.cache_file, "w") as f: + # serialize the data manually + serializable_store = { + "expiry": self.compute_expiry(), + "data": { + asset_id: { + "id": model.id, + "name": model.name, + "description": model.description, + "api_key": model.api_key, + "supplier": model.supplier, + "version": model.version, + "created_at": model.created_at.isoformat() if hasattr(model.created_at, "isoformat") else model.created_at, + } + for asset_id, model in self.store.data.items() + }, + } + json.dump(serializable_store, f) + + + def get(self, asset_id: str) -> Optional[T]: + return self.store.data.get(asset_id) + + def add(self, asset: T): + self.store.data[asset.id] = asset + self.save() + + def add_model_list(self, models: List[T]): + self.store.data = {model.id: model for model in models} + self.save() + + def get_all_models(self) -> List[T]: + return list(self.store.data.values()) + + def has_valid_cache(self) -> bool: + return self.store.expiry >= time.time() \ No newline at end of file diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index 12f57644..b56ee5cb 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -40,18 +40,17 @@ def load_functions(): os.makedirs(CACHE_FOLDER, exist_ok=True) - resp = AssetCache.load_from_cache(CACHE_FILE, LOCK_FILE) - if resp is None: - url = urljoin(backend_url, "sdk/functions") - - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - r = _request_with_retry("get", url, headers=headers) - if not 200 <= r.status_code < 300: - raise Exception( - f'Functions could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' - ) - resp = r.json() - AssetCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) + # resp = AssetCache.load_from_cache(CACHE_FILE, LOCK_FILE) + url = urljoin(backend_url, "sdk/functions") + + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + r = _request_with_retry("get", url, headers=headers) + if not 200 <= r.status_code < 300: + raise Exception( + f'Functions could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' + ) + resp = r.json() + # AssetCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) class Function(str, Enum): def __new__(cls, value): diff --git a/aixplain/enums/language.py b/aixplain/enums/language.py index f1f184de..4da52826 100644 --- a/aixplain/enums/language.py +++ b/aixplain/enums/language.py @@ -32,21 +32,21 @@ def load_languages(): - resp = AssetCache.load_from_cache(CACHE_FILE, LOCK_FILE) - if resp is None: - api_key = config.TEAM_API_KEY - backend_url = config.BACKEND_URL + # resp = AssetCache._load_from_cache(CACHE_FILE, LOCK_FILE) + # if resp is None: + api_key = config.TEAM_API_KEY + backend_url = config.BACKEND_URL - url = urljoin(backend_url, "sdk/languages") + url = urljoin(backend_url, "sdk/languages") - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - r = _request_with_retry("get", url, headers=headers) - if not 200 <= r.status_code < 300: - raise Exception( - f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' - ) - resp = r.json() - AssetCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + r = _request_with_retry("get", url, headers=headers) + if not 200 <= r.status_code < 300: + raise Exception( + f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' + ) + resp = r.json() + # AssetCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) languages = {} for w in resp: diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index 90a873f3..e87393f5 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -26,30 +26,29 @@ from urllib.parse import urljoin from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry -from .asset_cache import AssetCache, CACHE_FOLDER +# from aixplain.enums import AssetCache -CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" -LOCK_FILE = f"{CACHE_FILE}.lock" +# CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" +# LOCK_FILE = f"{CACHE_FILE}.lock" def load_licenses(): - resp = AssetCache.load_from_cache(CACHE_FILE, LOCK_FILE) + # resp = AssetCache._load_from_cache(CACHE_FILE, LOCK_FILE) try: - if resp is None: - api_key = config.TEAM_API_KEY - backend_url = config.BACKEND_URL - - url = urljoin(backend_url, "sdk/licenses") - - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - r = _request_with_retry("get", url, headers=headers) - if not 200 <= r.status_code < 300: - raise Exception( - f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' - ) - resp = r.json() - AssetCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) + api_key = config.TEAM_API_KEY + backend_url = config.BACKEND_URL + + url = urljoin(backend_url, "sdk/licenses") + + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + r = _request_with_retry("get", url, headers=headers) + if not 200 <= r.status_code < 300: + raise Exception( + f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' + ) + resp = r.json() + # AssetCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) licenses = {"_".join(w["name"].split()): w["id"] for w in resp} return Enum("License", licenses, type=str) diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index b6c040a9..566e3f1d 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -29,7 +29,7 @@ from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin -from aixplain.enums import AssetCache, AssetType, Asset +from aixplain.enums import AssetCache from aixplain.factories.model_factory.utils import create_model_from_response @@ -90,17 +90,6 @@ def create_utility_model( if 200 <= r.status_code < 300: utility_model.id = resp["id"] logging.info(f"Utility Model Creation: Model {utility_model.id} instantiated.") - - new_asset = Asset( - id=resp["id"], - name=resp.get("name", ""), - description=resp.get("description", ""), - api_key=api_key or config.TEAM_API_KEY, - supplier=resp.get("supplier", "aiXplain"), - version=resp.get("version", "1.0"), - status=resp.get("status", "onboarded"), - created_at=resp.get("createdAt", "") - ) return utility_model else: error_message = ( @@ -111,81 +100,59 @@ def create_utility_model( @classmethod def get(cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = True) -> Model: - """Create a 'Model' object from model id - - Args: - model_id (Text): Model ID of required model. - api_key (Optional[Text], optional): Model API key. Defaults to None. - - Returns: - Model: Created 'Model' object - """ - cache = AssetCache(AssetType.MODELS) - asset = cache.assets_data.get(model_id) if use_cache else None - - if asset: - logging.info(f"ModelFactory: Loaded model {model_id} from cache.") - return Model( - id=asset.id, - name=asset.name, - description=asset.description, - api_key=api_key or asset.api_key, - supplier=asset.supplier, - version=asset.version, - created_at=asset.created_at, - ) - - url = urljoin(cls.backend_url, f"sdk/models/{model_id}") - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + """Create a 'Model' object from model id""" + cache = AssetCache(Model) + + if use_cache: + if cache.has_valid_cache(): + cached_model = cache.get(model_id) + if cached_model: + return cached_model + logging.info("Model not found in valid cache, fetching individually...") + model = cls._fetch_model_by_id(model_id, api_key) + cache.add(model) + return model + else: + try: + model_list_resp = cls.list(model_ids=None, api_key=api_key) + models = model_list_resp["results"] + cache.add_model_list(models) + for model in models: + if model.id == model_id: + return model + except Exception as e: + logging.error(f"Error fetching model list: {e}") + raise e + + logging.info("Fetching model directly without cache...") + return cls._fetch_model_by_id(model_id, api_key) + @classmethod + def _fetch_model_by_id(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: resp = None try: - logging.info(f"Fetching Model from backend...") url = urljoin(cls.backend_url, f"sdk/models/{model_id}") - - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - logging.info(f"Start service for GET Model - {url} - {headers}") + headers = {"Authorization": f"Token {api_key or config.TEAM_API_KEY}", "Content-Type": "application/json"} r = _request_with_retry("get", url, headers=headers) resp = r.json() - except Exception: - if resp is not None and "statusCode" in resp: + if resp and "statusCode" in resp: status_code = resp["statusCode"] - message = resp["message"] - message = f"Model Creation: Status {status_code} - {message}" + message = f"Model Creation: Status {status_code} - {resp['message']}" else: message = "Model Creation: Unspecified Error" logging.error(message) - raise Exception(f"{message}") - if 200 <= r.status_code < 300: - resp["api_key"] = config.TEAM_API_KEY - if api_key is not None: - resp["api_key"] = api_key - from aixplain.factories.model_factory.utils import create_model_from_response - - new_asset = Asset( - id=resp["id"], - name=resp.get("name", ""), - description=resp.get("description", ""), - api_key=api_key or config.TEAM_API_KEY, - supplier=resp.get("supplier", "aiXplain"), - version=resp.get("version", "1.0"), - status=resp.get("status", "onboarded"), - created_at=resp.get("createdAt", "") - ) - cache.assets_data[new_asset.id] = new_asset - serializable_data = {k: vars(v) for k, v in cache.assets_data.items()} - cache.save_to_cache(cache.cache_file, {"items": list(serializable_data.values())}, cache.lock_file) - - model = create_model_from_response(resp) - logging.info(f"ModelFactory: Model {model_id} fetched and cached successfully.") - return model + raise Exception(message) + if 200 <= r.status_code < 300: + resp["api_key"] = api_key or config.TEAM_API_KEY + return create_model_from_response(resp) else: error_message = f"Model GET Error: Failed to retrieve model {model_id}. Status Code: {r.status_code}. Error: {resp}" logging.error(error_message) raise Exception(error_message) + @classmethod def list( cls, diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 8ab5b38c..5c3fca5b 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -61,8 +61,8 @@ class Agent(Model): api_key (str): The TEAM API key used for authentication. cost (Dict, optional): model price. Defaults to None. """ - AgentCache = AssetCache("agents", "agents") - AgentEnum, AgentDetails = AgentCache.load_assets() + # AgentCache = AssetCache("agents", "agents") + # AgentEnum, AgentDetails = AgentCache.load_assets() is_valid: bool diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 14b24704..206ae3b3 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -32,7 +32,7 @@ from aixplain.utils.file_utils import _request_with_retry from typing import Dict, Optional, Text, Union from urllib.parse import urljoin -from aixplain.enums import AixplainCache +from aixplain.enums import AssetCache from aixplain.modules.pipeline.response import PipelineResponse @@ -75,7 +75,7 @@ def __init__( status (AssetStatus, optional): Pipeline status. Defaults to AssetStatus.DRAFT. **additional_info: Any additional Pipeline info to be saved """ - PipelineCache = AixplainCache("pipelines", "pipelines") + PipelineCache = AssetCache("pipelines", "pipelines") _, PipelineDetails = PipelineCache.load_assets() if id in PipelineDetails: cached_pipeline = PipelineDetails[id] From 134033c2b077f39ecb3d35360686b1a5381e20e3 Mon Sep 17 00:00:00 2001 From: xainaz Date: Mon, 21 Apr 2025 11:20:05 +0300 Subject: [PATCH 10/13] changes for functions, languages, licenses, agents, pipelines caching structure --- aixplain/enums/__init__.py | 1 - aixplain/enums/function.py | 79 +++++--- aixplain/enums/language.py | 56 ++++-- aixplain/enums/license.py | 41 ++-- aixplain/factories/agent_factory/__init__.py | 185 +++++++++++++----- aixplain/factories/model_factory/__init__.py | 138 +++++++++---- .../factories/pipeline_factory/__init__.py | 161 ++++++++++----- aixplain/modules/agent/__init__.py | 110 +++++++---- aixplain/modules/model/__init__.py | 98 ++++++++-- aixplain/modules/pipeline/asset.py | 152 ++++++++++---- aixplain/modules/pipeline/default.py | 17 +- aixplain/{enums => utils}/asset_cache.py | 80 +++----- 12 files changed, 785 insertions(+), 333 deletions(-) rename aixplain/{enums => utils}/asset_cache.py (57%) diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index 51c663de..555f4920 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -15,5 +15,4 @@ from .sort_by import SortBy from .sort_order import SortOrder from .response_status import ResponseStatus -from .asset_cache import AssetCache from .database_source import DatabaseSourceType diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index b56ee5cb..d606b98f 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -20,12 +20,12 @@ Description: Function Enum """ - +import logging from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from enum import Enum from urllib.parse import urljoin -from .asset_cache import AssetCache, CACHE_FOLDER +from ..utils.asset_cache import AssetCache, CACHE_FOLDER from typing import Tuple, Dict from aixplain.base.parameters import BaseParameters, Parameter import os @@ -34,23 +34,45 @@ LOCK_FILE = f"{CACHE_FILE}.lock" +class FunctionMetadata: + def __init__(self, data: dict): + self.__dict__.update(data) + + def __repr__(self): + return f"" + + def to_dict(self) -> dict: + return self.__dict__ + + @classmethod + def from_dict(cls, data: dict): + return cls(data) + + def load_functions(): api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL os.makedirs(CACHE_FOLDER, exist_ok=True) - # resp = AssetCache.load_from_cache(CACHE_FILE, LOCK_FILE) url = urljoin(backend_url, "sdk/functions") - - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - r = _request_with_retry("get", url, headers=headers) - if not 200 <= r.status_code < 300: - raise Exception( - f'Functions could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' - ) - resp = r.json() - # AssetCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) + cache = AssetCache(FunctionMetadata, cache_filename="functions") + if cache.has_valid_cache(): + logging.info("Loading functions from cache...") + function_objects = list(cache.store.data.values()) + else: + logging.info("Fetching functions from backend...") + url = urljoin(backend_url, "sdk/functions") + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + r = _request_with_retry("get", url, headers=headers) + if not 200 <= r.status_code < 300: + raise Exception( + f'Functions could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' + ) + resp = r.json() + results = resp.get("results") + function_objects = [FunctionMetadata(f) for f in results] + cache.add_list(function_objects) class Function(str, Enum): def __new__(cls, value): @@ -68,8 +90,12 @@ def get_input_output_params(self) -> Tuple[Dict, Dict]: function_io = FunctionInputOutput.get(self.value, None) if function_io is None: return {}, {} - input_params = {param["code"]: param for param in function_io["spec"]["params"]} - output_params = {param["code"]: param for param in function_io["spec"]["output"]} + input_params = { + param["code"]: param for param in function_io["spec"]["params"] + } + output_params = { + param["code"]: param for param in function_io["spec"]["output"] + } return input_params, output_params def get_parameters(self) -> "FunctionParameters": @@ -83,19 +109,18 @@ def get_parameters(self) -> "FunctionParameters": self._parameters = FunctionParameters(input_params) return self._parameters - functions = Function("Function", {w["id"].upper().replace("-", "_"): w["id"] for w in resp["items"]}) + functions = Function( + "Function", {f.id.upper().replace("-", "_"): f.id for f in function_objects} + ) functions_input_output = { - function["id"]: { - "input": { - input_data_object["dataType"] - for input_data_object in function["params"] - if input_data_object["required"] is True - }, - "output": {output_data_object["dataType"] for output_data_object in function["output"]}, - "spec": function, + f.id: { + "input": {p["dataType"] for p in f.params if p.get("required")}, + "output": {o["dataType"] for o in f.output}, + "spec": f.to_dict(), } - for function in resp["items"] + for f in function_objects } + return functions, functions_input_output @@ -110,7 +135,11 @@ def __init__(self, input_params: Dict): """ super().__init__() for param_code, param_config in input_params.items(): - self.parameters[param_code] = Parameter(name=param_code, required=param_config.get("required", False), value=None) + self.parameters[param_code] = Parameter( + name=param_code, + required=param_config.get("required", False), + value=None, + ) Function, FunctionInputOutput = load_functions() diff --git a/aixplain/enums/language.py b/aixplain/enums/language.py index 4da52826..f8a61381 100644 --- a/aixplain/enums/language.py +++ b/aixplain/enums/language.py @@ -25,39 +25,57 @@ from urllib.parse import urljoin from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry -from .asset_cache import AssetCache, CACHE_FOLDER +from aixplain.utils.asset_cache import AssetCache +import logging -CACHE_FILE = f"{CACHE_FOLDER}/languages.json" -LOCK_FILE = f"{CACHE_FILE}.lock" + +class LanguageMetadata: + def __init__(self, data: dict): + self.__dict__.update(data) + + def to_dict(self): + return self.__dict__ + + @classmethod + def from_dict(cls, data: dict): + return cls(data) def load_languages(): - # resp = AssetCache._load_from_cache(CACHE_FILE, LOCK_FILE) - # if resp is None: api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL url = urljoin(backend_url, "sdk/languages") + cache = AssetCache(LanguageMetadata, cache_filename="languages") - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - r = _request_with_retry("get", url, headers=headers) - if not 200 <= r.status_code < 300: - raise Exception( - f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' - ) - resp = r.json() - # AssetCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) + if cache.has_valid_cache(): + logging.info("Loading languages from cache...") + lang_entries = list(cache.store.data.values()) + else: + logging.info("Fetching languages from backend...") + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + r = _request_with_retry("get", url, headers=headers) + if not 200 <= r.status_code < 300: + raise Exception( + f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' + ) + resp = r.json() + lang_entries = [LanguageMetadata(item) for item in resp] + cache.add_list(lang_entries) languages = {} - for w in resp: - language = w["value"] - language_label = "_".join(w["label"].split()) - languages[language_label] = {"language": language, "dialect": ""} - for dialect in w["dialects"]: + for entry in lang_entries: + language = entry.value + label = "_".join(entry.label.split()) + languages[label] = {"language": language, "dialect": ""} + for dialect in entry.dialects: dialect_label = "_".join(dialect["label"].split()).upper() dialect_value = dialect["value"] + languages[f"{label}_{dialect_label}"] = { + "language": language, + "dialect": dialect_value, + } - languages[language_label + "_" + dialect_label] = {"language": language, "dialect": dialect_value} return Enum("Language", languages, type=dict) diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index e87393f5..5489173a 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -26,31 +26,46 @@ from urllib.parse import urljoin from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry -# from aixplain.enums import AssetCache +from aixplain.utils.asset_cache import AssetCache, CACHE_FOLDER -# CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" -# LOCK_FILE = f"{CACHE_FILE}.lock" + +class LicenseMetadata: + def __init__(self, data: dict): + self.__dict__.update(data) + + def to_dict(self): + return self.__dict__ + + @classmethod + def from_dict(cls, data: dict): + return cls(data) def load_licenses(): - # resp = AssetCache._load_from_cache(CACHE_FILE, LOCK_FILE) try: api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL url = urljoin(backend_url, "sdk/licenses") + cache = AssetCache(LicenseMetadata, cache_filename="licenses") - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - r = _request_with_retry("get", url, headers=headers) - if not 200 <= r.status_code < 300: - raise Exception( - f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' - ) - resp = r.json() - # AssetCache.save_to_cache(CACHE_FILE, resp, LOCK_FILE) + if cache.has_valid_cache(): + logging.info("Loading licenses from cache...") + license_objects = list(cache.store.data.values()) + else: + logging.info("Fetching licenses from backend...") + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + r = _request_with_retry("get", url, headers=headers) + if not 200 <= r.status_code < 300: + raise Exception( + f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation.' + ) + resp = r.json() + license_objects = [LicenseMetadata(item) for item in resp] + cache.add_list(license_objects) - licenses = {"_".join(w["name"].split()): w["id"] for w in resp} + licenses = {"_".join(lic.name.split()): lic.id for lic in license_objects} return Enum("License", licenses, type=str) except Exception: logging.exception("License Loading Error") diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index c537cb9b..ceb3c3d5 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -39,6 +39,7 @@ from aixplain.modules.model import Model from aixplain.modules.pipeline import Pipeline from aixplain.utils import config +from aixplain.utils.asset_cache import AssetCache from typing import Callable, Dict, List, Optional, Text, Union from aixplain.utils.file_utils import _request_with_retry @@ -100,21 +101,34 @@ def create( payload = { "name": name, "assets": [ - tool.to_dict() - if isinstance(tool, Tool) - else { - "id": tool.id, - "name": tool.name, - "description": tool.description, - "supplier": tool.supplier.value["code"] if isinstance(tool.supplier, Supplier) else tool.supplier, - "parameters": tool.get_parameters().to_list() - if hasattr(tool, "get_parameters") and tool.get_parameters() is not None - else None, - "function": tool.function if hasattr(tool, "function") and tool.function is not None else None, - "type": "model", - "version": tool.version if hasattr(tool, "version") else None, - "assetId": tool.id, - } + ( + tool.to_dict() + if isinstance(tool, Tool) + else { + "id": tool.id, + "name": tool.name, + "description": tool.description, + "supplier": ( + tool.supplier.value["code"] + if isinstance(tool.supplier, Supplier) + else tool.supplier + ), + "parameters": ( + tool.get_parameters().to_list() + if hasattr(tool, "get_parameters") + and tool.get_parameters() is not None + else None + ), + "function": ( + tool.function + if hasattr(tool, "function") and tool.function is not None + else None + ), + "type": "model", + "version": tool.version if hasattr(tool, "version") else None, + "assetId": tool.id, + } + ) for tool in tools ], "description": description, @@ -129,11 +143,15 @@ def create( agent.validate(raise_exception=True) response = "Unspecified error" try: - logging.debug(f"Start service for POST Create Agent - {url} - {headers} - {json.dumps(agent.to_dict())}") + logging.debug( + f"Start service for POST Create Agent - {url} - {headers} - {json.dumps(agent.to_dict())}" + ) r = _request_with_retry("post", url, headers=headers, json=agent.to_dict()) response = r.json() except Exception: - raise Exception("Agent Onboarding Error: Please contact the administrators.") + raise Exception( + "Agent Onboarding Error: Please contact the administrators." + ) if 200 <= r.status_code < 300: agent = build_agent(payload=response, tools=tools, api_key=api_key) @@ -152,9 +170,18 @@ def create( @classmethod def create_task( - cls, name: Text, description: Text, expected_output: Text, dependencies: Optional[List[Text]] = None + cls, + name: Text, + description: Text, + expected_output: Text, + dependencies: Optional[List[Text]] = None, ) -> AgentTask: - return AgentTask(name=name, description=description, expected_output=expected_output, dependencies=dependencies) + return AgentTask( + name=name, + description=description, + expected_output=expected_output, + dependencies=dependencies, + ) @classmethod def create_model_tool( @@ -172,14 +199,27 @@ def create_model_tool( if supplier is not None: if isinstance(supplier, str): for supplier_ in Supplier: - if supplier.lower() in [supplier_.value["code"].lower(), supplier_.value["name"].lower()]: + if supplier.lower() in [ + supplier_.value["code"].lower(), + supplier_.value["name"].lower(), + ]: supplier = supplier_ break - assert isinstance(supplier, Supplier), f"Supplier {supplier} is not a valid supplier" - return ModelTool(function=function, supplier=supplier, model=model, description=description, parameters=parameters) + assert isinstance( + supplier, Supplier + ), f"Supplier {supplier} is not a valid supplier" + return ModelTool( + function=function, + supplier=supplier, + model=model, + description=description, + parameters=parameters, + ) @classmethod - def create_pipeline_tool(cls, description: Text, pipeline: Union[Pipeline, Text]) -> PipelineTool: + def create_pipeline_tool( + cls, description: Text, pipeline: Union[Pipeline, Text] + ) -> PipelineTool: """Create a new pipeline tool.""" return PipelineTool(description=description, pipeline=pipeline) @@ -189,7 +229,9 @@ def create_python_interpreter_tool(cls) -> PythonInterpreterTool: return PythonInterpreterTool() @classmethod - def create_custom_python_code_tool(cls, code: Union[Text, Callable], description: Text = "") -> CustomPythonCodeTool: + def create_custom_python_code_tool( + cls, code: Union[Text, Callable], description: Text = "" + ) -> CustomPythonCodeTool: """Create a new custom python code tool.""" return CustomPythonCodeTool(description=description, code=code) @@ -254,7 +296,9 @@ def create_sql_tool( # Already the correct type, no conversion needed pass else: - raise SQLToolError(f"Source type must be either a string or DatabaseSourceType enum, got {type(source_type)}") + raise SQLToolError( + f"Source type must be either a string or DatabaseSourceType enum, got {type(source_type)}" + ) database_path = None # Final database path to pass to SQLTool @@ -283,7 +327,9 @@ def create_sql_tool( try: os.remove(db_path) except Exception as cleanup_error: - warnings.warn(f"Failed to remove temporary database file '{db_path}': {str(cleanup_error)}") + warnings.warn( + f"Failed to remove temporary database file '{db_path}': {str(cleanup_error)}" + ) raise SQLToolError(f"Failed to create database from CSV: {str(e)}") # Handle SQLite source type @@ -291,7 +337,9 @@ def create_sql_tool( if not os.path.exists(source): raise SQLToolError(f"Database '{source}' does not exist") if not source.endswith(".db") and not source.endswith(".sqlite"): - raise SQLToolError(f"Database '{source}' must have .db or .sqlite extension") + raise SQLToolError( + f"Database '{source}' must have .db or .sqlite extension" + ) database_path = source @@ -331,7 +379,9 @@ def list(cls) -> Dict: resp = {} payload = {} - logging.info(f"Start service for GET List Agents - {url} - {headers} - {json.dumps(payload)}") + logging.info( + f"Start service for GET List Agents - {url} - {headers} - {json.dumps(payload)}" + ) try: r = _request_with_retry("get", url, headers=headers) resp = r.json() @@ -343,10 +393,17 @@ def list(cls) -> Dict: results = resp page_total = len(results) total = len(results) - logging.info(f"Response for GET List Agents - Page Total: {page_total} / Total: {total}") + logging.info( + f"Response for GET List Agents - Page Total: {page_total} / Total: {total}" + ) for agent in results: agents.append(build_agent(agent)) - return {"results": agents, "page_total": page_total, "page_number": 0, "total": total} + return { + "results": agents, + "page_total": page_total, + "page_number": 0, + "total": total, + } else: error_msg = "Agent Listing Error: Please contact the administrators." if isinstance(resp, dict) and "message" in resp: @@ -356,22 +413,60 @@ def list(cls) -> Dict: raise Exception(error_msg) @classmethod - def get(cls, agent_id: Text, api_key: Optional[Text] = None) -> Agent: - """Get agent by id.""" + def get( + cls, agent_id: Text, api_key: Optional[Text] = None, use_cache: bool = True + ) -> Agent: from aixplain.factories.agent_factory.utils import build_agent + from aixplain.utils.asset_cache import AssetCache + + cache = AssetCache(Agent) + api_key = api_key or config.TEAM_API_KEY + + if use_cache: + if cache.has_valid_cache(): + cached_agent = cache.store.data.get(agent_id) + if cached_agent: + logging.info(f"Agent {agent_id} retrieved from valid cache.") + return cached_agent + else: + logging.info( + "No valid cache found — fetching full agent list to build cache." + ) + try: + agent_list_resp = cls.list() + agents = agent_list_resp.get("results", []) + cache.add_list(agents) + logging.info(f"Cache rebuilt with {len(agents)} agents.") + + for agent in agents: + if agent.id == agent_id: + logging.info( + f"Agent {agent_id} retrieved from newly built cache." + ) + return agent + except Exception as e: + logging.error(f"Error rebuilding agent cache: {e}") + raise e + # Fallback: direct fetch if cache not used or agent not found + logging.info(f"Fetching agent {agent_id} directly from backend.") url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent_id}") - - api_key = api_key if api_key is not None else config.TEAM_API_KEY headers = {"x-api-key": api_key, "Content-Type": "application/json"} - logging.info(f"Start service for GET Agent - {url} - {headers}") - r = _request_with_retry("get", url, headers=headers) - resp = r.json() - if 200 <= r.status_code < 300: - return build_agent(resp) - else: - msg = "Please contact the administrators." - if "message" in resp: - msg = resp["message"] - error_msg = f"Agent Get Error (HTTP {r.status_code}): {msg}" - raise Exception(error_msg) + + try: + r = _request_with_retry("get", url, headers=headers) + resp = r.json() + + if 200 <= r.status_code < 300: + agent = build_agent(resp) + cache.add(agent) # still helpful for future use + logging.info( + f"Agent {agent_id} fetched from backend and added to cache." + ) + return agent + else: + msg = resp.get("message", "Please contact the administrators.") + raise Exception(f"Agent Get Error (HTTP {r.status_code}): {msg}") + except Exception as e: + logging.exception(f"Agent Get Error: {e}") + raise diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index 566e3f1d..6300ab0f 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -25,11 +25,18 @@ import logging from aixplain.modules.model import Model from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput -from aixplain.enums import Function, Language, OwnershipType, Supplier, SortBy, SortOrder +from aixplain.enums import ( + Function, + Language, + OwnershipType, + Supplier, + SortBy, + SortOrder, +) from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin -from aixplain.enums import AssetCache +from aixplain.utils.asset_cache import AssetCache, CACHE_FOLDER from aixplain.factories.model_factory.utils import create_model_from_response @@ -80,7 +87,9 @@ def create_utility_model( url = urljoin(cls.backend_url, "sdk/utilities") headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} try: - logging.info(f"Start service for POST Utility Model - {url} - {headers} - {payload}") + logging.info( + f"Start service for POST Utility Model - {url} - {headers} - {payload}" + ) r = _request_with_retry("post", url, headers=headers, json=payload) resp = r.json() except Exception as e: @@ -89,23 +98,25 @@ def create_utility_model( if 200 <= r.status_code < 300: utility_model.id = resp["id"] - logging.info(f"Utility Model Creation: Model {utility_model.id} instantiated.") + logging.info( + f"Utility Model Creation: Model {utility_model.id} instantiated." + ) return utility_model else: - error_message = ( - f"Utility Model Creation: Failed to create utility model. Status Code: {r.status_code}. Error: {resp}" - ) + error_message = f"Utility Model Creation: Failed to create utility model. Status Code: {r.status_code}. Error: {resp}" logging.error(error_message) raise Exception(error_message) @classmethod - def get(cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = True) -> Model: + def get( + cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = True + ) -> Model: """Create a 'Model' object from model id""" cache = AssetCache(Model) if use_cache: if cache.has_valid_cache(): - cached_model = cache.get(model_id) + cached_model = cache.store.data.get(model_id) if cached_model: return cached_model logging.info("Model not found in valid cache, fetching individually...") @@ -116,7 +127,7 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = T try: model_list_resp = cls.list(model_ids=None, api_key=api_key) models = model_list_resp["results"] - cache.add_model_list(models) + cache.add_list(models) for model in models: if model.id == model_id: return model @@ -125,14 +136,22 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = T raise e logging.info("Fetching model directly without cache...") - return cls._fetch_model_by_id(model_id, api_key) + model = cls._fetch_model_by_id(model_id, api_key) + cache.add(model) + return model @classmethod - def _fetch_model_by_id(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: + def _fetch_model_by_id( + cls, model_id: Text, api_key: Optional[Text] = None + ) -> Model: resp = None try: url = urljoin(cls.backend_url, f"sdk/models/{model_id}") - headers = {"Authorization": f"Token {api_key or config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {api_key or config.TEAM_API_KEY}", + "Content-Type": "application/json", + } + logging.info(f"Start service for GET Model - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() except Exception: @@ -145,14 +164,21 @@ def _fetch_model_by_id(cls, model_id: Text, api_key: Optional[Text] = None) -> M raise Exception(message) if 200 <= r.status_code < 300: - resp["api_key"] = api_key or config.TEAM_API_KEY - return create_model_from_response(resp) + resp["api_key"] = config.TEAM_API_KEY + if api_key is not None: + resp["api_key"] = api_key + from aixplain.factories.model_factory.utils import ( + create_model_from_response, + ) + + model = create_model_from_response(resp) + logging.info(f"Model Creation: Model {model_id} instantiated.") + return model else: error_message = f"Model GET Error: Failed to retrieve model {model_id}. Status Code: {r.status_code}. Error: {resp}" logging.error(error_message) raise Exception(error_message) - @classmethod def list( cls, @@ -200,7 +226,9 @@ def list( and ownership is None and sort_by is None ), "Cannot filter by function, suppliers, source languages, target languages, is finetunable, ownership, sort by when using model ids" - assert len(model_ids) <= page_size, "Page size must be greater than the number of model ids" + assert ( + len(model_ids) <= page_size + ), "Page size must be greater than the number of model ids" models, total = get_model_from_ids(model_ids, api_key), len(model_ids) else: from aixplain.factories.model_factory.utils import get_assets_from_page @@ -242,7 +270,10 @@ def list_host_machines(cls, api_key: Optional[Text] = None) -> List[Dict]: if api_key: headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} else: - headers = {"x-api-key": f"{config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "x-api-key": f"{config.TEAM_API_KEY}", + "Content-Type": "application/json", + } response = _request_with_retry("get", machines_url, headers=headers) response_dicts = json.loads(response.text) for dictionary in response_dicts: @@ -261,15 +292,23 @@ def list_gpus(cls, api_key: Optional[Text] = None) -> List[List[Text]]: """ gpu_url = urljoin(config.BACKEND_URL, "sdk/model-onboarding/gpus") if api_key: - headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {api_key}", + "Content-Type": "application/json", + } else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {config.TEAM_API_KEY}", + "Content-Type": "application/json", + } response = _request_with_retry("get", gpu_url, headers=headers) response_list = json.loads(response.text) return response_list @classmethod - def list_functions(cls, verbose: Optional[bool] = False, api_key: Optional[Text] = None) -> List[Dict]: + def list_functions( + cls, verbose: Optional[bool] = False, api_key: Optional[Text] = None + ) -> List[Dict]: """Lists supported model functions on platform. Args: @@ -286,7 +325,10 @@ def list_functions(cls, verbose: Optional[bool] = False, api_key: Optional[Text] if api_key: headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} else: - headers = {"x-api-key": f"{config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "x-api-key": f"{config.TEAM_API_KEY}", + "Content-Type": "application/json", + } response = _request_with_retry("get", functions_url, headers=headers) response_dict = json.loads(response.text) if verbose: @@ -345,7 +387,10 @@ def create_asset_repo( if api_key: headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} else: - headers = {"x-api-key": f"{config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "x-api-key": f"{config.TEAM_API_KEY}", + "Content-Type": "application/json", + } payload = { "model": { @@ -361,7 +406,9 @@ def create_asset_repo( "onboardingParams": {}, } logging.debug(f"Body: {str(payload)}") - response = _request_with_retry("post", create_url, headers=headers, json=payload) + response = _request_with_retry( + "post", create_url, headers=headers, json=payload + ) assert response.status_code == 201 @@ -381,9 +428,15 @@ def asset_repo_login(cls, api_key: Optional[Text] = None) -> Dict: login_url = urljoin(config.BACKEND_URL, "sdk/ecr/login") logging.debug(f"URL: {login_url}") if api_key: - headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {api_key}", + "Content-Type": "application/json", + } else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {config.TEAM_API_KEY}", + "Content-Type": "application/json", + } response = _request_with_retry("post", login_url, headers=headers) response_dict = json.loads(response.text) return response_dict @@ -413,10 +466,15 @@ def onboard_model( if api_key: headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} else: - headers = {"x-api-key": f"{config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "x-api-key": f"{config.TEAM_API_KEY}", + "Content-Type": "application/json", + } payload = {"image": image_tag, "sha": image_hash, "hostMachine": host_machine} logging.debug(f"Body: {str(payload)}") - response = _request_with_retry("post", onboard_url, headers=headers, json=payload) + response = _request_with_retry( + "post", onboard_url, headers=headers, json=payload + ) if response.status_code == 201: message = "Your onboarding request has been submitted to an aiXplain specialist for finalization. We will notify you when the process is completed." logging.info(message) @@ -446,9 +504,15 @@ def deploy_huggingface_model( supplier, model_name = hf_repo_id.split("/") deploy_url = urljoin(config.BACKEND_URL, "sdk/model-onboarding/onboard") if api_key: - headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {api_key}", + "Content-Type": "application/json", + } else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {config.TEAM_API_KEY}", + "Content-Type": "application/json", + } body = { "model": { "name": name, @@ -472,7 +536,9 @@ def deploy_huggingface_model( return response_dicts @classmethod - def get_huggingface_model_status(cls, model_id: Text, api_key: Optional[Text] = None): + def get_huggingface_model_status( + cls, model_id: Text, api_key: Optional[Text] = None + ): """Gets the on-boarding status of a Hugging Face model with ID MODEL_ID. Args: @@ -483,9 +549,15 @@ def get_huggingface_model_status(cls, model_id: Text, api_key: Optional[Text] = """ status_url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") if api_key: - headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {api_key}", + "Content-Type": "application/json", + } else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {config.TEAM_API_KEY}", + "Content-Type": "application/json", + } response = _request_with_retry("get", status_url, headers=headers) logging.debug(response.text) response_dicts = json.loads(response.text) diff --git a/aixplain/factories/pipeline_factory/__init__.py b/aixplain/factories/pipeline_factory/__init__.py index cfbfce54..72112654 100644 --- a/aixplain/factories/pipeline_factory/__init__.py +++ b/aixplain/factories/pipeline_factory/__init__.py @@ -30,6 +30,7 @@ from aixplain.enums.supplier import Supplier from aixplain.modules.model import Model from aixplain.modules.pipeline import Pipeline +from aixplain.utils.asset_cache import AssetCache from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin @@ -46,7 +47,9 @@ class PipelineFactory: backend_url = config.BACKEND_URL @classmethod - def get(cls, pipeline_id: Text, api_key: Optional[Text] = None) -> Pipeline: + def get( + cls, pipeline_id: Text, api_key: Optional[Text] = None, use_cache: bool = True + ) -> Pipeline: """Create a 'Pipeline' object from pipeline id Args: @@ -56,48 +59,36 @@ def get(cls, pipeline_id: Text, api_key: Optional[Text] = None) -> Pipeline: Returns: Pipeline: Created 'Pipeline' object """ - resp = None - try: - url = urljoin(cls.backend_url, f"sdk/pipelines/{pipeline_id}") - if api_key is not None: - headers = { - "Authorization": f"Token {api_key}", - "Content-Type": "application/json", - } - else: - headers = { - "Authorization": f"Token {config.TEAM_API_KEY}", - "Content-Type": "application/json", - } - logging.info(f"Start service for GET Pipeline - {url} - {headers}") - r = _request_with_retry("get", url, headers=headers) - resp = r.json() - - except Exception as e: - logging.exception(e) - status_code = 400 - if resp is not None and "statusCode" in resp: - status_code = resp["statusCode"] - message = resp["message"] - message = f"Pipeline Creation: Status {status_code} - {message}" + cache = AssetCache(Pipeline) + if use_cache: + if cache.has_valid_cache(): + cached_pipeline = cache.store.data.get(pipeline_id) + if cached_pipeline: + return cached_pipeline + logging.info( + "Pipeline not found in valid cache, fetching individually..." + ) + pipeline = cls._fetch_pipeline_by_id(pipeline_id) + cache.add(pipeline) + return pipeline else: - message = f"Pipeline Creation: Unspecified Error {e}" - logging.error(message) - raise Exception(f"Status {status_code}: {message}") - if 200 <= r.status_code < 300: - resp["api_key"] = config.TEAM_API_KEY - if api_key is not None: - resp["api_key"] = api_key - pipeline = build_from_response(resp, load_architecture=True) - logging.info(f"Pipeline {pipeline_id} retrieved successfully.") - return pipeline - - else: - error_message = ( - f"Pipeline GET Error: Failed to retrieve pipeline {pipeline_id}. Status Code: {r.status_code}. Error: {resp}" - ) - logging.error(error_message) - raise Exception(error_message) + try: + pipeline_list_resp = cls.list() + pipeline_dicts = pipeline_list_resp.get("results", []) + pipelines = pipeline_dicts + cache.add_list(pipelines) + + for pipeline in pipelines: + if pipeline.id == pipeline_id: + return pipeline + + except Exception as e: + logging.error(f"Error fetching pipeline list: {e}") + raise e + logging.info("Fetching pipeline directly without cache...") + pipeline = cls._fetch_pipeline_by_id(pipeline_id, config.TEAM_API_KEY) + cache.add(pipeline) + return pipeline @classmethod def create_asset_from_id(cls, pipeline_id: Text) -> Pipeline: @@ -126,9 +117,14 @@ def get_assets_from_page(cls, page_number: int) -> List[Pipeline]: } r = _request_with_retry("get", url, headers=headers) resp = r.json() - logging.info(f"Listing Pipelines: Status of getting Pipelines on Page {page_number}: {resp}") + logging.info( + f"Listing Pipelines: Status of getting Pipelines on Page {page_number}: {resp}" + ) all_pipelines = resp["items"] - pipeline_list = [build_from_response(pipeline_info_json) for pipeline_info_json in all_pipelines] + pipeline_list = [ + build_from_response(pipeline_info_json) + for pipeline_info_json in all_pipelines + ] return pipeline_list except Exception as e: error_message = f"Listing Pipelines: Error in getting Pipelines on Page {page_number}: {e}" @@ -176,7 +172,9 @@ def list( "Content-Type": "application/json", } - assert 0 < page_size <= 100, "Pipeline List Error: Page size must be greater than 0 and not exceed 100." + assert ( + 0 < page_size <= 100 + ), "Pipeline List Error: Page size must be greater than 0 and not exceed 100." payload = { "pageSize": page_size, "pageNumber": page_number, @@ -191,7 +189,6 @@ def list( if isinstance(functions, Function) is True: functions = [functions] payload["functions"] = [function.value for function in functions] - if suppliers is not None: if isinstance(suppliers, Supplier) is True: suppliers = [suppliers] @@ -205,14 +202,20 @@ def list( if input_data_types is not None: if isinstance(input_data_types, DataType) is True: input_data_types = [input_data_types] - payload["inputDataTypes"] = [data_type.value for data_type in input_data_types] + payload["inputDataTypes"] = [ + data_type.value for data_type in input_data_types + ] if output_data_types is not None: if isinstance(output_data_types, DataType) is True: output_data_types = [output_data_types] - payload["inputDataTypes"] = [data_type.value for data_type in output_data_types] + payload["inputDataTypes"] = [ + data_type.value for data_type in output_data_types + ] - logging.info(f"Start service for POST List Pipeline - {url} - {headers} - {json.dumps(payload)}") + logging.info( + f"Start service for POST List Pipeline - {url} - {headers} - {json.dumps(payload)}" + ) try: r = _request_with_retry("post", url, headers=headers, json=payload) resp = r.json() @@ -227,7 +230,9 @@ def list( results = resp["items"] page_total = resp["pageTotal"] total = resp["total"] - logging.info(f"Response for POST List Pipeline - Page Total: {page_total} / Total: {total}") + logging.info( + f"Response for POST List Pipeline - Page Total: {page_total} / Total: {total}" + ) for pipeline in results: pipelines.append(build_from_response(pipeline)) return { @@ -294,7 +299,9 @@ def create( for i, node in enumerate(pipeline["nodes"]): if "functionType" in node: - pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i]["functionType"].lower() + pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i][ + "functionType" + ].lower() # prepare payload payload = { "name": name, @@ -307,10 +314,60 @@ def create( "Authorization": f"Token {api_key}", "Content-Type": "application/json", } - logging.info(f"Start service for POST Create Pipeline - {url} - {headers} - {json.dumps(payload)}") + logging.info( + f"Start service for POST Create Pipeline - {url} - {headers} - {json.dumps(payload)}" + ) r = _request_with_retry("post", url, headers=headers, json=payload) response = r.json() return Pipeline(response["id"], name, api_key) except Exception as e: raise Exception(e) + + @classmethod + def _fetch_pipeline_by_id( + cls, pipeline_id: Text, api_key: Optional[Text] = None + ) -> Pipeline: + """Fetch a Pipeline by ID from the backend (no cache)""" + + resp = None + try: + url = urljoin(cls.backend_url, f"sdk/pipelines/{pipeline_id}") + if api_key is not None: + headers = { + "Authorization": f"Token {api_key}", + "Content-Type": "application/json", + } + else: + headers = { + "Authorization": f"Token {config.TEAM_API_KEY}", + "Content-Type": "application/json", + } + + logging.info(f"Start service for GET Pipeline - {url} - {headers}") + r = _request_with_retry("get", url, headers=headers) + resp = r.json() + + except Exception as e: + logging.exception(e) + status_code = 400 + if resp is not None and "statusCode" in resp: + status_code = resp["statusCode"] + message = resp["message"] + message = f"Pipeline Creation: Status {status_code} - {message}" + else: + message = f"Pipeline Creation: Unspecified Error {e}" + logging.error(message) + raise Exception(f"Status {status_code}: {message}") + if 200 <= r.status_code < 300: + resp["api_key"] = config.TEAM_API_KEY + if api_key is not None: + resp["api_key"] = api_key + pipeline = build_from_response(resp, load_architecture=True) + logging.info(f"Pipeline {pipeline_id} retrieved successfully.") + return pipeline + + else: + error_message = f"Pipeline GET Error: Failed to retrieve pipeline {pipeline_id}. Status Code: {r.status_code}. Error: {resp}" + logging.error(error_message) + raise Exception(error_message) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 5c3fca5b..8a68a3c3 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -37,7 +37,7 @@ from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData -from aixplain.enums import ResponseStatus, AssetCache +from aixplain.enums import ResponseStatus from aixplain.modules.agent.utils import process_variables from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin @@ -61,8 +61,6 @@ class Agent(Model): api_key (str): The TEAM API key used for authentication. cost (Dict, optional): model price. Defaults to None. """ - # AgentCache = AssetCache("agents", "agents") - # AgentEnum, AgentDetails = AgentCache.load_assets() is_valid: bool @@ -82,21 +80,6 @@ def __init__( tasks: List[AgentTask] = [], **additional_info, ) -> None: - - if id in self.__class__.AgentDetails: - cached_agent = self.__class__.AgentDetails[id] - name = cached_agent.get("name", name) - description = cached_agent.get("description", description) - tools = cached_agent.get("tools", tools) - llm_id = cached_agent.get("llm_id", llm_id) - api_key = cached_agent.get("api_key", api_key) - supplier = cached_agent.get("supplier", supplier) - version = cached_agent.get("version", version) - cost = cached_agent.get("cost", cost) - status = cached_agent.get("status", status) - tasks = cached_agent.get("tasks", tasks) - additional_info = cached_agent.get("additional_info", additional_info) - """Create an Agent with the necessary information. Args: @@ -142,13 +125,17 @@ def _validate(self) -> None: except Exception: raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.") - assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." + assert ( + llm.function == Function.TEXT_GENERATION + ), "Large Language Model must be a text generation model." for tool in self.tools: if isinstance(tool, Tool): tool.validate() elif isinstance(tool, Model): - assert not isinstance(tool, Agent), "Agent cannot contain another Agent." + assert not isinstance( + tool, Agent + ), "Agent cannot contain another Agent." def validate(self, raise_exception: bool = False) -> bool: """Validate the Agent.""" @@ -161,7 +148,9 @@ def validate(self, raise_exception: bool = False) -> bool: raise e else: logging.warning(f"Agent Validation Error: {e}") - logging.warning("You won't be able to run the Agent until the issues are handled manually.") + logging.warning( + "You won't be able to run the Agent until the issues are handled manually." + ) return self.is_valid def run( @@ -218,7 +207,9 @@ def run( return response poll_url = response["url"] end = time.time() - result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + result = self.sync_poll( + poll_url, name=name, timeout=timeout, wait_time=wait_time + ) result_data = result.data return AgentResponse( status=ResponseStatus.SUCCESS, @@ -281,12 +272,18 @@ def run_async( from aixplain.factories.file_factory import FileFactory if not self.is_valid: - raise Exception("Agent is not valid. Please validate the agent before running.") + raise Exception( + "Agent is not valid. Please validate the agent before running." + ) - assert data is not None or query is not None, "Either 'data' or 'query' must be provided." + assert ( + data is not None or query is not None + ), "Either 'data' or 'query' must be provided." if data is not None: if isinstance(data, dict): - assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided." + assert ( + "query" in data and data["query"] is not None + ), "When providing a dictionary, 'query' must be provided." query = data.get("query") if session_id is None: session_id = data.get("session_id") @@ -299,7 +296,9 @@ def run_async( # process content inputs if content is not None: - assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text." + assert ( + FileFactory.check_storage_type(query) == StorageType.TEXT + ), "When providing 'content', query must be text." if isinstance(content, list): assert len(content) <= 3, "The maximum number of content inputs is 3." @@ -308,7 +307,9 @@ def run_async( query += f"\n{input_link}" elif isinstance(content, dict): for key, value in content.items(): - assert "{{" + key + "}}" in query, f"Key '{key}' not found in query." + assert ( + "{{" + key + "}}" in query + ), f"Key '{key}' not found in query." value = FileFactory.to_link(value) query = query.replace("{{" + key + "}}", f"'{value}'") @@ -323,8 +324,16 @@ def run_async( "sessionId": session_id, "history": history, "executionParams": { - "maxTokens": (parameters["max_tokens"] if "max_tokens" in parameters else max_tokens), - "maxIterations": (parameters["max_iterations"] if "max_iterations" in parameters else max_iterations), + "maxTokens": ( + parameters["max_tokens"] + if "max_tokens" in parameters + else max_tokens + ), + "maxIterations": ( + parameters["max_iterations"] + if "max_iterations" in parameters + else max_iterations + ), "outputFormat": output_format.value, }, } @@ -358,7 +367,11 @@ def to_dict(self) -> Dict: "assets": [tool.to_dict() for tool in self.tools], "description": self.description, "role": self.instructions, - "supplier": (self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier), + "supplier": ( + self.supplier.value["code"] + if isinstance(self.supplier, Supplier) + else self.supplier + ), "version": self.version, "llmId": self.llm_id, "status": self.status.value, @@ -396,7 +409,8 @@ def update(self) -> None: stack = inspect.stack() if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " + "Please use save() instead.", DeprecationWarning, stacklevel=2, ) @@ -408,7 +422,9 @@ def update(self) -> None: payload = self.to_dict() - logging.debug(f"Start service for PUT Update Agent - {url} - {headers} - {json.dumps(payload)}") + logging.debug( + f"Start service for PUT Update Agent - {url} - {headers} - {json.dumps(payload)}" + ) resp = "No specified error." try: r = _request_with_retry("put", url, headers=headers, json=payload) @@ -427,10 +443,38 @@ def save(self) -> None: self.update() def deploy(self) -> None: - assert self.status == AssetStatus.DRAFT, "Agent must be in draft status to be deployed." + assert ( + self.status == AssetStatus.DRAFT + ), "Agent must be in draft status to be deployed." assert self.status != AssetStatus.ONBOARDED, "Agent is already deployed." self.status = AssetStatus.ONBOARDED self.update() def __repr__(self): return f"Agent(id={self.id}, name={self.name}, function={self.function})" + + @classmethod + def from_dict(cls, data: dict) -> "Agent": + return cls( + id=data.get("id"), + name=data.get("name"), + description=data.get("description", ""), + instructions=data.get("role", ""), + tools=[], + llm_id=data.get("llmId"), + api_key=data.get("api_key", config.TEAM_API_KEY), + supplier=data.get("supplier", "aiXplain"), + version=data.get("version"), + cost=data.get("cost"), + status=( + AssetStatus(data["status"]) if data.get("status") else AssetStatus.DRAFT + ), + tasks=( + [AgentTask.from_dict(t) for t in data.get("tasks", [])] + if "tasks" in data + else [] + ), + function=( + Function(data["function"]) if data.get("function") is not None else None + ), + ) diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 3bac37af..25aaa414 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -109,7 +109,9 @@ def to_dict(self) -> Dict: Returns: Dict: Model Information """ - clean_additional_info = {k: v for k, v in self.additional_info.items() if v is not None} + clean_additional_info = { + k: v for k, v in self.additional_info.items() if v is not None + } return { "id": self.id, "name": self.name, @@ -119,6 +121,7 @@ def to_dict(self) -> Dict: "input_params": self.input_params, "output_params": self.output_params, "model_params": self.model_params.to_dict(), + "function": self.function, } def get_parameters(self) -> ModelParameters: @@ -133,7 +136,11 @@ def __repr__(self): return f"" def sync_poll( - self, poll_url: Text, name: Text = "model_process", wait_time: float = 0.5, timeout: float = 300 + self, + poll_url: Text, + name: Text = "model_process", + wait_time: float = 0.5, + timeout: float = 300, ) -> ModelResponse: """Keeps polling the platform to check whether an asynchronous call is done. @@ -164,15 +171,21 @@ def sync_poll( wait_time *= 1.1 except Exception as e: response_body = ModelResponse( - status=ResponseStatus.FAILED, completed=False, error_message="No response from the service." + status=ResponseStatus.FAILED, + completed=False, + error_message="No response from the service.", ) logging.error(f"Polling for Model: polling for {name}: {e}") break if response_body["completed"] is True: - logging.debug(f"Polling for Model: Final status of polling for {name}: {response_body}") + logging.debug( + f"Polling for Model: Final status of polling for {name}: {response_body}" + ) else: response_body = ModelResponse( - status=ResponseStatus.FAILED, completed=False, error_message="No response from the service." + status=ResponseStatus.FAILED, + completed=False, + error_message="No response from the service.", ) logging.error( f"Polling for Model: Final status of polling for {name}: No response in {timeout} seconds - {response_body}" @@ -199,7 +212,9 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: status = ResponseStatus.FAILED else: status = ResponseStatus.IN_PROGRESS - logging.debug(f"Single Poll for Model: Status of polling for {name}: {resp}") + logging.debug( + f"Single Poll for Model: Status of polling for {name}: {resp}" + ) return ModelResponse( status=resp.pop("status", status), data=resp.pop("data", ""), @@ -249,12 +264,18 @@ def run( try: poll_url = response["url"] end = time.time() - return self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + return self.sync_poll( + poll_url, name=name, timeout=timeout, wait_time=wait_time + ) except Exception as e: msg = f"Error in request for {name} - {traceback.format_exc()}" logging.error(f"Model Run: Error in running for {name}: {e}") end = time.time() - response = {"status": "FAILED", "error_message": msg, "runTime": end - start} + response = { + "status": "FAILED", + "error_message": msg, + "runTime": end - start, + } return ModelResponse( status=response.pop("status", ResponseStatus.FAILED), data=response.pop("data", ""), @@ -268,7 +289,10 @@ def run( ) def run_async( - self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = None + self, + data: Union[Text, Dict], + name: Text = "model_process", + parameters: Optional[Dict] = None, ) -> ModelResponse: """Runs asynchronously a model call. @@ -313,7 +337,9 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): resp = None try: url = urljoin(self.backend_url, f"sdk/finetune/{self.id}/ml-logs") - logging.info(f"Start service for GET Check FineTune status Model - {url} - {headers}") + logging.info( + f"Start service for GET Check FineTune status Model - {url} - {headers}" + ) r = _request_with_retry("get", url, headers=headers) resp = r.json() finetune_status = AssetStatus(resp["finetuneStatus"]) @@ -340,9 +366,21 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): status = FinetuneStatus( status=finetune_status, model_status=model_status, - epoch=float(log["epoch"]) if "epoch" in log and log["epoch"] is not None else None, - training_loss=float(log["trainLoss"]) if "trainLoss" in log and log["trainLoss"] is not None else None, - validation_loss=float(log["evalLoss"]) if "evalLoss" in log and log["evalLoss"] is not None else None, + epoch=( + float(log["epoch"]) + if "epoch" in log and log["epoch"] is not None + else None + ), + training_loss=( + float(log["trainLoss"]) + if "trainLoss" in log and log["trainLoss"] is not None + else None + ), + validation_loss=( + float(log["evalLoss"]) + if "evalLoss" in log and log["evalLoss"] is not None + else None + ), ) else: status = FinetuneStatus( @@ -350,7 +388,9 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): model_status=model_status, ) - logging.info(f"Response for GET Check FineTune status Model - Id {self.id} / Status {status.status.value}.") + logging.info( + f"Response for GET Check FineTune status Model - Id {self.id} / Status {status.status.value}." + ) return status except Exception: message = "" @@ -365,7 +405,10 @@ def delete(self) -> None: """Delete Model service""" try: url = urljoin(self.backend_url, f"sdk/models/{self.id}") - headers = {"Authorization": f"Token {self.api_key}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {self.api_key}", + "Content-Type": "application/json", + } logging.info(f"Start service for DELETE Model - {url} - {headers}") r = _request_with_retry("delete", url, headers=headers) if r.status_code != 200: @@ -373,4 +416,27 @@ def delete(self) -> None: except Exception: message = "Model Deletion Error: Make sure the model exists and you are the owner." logging.error(message) - raise Exception(f"{message}") \ No newline at end of file + raise Exception(f"{message}") + + @classmethod + def from_dict(cls, data: Dict) -> "Model": + return cls( + id=data.get("id", ""), + name=data.get("name", ""), + description=data.get("description", ""), + api_key=data.get("api_key", config.TEAM_API_KEY), + supplier=data.get("supplier", "aiXplain"), + version=data.get("version", "1.0"), + function=Function(data.get("function")), + is_subscribed=data.get("is_subscribed", False), + cost=data.get("cost"), + created_at=( + datetime.fromisoformat(data["created_at"]) + if data.get("created_at") + else None + ), + input_params=data.get("input_params"), + output_params=data.get("output_params"), + model_params=data.get("model_params"), + **data.get("additional_info", {}), + ) diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 206ae3b3..18ba40ca 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -32,7 +32,6 @@ from aixplain.utils.file_utils import _request_with_retry from typing import Dict, Optional, Text, Union from urllib.parse import urljoin -from aixplain.enums import AssetCache from aixplain.modules.pipeline.response import PipelineResponse @@ -75,17 +74,6 @@ def __init__( status (AssetStatus, optional): Pipeline status. Defaults to AssetStatus.DRAFT. **additional_info: Any additional Pipeline info to be saved """ - PipelineCache = AssetCache("pipelines", "pipelines") - _, PipelineDetails = PipelineCache.load_assets() - if id in PipelineDetails: - cached_pipeline = PipelineDetails[id] - name = cached_pipeline["name"] - api_key = cached_pipeline["api_key"] - supplier = cached_pipeline["supplier"] - version = cached_pipeline["version"] - status = cached_pipeline["status"] - additional_info = cached_pipeline["architecture"] - if not name: raise ValueError("Pipeline name is required") @@ -126,28 +114,39 @@ def __polling( while not response_body["completed"] and (end - start) < timeout: try: response_body = self.poll(poll_url, name=name) - logging.debug(f"Polling for Pipeline: Status of polling for {name} : {response_body}") + logging.debug( + f"Polling for Pipeline: Status of polling for {name} : {response_body}" + ) end = time.time() if not response_body["completed"]: time.sleep(wait_time) if wait_time < 60: wait_time *= 1.1 except Exception: - logging.error(f"Polling for Pipeline: polling for {name} : Continue") + logging.error( + f"Polling for Pipeline '{self.id}': polling for {name} ({poll_url}): Continue" + ) break if response_body["status"] == ResponseStatus.SUCCESS: try: - logging.debug(f"Polling for Pipeline: Final status of polling for {name} : SUCCESS - {response_body}") + logging.debug( + f"Polling for Pipeline '{self.id}' - Final status of polling for {name} ({poll_url}): SUCCESS - {response_body}" + ) except Exception: - logging.error(f"Polling for Pipeline: Final status of polling for {name} : ERROR - {response_body}") + logging.error( + f"Polling for Pipeline '{self.id}' - Final status of polling for {name} ({poll_url}): ERROR - {response_body}" + ) else: logging.error( - f"Polling for Pipeline: Final status of polling for {name} : No response in {timeout} seconds - {response_body}" + f"Polling for Pipeline '{self.id}' - Final status of polling for {name} ({poll_url}): No response in {timeout} seconds - {response_body}" ) return response_body def poll( - self, poll_url: Text, name: Text = "pipeline_process", response_version: Text = "v2" + self, + poll_url: Text, + name: Text = "pipeline_process", + response_version: Text = "v2", ) -> Union[Dict, PipelineResponse]: """Poll the platform to check whether an asynchronous call is done. @@ -171,7 +170,9 @@ def poll( resp["data"] = json.loads(resp["data"])["response"] except Exception: resp = r.json() - logging.info(f"Single Poll for Pipeline: Status of polling for {name} : {resp}") + logging.info( + f"Single Poll for Pipeline '{self.id}' - Status of polling for {name} ({poll_url}): {resp}" + ) if response_version == "v1": return resp status = ResponseStatus(resp.pop("status", "failed")) @@ -204,7 +205,9 @@ def run( ) -> Union[Dict, PipelineResponse]: start = time.time() try: - response = self.run_async(data, data_asset=data_asset, name=name, version=version, **kwargs) + response = self.run_async( + data, data_asset=data_asset, name=name, version=version, **kwargs + ) if response["status"] == ResponseStatus.FAILED: end = time.time() if response_version == "v1": @@ -221,15 +224,23 @@ def run( **kwargs, ) poll_url = response["url"] - polling_response = self.__polling(poll_url, name=name, timeout=timeout, wait_time=wait_time) + polling_response = self.__polling( + poll_url, name=name, timeout=timeout, wait_time=wait_time + ) end = time.time() - status = ResponseStatus(polling_response["status"]) + if isinstance(polling_response, dict): + status = ResponseStatus(polling_response.get("status", "failed")) + completed = polling_response.get("completed", False) + else: + status = polling_response.status + completed = getattr(polling_response, "completed", False) if response_version == "v1": polling_response["elapsed_time"] = end - start return polling_response status = ResponseStatus(polling_response.status) return PipelineResponse( status=status, + completed=completed, error=polling_response.error, elapsed_time=end - start, data=getattr(polling_response, "data", {}), @@ -292,7 +303,10 @@ def __prepare_payload( try: payload = json.loads(data) if isinstance(payload, dict) is False: - if isinstance(payload, int) is True or isinstance(payload, float) is True: + if ( + isinstance(payload, int) is True + or isinstance(payload, float) is True + ): payload = str(payload) payload = {"data": payload} except Exception: @@ -330,7 +344,9 @@ def __prepare_payload( asset_payload["dataAsset"]["dataset_id"] = dasset.id source_data_list = [ - dfield for dfield in dasset.source_data if dasset.source_data[dfield].id == data[node_label] + dfield + for dfield in dasset.source_data + if dasset.source_data[dfield].id == data[node_label] ] if len(source_data_list) > 0: @@ -404,7 +420,9 @@ def run_async( try: if 200 <= r.status_code < 300: resp = r.json() - logging.info(f"Result of request for {name} - {r.status_code} - {resp}") + logging.info( + f"Result of request for {name} - {r.status_code} - {resp}" + ) if response_version == "v1": return resp res = PipelineResponse( @@ -428,11 +446,11 @@ def run_async( error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." else: status_code = str(r.status_code) - error = ( - f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." - ) + error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." - logging.error(f"Error in request for {name} - {r.status_code}: {error}") + logging.error( + f"Error in request for {name} (Pipeline ID '{self.id}') - {r.status_code}: {error}" + ) if response_version == "v1": return { "status": "failed", @@ -485,7 +503,8 @@ def update( stack = inspect.stack() if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " + "Please use save() instead.", DeprecationWarning, stacklevel=2, ) @@ -500,7 +519,9 @@ def update( for i, node in enumerate(pipeline["nodes"]): if "functionType" in node: - pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i]["functionType"].lower() + pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i][ + "functionType" + ].lower() # prepare payload status = "draft" if save_as_asset is True: @@ -518,7 +539,9 @@ def update( "Authorization": f"Token {api_key}", "Content-Type": "application/json", } - logging.info(f"Start service for PUT Update Pipeline - {url} - {headers} - {json.dumps(payload)}") + logging.info( + f"Start service for PUT Update Pipeline - {url} - {headers} - {json.dumps(payload)}" + ) r = _request_with_retry("put", url, headers=headers, json=payload) response = r.json() logging.info(f"Pipeline {response['id']} Updated.") @@ -569,11 +592,15 @@ def save( ), "Pipeline Update Error: Make sure the pipeline to be saved is in a JSON file." with open(pipeline) as f: pipeline = json.load(f) - self.update(pipeline=pipeline, save_as_asset=save_as_asset, api_key=api_key) + self.update( + pipeline=pipeline, save_as_asset=save_as_asset, api_key=api_key + ) for i, node in enumerate(pipeline["nodes"]): if "functionType" in node: - pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i]["functionType"].lower() + pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i][ + "functionType" + ].lower() # prepare payload status = "draft" if save_as_asset is True: @@ -590,7 +617,9 @@ def save( "Authorization": f"Token {api_key}", "Content-Type": "application/json", } - logging.info(f"Start service for Save Pipeline - {url} - {headers} - {json.dumps(payload)}") + logging.info( + f"Start service for Save Pipeline - {url} - {headers} - {json.dumps(payload)}" + ) r = _request_with_retry("post", url, headers=headers, json=payload) response = r.json() self.id = response["id"] @@ -600,12 +629,59 @@ def save( def deploy(self, api_key: Optional[Text] = None) -> None: """Deploy the Pipeline.""" - assert self.status == "draft", "Pipeline Deployment Error: Pipeline must be in draft status." - assert self.status != "onboarded", "Pipeline Deployment Error: Pipeline must be onboarded." + assert ( + self.status == "draft" + ), "Pipeline Deployment Error: Pipeline must be in draft status." + assert ( + self.status != "onboarded" + ), "Pipeline Deployment Error: Pipeline must be onboarded." pipeline = self.to_dict() - self.update(pipeline=pipeline, save_as_asset=True, api_key=api_key, name=self.name) + self.update( + pipeline=pipeline, save_as_asset=True, api_key=api_key, name=self.name + ) self.status = AssetStatus.ONBOARDED def __repr__(self): return f"Pipeline(id={self.id}, name={self.name})" + + @classmethod + def from_dict(cls, data: Dict) -> "Pipeline": + """ + Create a Pipeline instance from a dictionary. + + Args: + data (Dict): A dictionary containing pipeline attributes. + + Returns: + Pipeline: An instance of the Pipeline class. + """ + return cls( + id=data.get("id"), + name=data.get("name"), + api_key=data.get("api_key", config.TEAM_API_KEY), + url=data.get("url", config.BACKEND_URL), + supplier=data.get("supplier", "aiXplain"), + version=data.get("version", "1.0"), + status=data.get("status", AssetStatus.DRAFT), + **data.get("additional_info", {}), + ) + + def to_dict(self) -> Dict: + """ + Serialize the Pipeline object to a dictionary. + + Returns: + Dict: A dictionary representation of the Pipeline instance. + """ + logging.info("corect to dict") + return { + "id": self.id, + "name": self.name, + "api_key": self.api_key, + "url": self.url, + "supplier": self.supplier, + "version": self.version, + "status": self.status.name if hasattr(self.status, "name") else self.status, + "additional_info": self.additional_info, + } diff --git a/aixplain/modules/pipeline/default.py b/aixplain/modules/pipeline/default.py index 41ae3c71..226deefd 100644 --- a/aixplain/modules/pipeline/default.py +++ b/aixplain/modules/pipeline/default.py @@ -1,5 +1,6 @@ from .asset import Pipeline as PipelineAsset from .designer import DesignerPipeline +from enum import Enum class DefaultPipeline(PipelineAsset, DesignerPipeline): @@ -13,4 +14,18 @@ def save(self, *args, **kwargs): super().save(*args, **kwargs) def to_dict(self) -> dict: - return self.serialize() + data = self.__dict__.copy() + + for key, value in data.items(): + if isinstance(value, Enum): + data[key] = value.value + + elif isinstance(value, list): + data[key] = [ + v.to_dict() if hasattr(v, "to_dict") else str(v) for v in value + ] + + elif hasattr(value, "to_dict"): + data[key] = value.to_dict() + + return data diff --git a/aixplain/enums/asset_cache.py b/aixplain/utils/asset_cache.py similarity index 57% rename from aixplain/enums/asset_cache.py rename to aixplain/utils/asset_cache.py index 836ede1a..982b8d95 100644 --- a/aixplain/enums/asset_cache.py +++ b/aixplain/utils/asset_cache.py @@ -3,12 +3,10 @@ import json import time from typing import Dict, Optional -from dataclasses import dataclass, asdict +from dataclasses import dataclass from filelock import FileLock from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry -from urllib.parse import urljoin from typing import TypeVar, Generic, Type from typing import List @@ -21,26 +19,13 @@ CACHE_FOLDER = ".cache" DEFAULT_CACHE_EXPIRY = 86400 -@dataclass -class Model: - id: str - name: str = "" - description: str = "" - api_key: str = config.TEAM_API_KEY - supplier: str = "aiXplain" - version: str = "1.0" - status: str = "onboarded" - created_at: str = "" - - @classmethod - def from_dict(cls, data: Dict) -> "Model": - return cls(**data) @dataclass class Store(Generic[T]): data: Dict[str, T] expiry: int + class AssetCache(Generic[T]): """ A modular caching system to handle different asset types (Models, Pipelines, Agents). @@ -95,66 +80,47 @@ def load(self): with open(self.cache_file, "r") as f: try: cache_data = json.load(f) - except Exception as e: - # data is corrupted, invalidate the cache - self.invalidate() - logging.warning(f"Failed to parse cache file: {e}") - return - - try: expiry = cache_data["expiry"] raw_data = cache_data["data"] parsed_data = { - k: self.cls( - id=v.get("id", ""), - name=v.get("name", ""), - description=v.get("description", ""), - api_key=v.get("api_key", config.TEAM_API_KEY), - supplier=v.get("supplier", "aiXplain"), - version=v.get("version", "1.0"), - status=v.get("status", "onboarded"), - created_at=v.get("created_at", ""), - ) for k, v in raw_data.items() + k: self.cls.from_dict(v) for k, v in raw_data.items() } - self.store = Store(data=parsed_data, expiry=expiry) + + if self.store.expiry < time.time(): + logger.warning(f"Cache expired for {self.cls.__name__}") + self.invalidate() + except Exception as e: self.invalidate() - logging.warning(f"Failed to load cache data: {e}") - + logger.warning(f"Failed to load cache data: {e}") if self.store.expiry < time.time(): logger.warning( f"Cache expired, invalidating cache for {self.cls.__name__}" ) - # cache expired, invalidate the cache self.invalidate() return def save(self): + os.makedirs(CACHE_FOLDER, exist_ok=True) with FileLock(self.lock_file): with open(self.cache_file, "w") as f: - # serialize the data manually + data_dict = {} + for asset_id, asset in self.store.data.items(): + try: + data_dict[asset_id] = asset.to_dict() + except Exception as e: + logger.error(f"Error serializing {asset_id}: {e}") serializable_store = { - "expiry": self.compute_expiry(), - "data": { - asset_id: { - "id": model.id, - "name": model.name, - "description": model.description, - "api_key": model.api_key, - "supplier": model.supplier, - "version": model.version, - "created_at": model.created_at.isoformat() if hasattr(model.created_at, "isoformat") else model.created_at, - } - for asset_id, model in self.store.data.items() - }, + "expiry": self.store.expiry, + "data": data_dict, } - json.dump(serializable_store, f) + json.dump(serializable_store, f, indent=4) def get(self, asset_id: str) -> Optional[T]: return self.store.data.get(asset_id) @@ -163,12 +129,12 @@ def add(self, asset: T): self.store.data[asset.id] = asset self.save() - def add_model_list(self, models: List[T]): - self.store.data = {model.id: model for model in models} + def add_list(self, assets: List[T]): + self.store.data = {asset.id: asset for asset in assets} self.save() - def get_all_models(self) -> List[T]: + def get_all(self) -> List[T]: return list(self.store.data.values()) def has_valid_cache(self) -> bool: - return self.store.expiry >= time.time() \ No newline at end of file + return self.store.expiry >= time.time() and bool(self.store.data) From c273cb3f0f723eb2ad9c159ed1d574be5e36831d Mon Sep 17 00:00:00 2001 From: xainaz Date: Sun, 27 Apr 2025 14:39:20 +0300 Subject: [PATCH 11/13] removed agent and pipeline changes, made metadata changes --- aixplain/enums/function.py | 36 +++- aixplain/enums/language.py | 34 ++- aixplain/enums/license.py | 32 ++- aixplain/factories/agent_factory/__init__.py | 196 +++++------------- .../factories/pipeline_factory/__init__.py | 163 +++++---------- aixplain/modules/agent/__init__.py | 97 ++------- aixplain/modules/pipeline/asset.py | 132 +++--------- 7 files changed, 225 insertions(+), 465 deletions(-) diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index d606b98f..8c25111b 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -33,20 +33,38 @@ CACHE_FILE = f"{CACHE_FOLDER}/functions.json" LOCK_FILE = f"{CACHE_FILE}.lock" +from dataclasses import dataclass, field +from typing import List, Optional, Dict, Any +@dataclass class FunctionMetadata: - def __init__(self, data: dict): - self.__dict__.update(data) - - def __repr__(self): - return f"" + id: str + name: str + description: Optional[str] = None + params: List[Dict[str, Any]] = field(default_factory=list) + output: List[Dict[str, Any]] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) def to_dict(self) -> dict: - return self.__dict__ + return { + "id": self.id, + "name": self.name, + "description": self.description, + "params": self.params, + "output": self.output, + "metadata": self.metadata, + } @classmethod def from_dict(cls, data: dict): - return cls(data) + return cls( + id=data.get("id"), + name=data.get("name"), + description=data.get("description"), + params=data.get("params", []), + output=data.get("output", []), + metadata={k: v for k, v in data.items() if k not in {"id", "name", "description", "params", "output"}}, + ) def load_functions(): @@ -71,7 +89,7 @@ def load_functions(): ) resp = r.json() results = resp.get("results") - function_objects = [FunctionMetadata(f) for f in results] + function_objects = [FunctionMetadata.from_dict(f) for f in results] cache.add_list(function_objects) class Function(str, Enum): @@ -142,4 +160,4 @@ def __init__(self, input_params: Dict): ) -Function, FunctionInputOutput = load_functions() +Function, FunctionInputOutput = load_functions() \ No newline at end of file diff --git a/aixplain/enums/language.py b/aixplain/enums/language.py index f8a61381..a660024a 100644 --- a/aixplain/enums/language.py +++ b/aixplain/enums/language.py @@ -27,19 +27,35 @@ from aixplain.utils.request_utils import _request_with_retry from aixplain.utils.asset_cache import AssetCache import logging +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional - +@dataclass class LanguageMetadata: - def __init__(self, data: dict): - self.__dict__.update(data) - - def to_dict(self): - return self.__dict__ + id: str + value: str + label: str + dialects: List[Dict[str, str]] = field(default_factory=list) + scripts: List[Any] = field(default_factory=list) + + def to_dict(self) -> dict: + return { + "id": self.id, + "value": self.value, + "label": self.label, + "dialects": self.dialects, + "scripts": self.scripts, + } @classmethod def from_dict(cls, data: dict): - return cls(data) - + return cls( + id=data.get("id"), + value=data.get("value"), + label=data.get("label"), + dialects=data.get("dialects", []), + scripts=data.get("scripts", []), + ) def load_languages(): api_key = config.TEAM_API_KEY @@ -60,7 +76,7 @@ def load_languages(): f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - lang_entries = [LanguageMetadata(item) for item in resp] + lang_entries = [LanguageMetadata.from_dict(item) for item in resp] cache.add_list(lang_entries) languages = {} diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index 5489173a..6914f647 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -29,16 +29,34 @@ from aixplain.utils.asset_cache import AssetCache, CACHE_FOLDER -class LicenseMetadata: - def __init__(self, data: dict): - self.__dict__.update(data) +from dataclasses import dataclass - def to_dict(self): - return self.__dict__ +@dataclass +class LicenseMetadata: + id: str + name: str + description: str + url: str + allowCustomUrl: bool + + def to_dict(self) -> dict: + return { + "id": self.id, + "name": self.name, + "description": self.description, + "url": self.url, + "allowCustomUrl": self.allowCustomUrl, + } @classmethod def from_dict(cls, data: dict): - return cls(data) + return cls( + id=data.get("id"), + name=data.get("name"), + description=data.get("description"), + url=data.get("url"), + allowCustomUrl=data.get("allowCustomUrl", False), + ) def load_licenses(): @@ -62,7 +80,7 @@ def load_licenses(): f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation.' ) resp = r.json() - license_objects = [LicenseMetadata(item) for item in resp] + license_objects = [LicenseMetadata.from_dict(item) for item in resp] cache.add_list(license_objects) licenses = {"_".join(lic.name.split()): lic.id for lic in license_objects} diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index ceb3c3d5..10fe36e9 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -39,7 +39,6 @@ from aixplain.modules.model import Model from aixplain.modules.pipeline import Pipeline from aixplain.utils import config -from aixplain.utils.asset_cache import AssetCache from typing import Callable, Dict, List, Optional, Text, Union from aixplain.utils.file_utils import _request_with_retry @@ -82,9 +81,9 @@ def create( Agent: created Agent """ warnings.warn( - "The 'instructions' parameter was recently added and serves the same purpose as 'description' did previously: set the role of the agent as a system prompt. " - "The 'description' parameter is still required and should be used to set a short summary of the agent's purpose. " - "For the next releases, the 'instructions' parameter will be required.", + "Use `instructions` to define the **system prompt**. " + "Use `description` to provide a **short summary** of the agent for metadata and dashboard display. " + "Note: In upcoming releases, `instructions` will become a required parameter.", UserWarning, ) from aixplain.factories.agent_factory.utils import build_agent @@ -101,34 +100,21 @@ def create( payload = { "name": name, "assets": [ - ( - tool.to_dict() - if isinstance(tool, Tool) - else { - "id": tool.id, - "name": tool.name, - "description": tool.description, - "supplier": ( - tool.supplier.value["code"] - if isinstance(tool.supplier, Supplier) - else tool.supplier - ), - "parameters": ( - tool.get_parameters().to_list() - if hasattr(tool, "get_parameters") - and tool.get_parameters() is not None - else None - ), - "function": ( - tool.function - if hasattr(tool, "function") and tool.function is not None - else None - ), - "type": "model", - "version": tool.version if hasattr(tool, "version") else None, - "assetId": tool.id, - } - ) + tool.to_dict() + if isinstance(tool, Tool) + else { + "id": tool.id, + "name": tool.name, + "description": tool.description, + "supplier": tool.supplier.value["code"] if isinstance(tool.supplier, Supplier) else tool.supplier, + "parameters": tool.get_parameters().to_list() + if hasattr(tool, "get_parameters") and tool.get_parameters() is not None + else None, + "function": tool.function if hasattr(tool, "function") and tool.function is not None else None, + "type": "model", + "version": tool.version if hasattr(tool, "version") else None, + "assetId": tool.id, + } for tool in tools ], "description": description, @@ -143,15 +129,11 @@ def create( agent.validate(raise_exception=True) response = "Unspecified error" try: - logging.debug( - f"Start service for POST Create Agent - {url} - {headers} - {json.dumps(agent.to_dict())}" - ) + logging.debug(f"Start service for POST Create Agent - {url} - {headers} - {json.dumps(agent.to_dict())}") r = _request_with_retry("post", url, headers=headers, json=agent.to_dict()) response = r.json() except Exception: - raise Exception( - "Agent Onboarding Error: Please contact the administrators." - ) + raise Exception("Agent Onboarding Error: Please contact the administrators.") if 200 <= r.status_code < 300: agent = build_agent(payload=response, tools=tools, api_key=api_key) @@ -170,18 +152,9 @@ def create( @classmethod def create_task( - cls, - name: Text, - description: Text, - expected_output: Text, - dependencies: Optional[List[Text]] = None, + cls, name: Text, description: Text, expected_output: Text, dependencies: Optional[List[Text]] = None ) -> AgentTask: - return AgentTask( - name=name, - description=description, - expected_output=expected_output, - dependencies=dependencies, - ) + return AgentTask(name=name, description=description, expected_output=expected_output, dependencies=dependencies) @classmethod def create_model_tool( @@ -199,27 +172,14 @@ def create_model_tool( if supplier is not None: if isinstance(supplier, str): for supplier_ in Supplier: - if supplier.lower() in [ - supplier_.value["code"].lower(), - supplier_.value["name"].lower(), - ]: + if supplier.lower() in [supplier_.value["code"].lower(), supplier_.value["name"].lower()]: supplier = supplier_ break - assert isinstance( - supplier, Supplier - ), f"Supplier {supplier} is not a valid supplier" - return ModelTool( - function=function, - supplier=supplier, - model=model, - description=description, - parameters=parameters, - ) + assert isinstance(supplier, Supplier), f"Supplier {supplier} is not a valid supplier" + return ModelTool(function=function, supplier=supplier, model=model, description=description, parameters=parameters) @classmethod - def create_pipeline_tool( - cls, description: Text, pipeline: Union[Pipeline, Text] - ) -> PipelineTool: + def create_pipeline_tool(cls, description: Text, pipeline: Union[Pipeline, Text]) -> PipelineTool: """Create a new pipeline tool.""" return PipelineTool(description=description, pipeline=pipeline) @@ -229,9 +189,7 @@ def create_python_interpreter_tool(cls) -> PythonInterpreterTool: return PythonInterpreterTool() @classmethod - def create_custom_python_code_tool( - cls, code: Union[Text, Callable], description: Text = "" - ) -> CustomPythonCodeTool: + def create_custom_python_code_tool(cls, code: Union[Text, Callable], description: Text = "") -> CustomPythonCodeTool: """Create a new custom python code tool.""" return CustomPythonCodeTool(description=description, code=code) @@ -296,9 +254,7 @@ def create_sql_tool( # Already the correct type, no conversion needed pass else: - raise SQLToolError( - f"Source type must be either a string or DatabaseSourceType enum, got {type(source_type)}" - ) + raise SQLToolError(f"Source type must be either a string or DatabaseSourceType enum, got {type(source_type)}") database_path = None # Final database path to pass to SQLTool @@ -308,14 +264,17 @@ def create_sql_tool( raise SQLToolError(f"CSV file '{source}' does not exist") if not source.endswith(".csv"): raise SQLToolError(f"File '{source}' is not a CSV file") + if tables and len(tables) > 1: + raise SQLToolError("CSV source type only supports one table") # Create database name from CSV filename or use custom table name base_name = os.path.splitext(os.path.basename(source))[0] db_path = os.path.join(os.path.dirname(source), f"{base_name}.db") + table_name = tables[0] if tables else None try: # Create database from CSV - schema = create_database_from_csv(source, db_path) + schema = create_database_from_csv(source, db_path, table_name) database_path = db_path # Get table names if not provided @@ -327,9 +286,7 @@ def create_sql_tool( try: os.remove(db_path) except Exception as cleanup_error: - warnings.warn( - f"Failed to remove temporary database file '{db_path}': {str(cleanup_error)}" - ) + warnings.warn(f"Failed to remove temporary database file '{db_path}': {str(cleanup_error)}") raise SQLToolError(f"Failed to create database from CSV: {str(e)}") # Handle SQLite source type @@ -337,9 +294,7 @@ def create_sql_tool( if not os.path.exists(source): raise SQLToolError(f"Database '{source}' does not exist") if not source.endswith(".db") and not source.endswith(".sqlite"): - raise SQLToolError( - f"Database '{source}' must have .db or .sqlite extension" - ) + raise SQLToolError(f"Database '{source}' must have .db or .sqlite extension") database_path = source @@ -379,9 +334,7 @@ def list(cls) -> Dict: resp = {} payload = {} - logging.info( - f"Start service for GET List Agents - {url} - {headers} - {json.dumps(payload)}" - ) + logging.info(f"Start service for GET List Agents - {url} - {headers} - {json.dumps(payload)}") try: r = _request_with_retry("get", url, headers=headers) resp = r.json() @@ -393,17 +346,10 @@ def list(cls) -> Dict: results = resp page_total = len(results) total = len(results) - logging.info( - f"Response for GET List Agents - Page Total: {page_total} / Total: {total}" - ) + logging.info(f"Response for GET List Agents - Page Total: {page_total} / Total: {total}") for agent in results: agents.append(build_agent(agent)) - return { - "results": agents, - "page_total": page_total, - "page_number": 0, - "total": total, - } + return {"results": agents, "page_total": page_total, "page_number": 0, "total": total} else: error_msg = "Agent Listing Error: Please contact the administrators." if isinstance(resp, dict) and "message" in resp: @@ -413,60 +359,22 @@ def list(cls) -> Dict: raise Exception(error_msg) @classmethod - def get( - cls, agent_id: Text, api_key: Optional[Text] = None, use_cache: bool = True - ) -> Agent: + def get(cls, agent_id: Text, api_key: Optional[Text] = None) -> Agent: + """Get agent by id.""" from aixplain.factories.agent_factory.utils import build_agent - from aixplain.utils.asset_cache import AssetCache - - cache = AssetCache(Agent) - api_key = api_key or config.TEAM_API_KEY - - if use_cache: - if cache.has_valid_cache(): - cached_agent = cache.store.data.get(agent_id) - if cached_agent: - logging.info(f"Agent {agent_id} retrieved from valid cache.") - return cached_agent - else: - logging.info( - "No valid cache found — fetching full agent list to build cache." - ) - try: - agent_list_resp = cls.list() - agents = agent_list_resp.get("results", []) - cache.add_list(agents) - logging.info(f"Cache rebuilt with {len(agents)} agents.") - - for agent in agents: - if agent.id == agent_id: - logging.info( - f"Agent {agent_id} retrieved from newly built cache." - ) - return agent - except Exception as e: - logging.error(f"Error rebuilding agent cache: {e}") - raise e - # Fallback: direct fetch if cache not used or agent not found - logging.info(f"Fetching agent {agent_id} directly from backend.") url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent_id}") - headers = {"x-api-key": api_key, "Content-Type": "application/json"} - - try: - r = _request_with_retry("get", url, headers=headers) - resp = r.json() - if 200 <= r.status_code < 300: - agent = build_agent(resp) - cache.add(agent) # still helpful for future use - logging.info( - f"Agent {agent_id} fetched from backend and added to cache." - ) - return agent - else: - msg = resp.get("message", "Please contact the administrators.") - raise Exception(f"Agent Get Error (HTTP {r.status_code}): {msg}") - except Exception as e: - logging.exception(f"Agent Get Error: {e}") - raise + api_key = api_key if api_key is not None else config.TEAM_API_KEY + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + logging.info(f"Start service for GET Agent - {url} - {headers}") + r = _request_with_retry("get", url, headers=headers) + resp = r.json() + if 200 <= r.status_code < 300: + return build_agent(resp) + else: + msg = "Please contact the administrators." + if "message" in resp: + msg = resp["message"] + error_msg = f"Agent Get Error (HTTP {r.status_code}): {msg}" + raise Exception(error_msg) \ No newline at end of file diff --git a/aixplain/factories/pipeline_factory/__init__.py b/aixplain/factories/pipeline_factory/__init__.py index 72112654..ce538d58 100644 --- a/aixplain/factories/pipeline_factory/__init__.py +++ b/aixplain/factories/pipeline_factory/__init__.py @@ -30,7 +30,6 @@ from aixplain.enums.supplier import Supplier from aixplain.modules.model import Model from aixplain.modules.pipeline import Pipeline -from aixplain.utils.asset_cache import AssetCache from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin @@ -47,9 +46,7 @@ class PipelineFactory: backend_url = config.BACKEND_URL @classmethod - def get( - cls, pipeline_id: Text, api_key: Optional[Text] = None, use_cache: bool = True - ) -> Pipeline: + def get(cls, pipeline_id: Text, api_key: Optional[Text] = None) -> Pipeline: """Create a 'Pipeline' object from pipeline id Args: @@ -59,36 +56,48 @@ def get( Returns: Pipeline: Created 'Pipeline' object """ - cache = AssetCache(Pipeline) - if use_cache: - if cache.has_valid_cache(): - cached_pipeline = cache.store.data.get(pipeline_id) - if cached_pipeline: - return cached_pipeline - logging.info( - "Pipeline not found in valid cache, fetching individually..." - ) - pipeline = cls._fetch_pipeline_by_id(pipeline_id) - cache.add(pipeline) - return pipeline + resp = None + try: + url = urljoin(cls.backend_url, f"sdk/pipelines/{pipeline_id}") + if api_key is not None: + headers = { + "Authorization": f"Token {api_key}", + "Content-Type": "application/json", + } + else: + headers = { + "Authorization": f"Token {config.TEAM_API_KEY}", + "Content-Type": "application/json", + } + logging.info(f"Start service for GET Pipeline - {url} - {headers}") + r = _request_with_retry("get", url, headers=headers) + resp = r.json() + + except Exception as e: + logging.exception(e) + status_code = 400 + if resp is not None and "statusCode" in resp: + status_code = resp["statusCode"] + message = resp["message"] + message = f"Pipeline Creation: Status {status_code} - {message}" else: - try: - pipeline_list_resp = cls.list() - pipeline_dicts = pipeline_list_resp.get("results", []) - pipelines = pipeline_dicts - cache.add_list(pipelines) - - for pipeline in pipelines: - if pipeline.id == pipeline_id: - return pipeline - - except Exception as e: - logging.error(f"Error fetching pipeline list: {e}") - raise e - logging.info("Fetching pipeline directly without cache...") - pipeline = cls._fetch_pipeline_by_id(pipeline_id, config.TEAM_API_KEY) - cache.add(pipeline) - return pipeline + message = f"Pipeline Creation: Unspecified Error {e}" + logging.error(message) + raise Exception(f"Status {status_code}: {message}") + if 200 <= r.status_code < 300: + resp["api_key"] = config.TEAM_API_KEY + if api_key is not None: + resp["api_key"] = api_key + pipeline = build_from_response(resp, load_architecture=True) + logging.info(f"Pipeline {pipeline_id} retrieved successfully.") + return pipeline + + else: + error_message = ( + f"Pipeline GET Error: Failed to retrieve pipeline {pipeline_id}. Status Code: {r.status_code}. Error: {resp}" + ) + logging.error(error_message) + raise Exception(error_message) @classmethod def create_asset_from_id(cls, pipeline_id: Text) -> Pipeline: @@ -117,14 +126,9 @@ def get_assets_from_page(cls, page_number: int) -> List[Pipeline]: } r = _request_with_retry("get", url, headers=headers) resp = r.json() - logging.info( - f"Listing Pipelines: Status of getting Pipelines on Page {page_number}: {resp}" - ) + logging.info(f"Listing Pipelines: Status of getting Pipelines on Page {page_number}: {resp}") all_pipelines = resp["items"] - pipeline_list = [ - build_from_response(pipeline_info_json) - for pipeline_info_json in all_pipelines - ] + pipeline_list = [build_from_response(pipeline_info_json) for pipeline_info_json in all_pipelines] return pipeline_list except Exception as e: error_message = f"Listing Pipelines: Error in getting Pipelines on Page {page_number}: {e}" @@ -172,9 +176,7 @@ def list( "Content-Type": "application/json", } - assert ( - 0 < page_size <= 100 - ), "Pipeline List Error: Page size must be greater than 0 and not exceed 100." + assert 0 < page_size <= 100, "Pipeline List Error: Page size must be greater than 0 and not exceed 100." payload = { "pageSize": page_size, "pageNumber": page_number, @@ -189,6 +191,7 @@ def list( if isinstance(functions, Function) is True: functions = [functions] payload["functions"] = [function.value for function in functions] + if suppliers is not None: if isinstance(suppliers, Supplier) is True: suppliers = [suppliers] @@ -202,20 +205,14 @@ def list( if input_data_types is not None: if isinstance(input_data_types, DataType) is True: input_data_types = [input_data_types] - payload["inputDataTypes"] = [ - data_type.value for data_type in input_data_types - ] + payload["inputDataTypes"] = [data_type.value for data_type in input_data_types] if output_data_types is not None: if isinstance(output_data_types, DataType) is True: output_data_types = [output_data_types] - payload["inputDataTypes"] = [ - data_type.value for data_type in output_data_types - ] + payload["inputDataTypes"] = [data_type.value for data_type in output_data_types] - logging.info( - f"Start service for POST List Pipeline - {url} - {headers} - {json.dumps(payload)}" - ) + logging.info(f"Start service for POST List Pipeline - {url} - {headers} - {json.dumps(payload)}") try: r = _request_with_retry("post", url, headers=headers, json=payload) resp = r.json() @@ -230,9 +227,7 @@ def list( results = resp["items"] page_total = resp["pageTotal"] total = resp["total"] - logging.info( - f"Response for POST List Pipeline - Page Total: {page_total} / Total: {total}" - ) + logging.info(f"Response for POST List Pipeline - Page Total: {page_total} / Total: {total}") for pipeline in results: pipelines.append(build_from_response(pipeline)) return { @@ -299,9 +294,7 @@ def create( for i, node in enumerate(pipeline["nodes"]): if "functionType" in node: - pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i][ - "functionType" - ].lower() + pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i]["functionType"].lower() # prepare payload payload = { "name": name, @@ -314,60 +307,10 @@ def create( "Authorization": f"Token {api_key}", "Content-Type": "application/json", } - logging.info( - f"Start service for POST Create Pipeline - {url} - {headers} - {json.dumps(payload)}" - ) + logging.info(f"Start service for POST Create Pipeline - {url} - {headers} - {json.dumps(payload)}") r = _request_with_retry("post", url, headers=headers, json=payload) response = r.json() return Pipeline(response["id"], name, api_key) except Exception as e: - raise Exception(e) - - @classmethod - def _fetch_pipeline_by_id( - cls, pipeline_id: Text, api_key: Optional[Text] = None - ) -> Pipeline: - """Fetch a Pipeline by ID from the backend (no cache)""" - - resp = None - try: - url = urljoin(cls.backend_url, f"sdk/pipelines/{pipeline_id}") - if api_key is not None: - headers = { - "Authorization": f"Token {api_key}", - "Content-Type": "application/json", - } - else: - headers = { - "Authorization": f"Token {config.TEAM_API_KEY}", - "Content-Type": "application/json", - } - - logging.info(f"Start service for GET Pipeline - {url} - {headers}") - r = _request_with_retry("get", url, headers=headers) - resp = r.json() - - except Exception as e: - logging.exception(e) - status_code = 400 - if resp is not None and "statusCode" in resp: - status_code = resp["statusCode"] - message = resp["message"] - message = f"Pipeline Creation: Status {status_code} - {message}" - else: - message = f"Pipeline Creation: Unspecified Error {e}" - logging.error(message) - raise Exception(f"Status {status_code}: {message}") - if 200 <= r.status_code < 300: - resp["api_key"] = config.TEAM_API_KEY - if api_key is not None: - resp["api_key"] = api_key - pipeline = build_from_response(resp, load_architecture=True) - logging.info(f"Pipeline {pipeline_id} retrieved successfully.") - return pipeline - - else: - error_message = f"Pipeline GET Error: Failed to retrieve pipeline {pipeline_id}. Status Code: {r.status_code}. Error: {resp}" - logging.error(error_message) - raise Exception(error_message) + raise Exception(e) \ No newline at end of file diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 8a68a3c3..200383bd 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -125,17 +125,13 @@ def _validate(self) -> None: except Exception: raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.") - assert ( - llm.function == Function.TEXT_GENERATION - ), "Large Language Model must be a text generation model." + assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." for tool in self.tools: if isinstance(tool, Tool): tool.validate() elif isinstance(tool, Model): - assert not isinstance( - tool, Agent - ), "Agent cannot contain another Agent." + assert not isinstance(tool, Agent), "Agent cannot contain another Agent." def validate(self, raise_exception: bool = False) -> bool: """Validate the Agent.""" @@ -148,9 +144,7 @@ def validate(self, raise_exception: bool = False) -> bool: raise e else: logging.warning(f"Agent Validation Error: {e}") - logging.warning( - "You won't be able to run the Agent until the issues are handled manually." - ) + logging.warning("You won't be able to run the Agent until the issues are handled manually.") return self.is_valid def run( @@ -207,10 +201,8 @@ def run( return response poll_url = response["url"] end = time.time() - result = self.sync_poll( - poll_url, name=name, timeout=timeout, wait_time=wait_time - ) - result_data = result.data + result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + result_data = result.get("data") or {} return AgentResponse( status=ResponseStatus.SUCCESS, completed=True, @@ -272,18 +264,12 @@ def run_async( from aixplain.factories.file_factory import FileFactory if not self.is_valid: - raise Exception( - "Agent is not valid. Please validate the agent before running." - ) + raise Exception("Agent is not valid. Please validate the agent before running.") - assert ( - data is not None or query is not None - ), "Either 'data' or 'query' must be provided." + assert data is not None or query is not None, "Either 'data' or 'query' must be provided." if data is not None: if isinstance(data, dict): - assert ( - "query" in data and data["query"] is not None - ), "When providing a dictionary, 'query' must be provided." + assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided." query = data.get("query") if session_id is None: session_id = data.get("session_id") @@ -296,9 +282,7 @@ def run_async( # process content inputs if content is not None: - assert ( - FileFactory.check_storage_type(query) == StorageType.TEXT - ), "When providing 'content', query must be text." + assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text." if isinstance(content, list): assert len(content) <= 3, "The maximum number of content inputs is 3." @@ -307,16 +291,14 @@ def run_async( query += f"\n{input_link}" elif isinstance(content, dict): for key, value in content.items(): - assert ( - "{{" + key + "}}" in query - ), f"Key '{key}' not found in query." + assert "{{" + key + "}}" in query, f"Key '{key}' not found in query." value = FileFactory.to_link(value) query = query.replace("{{" + key + "}}", f"'{value}'") headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} # build query - input_data = process_variables(query, data, parameters, self.description) + input_data = process_variables(query, data, parameters, self.instructions) payload = { "id": self.id, @@ -324,16 +306,8 @@ def run_async( "sessionId": session_id, "history": history, "executionParams": { - "maxTokens": ( - parameters["max_tokens"] - if "max_tokens" in parameters - else max_tokens - ), - "maxIterations": ( - parameters["max_iterations"] - if "max_iterations" in parameters - else max_iterations - ), + "maxTokens": (parameters["max_tokens"] if "max_tokens" in parameters else max_tokens), + "maxIterations": (parameters["max_iterations"] if "max_iterations" in parameters else max_iterations), "outputFormat": output_format.value, }, } @@ -367,11 +341,7 @@ def to_dict(self) -> Dict: "assets": [tool.to_dict() for tool in self.tools], "description": self.description, "role": self.instructions, - "supplier": ( - self.supplier.value["code"] - if isinstance(self.supplier, Supplier) - else self.supplier - ), + "supplier": (self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier), "version": self.version, "llmId": self.llm_id, "status": self.status.value, @@ -409,8 +379,7 @@ def update(self) -> None: stack = inspect.stack() if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " - "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " "Please use save() instead.", DeprecationWarning, stacklevel=2, ) @@ -422,9 +391,7 @@ def update(self) -> None: payload = self.to_dict() - logging.debug( - f"Start service for PUT Update Agent - {url} - {headers} - {json.dumps(payload)}" - ) + logging.debug(f"Start service for PUT Update Agent - {url} - {headers} - {json.dumps(payload)}") resp = "No specified error." try: r = _request_with_retry("put", url, headers=headers, json=payload) @@ -443,38 +410,10 @@ def save(self) -> None: self.update() def deploy(self) -> None: - assert ( - self.status == AssetStatus.DRAFT - ), "Agent must be in draft status to be deployed." + assert self.status == AssetStatus.DRAFT, "Agent must be in draft status to be deployed." assert self.status != AssetStatus.ONBOARDED, "Agent is already deployed." self.status = AssetStatus.ONBOARDED self.update() def __repr__(self): - return f"Agent(id={self.id}, name={self.name}, function={self.function})" - - @classmethod - def from_dict(cls, data: dict) -> "Agent": - return cls( - id=data.get("id"), - name=data.get("name"), - description=data.get("description", ""), - instructions=data.get("role", ""), - tools=[], - llm_id=data.get("llmId"), - api_key=data.get("api_key", config.TEAM_API_KEY), - supplier=data.get("supplier", "aiXplain"), - version=data.get("version"), - cost=data.get("cost"), - status=( - AssetStatus(data["status"]) if data.get("status") else AssetStatus.DRAFT - ), - tasks=( - [AgentTask.from_dict(t) for t in data.get("tasks", [])] - if "tasks" in data - else [] - ), - function=( - Function(data["function"]) if data.get("function") is not None else None - ), - ) + return f"Agent(id={self.id}, name={self.name}, function={self.function})" \ No newline at end of file diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 18ba40ca..e32e9aa3 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -114,18 +114,14 @@ def __polling( while not response_body["completed"] and (end - start) < timeout: try: response_body = self.poll(poll_url, name=name) - logging.debug( - f"Polling for Pipeline: Status of polling for {name} : {response_body}" - ) + logging.debug(f"Polling for Pipeline: Status of polling for {name} : {response_body}") end = time.time() if not response_body["completed"]: time.sleep(wait_time) if wait_time < 60: wait_time *= 1.1 except Exception: - logging.error( - f"Polling for Pipeline '{self.id}': polling for {name} ({poll_url}): Continue" - ) + logging.error(f"Polling for Pipeline '{self.id}': polling for {name} ({poll_url}): Continue") break if response_body["status"] == ResponseStatus.SUCCESS: try: @@ -143,10 +139,7 @@ def __polling( return response_body def poll( - self, - poll_url: Text, - name: Text = "pipeline_process", - response_version: Text = "v2", + self, poll_url: Text, name: Text = "pipeline_process", response_version: Text = "v2" ) -> Union[Dict, PipelineResponse]: """Poll the platform to check whether an asynchronous call is done. @@ -170,9 +163,7 @@ def poll( resp["data"] = json.loads(resp["data"])["response"] except Exception: resp = r.json() - logging.info( - f"Single Poll for Pipeline '{self.id}' - Status of polling for {name} ({poll_url}): {resp}" - ) + logging.info(f"Single Poll for Pipeline '{self.id}' - Status of polling for {name} ({poll_url}): {resp}") if response_version == "v1": return resp status = ResponseStatus(resp.pop("status", "failed")) @@ -205,9 +196,7 @@ def run( ) -> Union[Dict, PipelineResponse]: start = time.time() try: - response = self.run_async( - data, data_asset=data_asset, name=name, version=version, **kwargs - ) + response = self.run_async(data, data_asset=data_asset, name=name, version=version, **kwargs) if response["status"] == ResponseStatus.FAILED: end = time.time() if response_version == "v1": @@ -224,16 +213,10 @@ def run( **kwargs, ) poll_url = response["url"] - polling_response = self.__polling( - poll_url, name=name, timeout=timeout, wait_time=wait_time - ) + polling_response = self.__polling(poll_url, name=name, timeout=timeout, wait_time=wait_time) end = time.time() - if isinstance(polling_response, dict): - status = ResponseStatus(polling_response.get("status", "failed")) - completed = polling_response.get("completed", False) - else: - status = polling_response.status - completed = getattr(polling_response, "completed", False) + status = ResponseStatus(polling_response["status"]) + completed = polling_response["completed"] if response_version == "v1": polling_response["elapsed_time"] = end - start return polling_response @@ -303,10 +286,7 @@ def __prepare_payload( try: payload = json.loads(data) if isinstance(payload, dict) is False: - if ( - isinstance(payload, int) is True - or isinstance(payload, float) is True - ): + if isinstance(payload, int) is True or isinstance(payload, float) is True: payload = str(payload) payload = {"data": payload} except Exception: @@ -344,9 +324,7 @@ def __prepare_payload( asset_payload["dataAsset"]["dataset_id"] = dasset.id source_data_list = [ - dfield - for dfield in dasset.source_data - if dasset.source_data[dfield].id == data[node_label] + dfield for dfield in dasset.source_data if dasset.source_data[dfield].id == data[node_label] ] if len(source_data_list) > 0: @@ -420,9 +398,7 @@ def run_async( try: if 200 <= r.status_code < 300: resp = r.json() - logging.info( - f"Result of request for {name} - {r.status_code} - {resp}" - ) + logging.info(f"Result of request for {name} - {r.status_code} - {resp}") if response_version == "v1": return resp res = PipelineResponse( @@ -446,11 +422,11 @@ def run_async( error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." else: status_code = str(r.status_code) - error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." + error = ( + f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." + ) - logging.error( - f"Error in request for {name} (Pipeline ID '{self.id}') - {r.status_code}: {error}" - ) + logging.error(f"Error in request for {name} (Pipeline ID '{self.id}') - {r.status_code}: {error}") if response_version == "v1": return { "status": "failed", @@ -503,8 +479,7 @@ def update( stack = inspect.stack() if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " - "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " "Please use save() instead.", DeprecationWarning, stacklevel=2, ) @@ -519,9 +494,7 @@ def update( for i, node in enumerate(pipeline["nodes"]): if "functionType" in node: - pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i][ - "functionType" - ].lower() + pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i]["functionType"].lower() # prepare payload status = "draft" if save_as_asset is True: @@ -539,9 +512,7 @@ def update( "Authorization": f"Token {api_key}", "Content-Type": "application/json", } - logging.info( - f"Start service for PUT Update Pipeline - {url} - {headers} - {json.dumps(payload)}" - ) + logging.info(f"Start service for PUT Update Pipeline - {url} - {headers} - {json.dumps(payload)}") r = _request_with_retry("put", url, headers=headers, json=payload) response = r.json() logging.info(f"Pipeline {response['id']} Updated.") @@ -592,15 +563,11 @@ def save( ), "Pipeline Update Error: Make sure the pipeline to be saved is in a JSON file." with open(pipeline) as f: pipeline = json.load(f) - self.update( - pipeline=pipeline, save_as_asset=save_as_asset, api_key=api_key - ) + self.update(pipeline=pipeline, save_as_asset=save_as_asset, api_key=api_key) for i, node in enumerate(pipeline["nodes"]): if "functionType" in node: - pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i][ - "functionType" - ].lower() + pipeline["nodes"][i]["functionType"] = pipeline["nodes"][i]["functionType"].lower() # prepare payload status = "draft" if save_as_asset is True: @@ -617,9 +584,7 @@ def save( "Authorization": f"Token {api_key}", "Content-Type": "application/json", } - logging.info( - f"Start service for Save Pipeline - {url} - {headers} - {json.dumps(payload)}" - ) + logging.info(f"Start service for Save Pipeline - {url} - {headers} - {json.dumps(payload)}") r = _request_with_retry("post", url, headers=headers, json=payload) response = r.json() self.id = response["id"] @@ -629,59 +594,12 @@ def save( def deploy(self, api_key: Optional[Text] = None) -> None: """Deploy the Pipeline.""" - assert ( - self.status == "draft" - ), "Pipeline Deployment Error: Pipeline must be in draft status." - assert ( - self.status != "onboarded" - ), "Pipeline Deployment Error: Pipeline must be onboarded." + assert self.status == "draft", "Pipeline Deployment Error: Pipeline must be in draft status." + assert self.status != "onboarded", "Pipeline Deployment Error: Pipeline must be onboarded." pipeline = self.to_dict() - self.update( - pipeline=pipeline, save_as_asset=True, api_key=api_key, name=self.name - ) + self.update(pipeline=pipeline, save_as_asset=True, api_key=api_key, name=self.name) self.status = AssetStatus.ONBOARDED def __repr__(self): - return f"Pipeline(id={self.id}, name={self.name})" - - @classmethod - def from_dict(cls, data: Dict) -> "Pipeline": - """ - Create a Pipeline instance from a dictionary. - - Args: - data (Dict): A dictionary containing pipeline attributes. - - Returns: - Pipeline: An instance of the Pipeline class. - """ - return cls( - id=data.get("id"), - name=data.get("name"), - api_key=data.get("api_key", config.TEAM_API_KEY), - url=data.get("url", config.BACKEND_URL), - supplier=data.get("supplier", "aiXplain"), - version=data.get("version", "1.0"), - status=data.get("status", AssetStatus.DRAFT), - **data.get("additional_info", {}), - ) - - def to_dict(self) -> Dict: - """ - Serialize the Pipeline object to a dictionary. - - Returns: - Dict: A dictionary representation of the Pipeline instance. - """ - logging.info("corect to dict") - return { - "id": self.id, - "name": self.name, - "api_key": self.api_key, - "url": self.url, - "supplier": self.supplier, - "version": self.version, - "status": self.status.name if hasattr(self.status, "name") else self.status, - "additional_info": self.additional_info, - } + return f"Pipeline(id={self.id}, name={self.name})" \ No newline at end of file From 317a51ddb12a29e2da51fadd5b37416aa8323df8 Mon Sep 17 00:00:00 2001 From: xainaz Date: Fri, 9 May 2025 13:52:16 +0300 Subject: [PATCH 12/13] fixed unit test issue --- aixplain/modules/agent/tool/model_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 9b073a84..8a0f9dc2 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -174,7 +174,7 @@ def validate(self) -> None: self.description = self.model.description elif self.function is not None: try: - self.description = FunctionInputOutput[self.function.value]["spec"]["metaData"]["description"] + self.description = FunctionInputOutput[self.function.value]["spec"]["description"] except Exception: self.description = "" From e308a9515144b5c0a7a48ba13c3a6a7356abb1b9 Mon Sep 17 00:00:00 2001 From: xainaz Date: Mon, 12 May 2025 22:25:25 +0300 Subject: [PATCH 13/13] fixed relative import --- aixplain/enums/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index 8c25111b..462f2e38 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -25,7 +25,7 @@ from aixplain.utils.request_utils import _request_with_retry from enum import Enum from urllib.parse import urljoin -from ..utils.asset_cache import AssetCache, CACHE_FOLDER +from aixplain.utils.asset_cache import AssetCache, CACHE_FOLDER from typing import Tuple, Dict from aixplain.base.parameters import BaseParameters, Parameter import os