Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

ENG-1551 ai xplain sdk caching onboarded models pipelines and agents #389

Open
wants to merge 6 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions aixplain/enums/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,19 @@
from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER
from typing import Tuple, Dict
from aixplain.base.parameters import BaseParameters, Parameter
import os

CACHE_FILE = f"{CACHE_FOLDER}/functions.json"
LOCK_FILE = f"{CACHE_FILE}.lock"


def load_functions():
api_key = config.TEAM_API_KEY
backend_url = config.BACKEND_URL

resp = load_from_cache(CACHE_FILE)
os.makedirs(CACHE_FOLDER, exist_ok=True)

resp = load_from_cache(CACHE_FILE, LOCK_FILE)
if resp is None:
url = urljoin(backend_url, "sdk/functions")

Expand All @@ -47,7 +51,7 @@ def load_functions():
f'Functions could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)'
)
resp = r.json()
save_to_cache(CACHE_FILE, resp)
save_to_cache(CACHE_FILE, resp, LOCK_FILE)

class Function(str, Enum):
def __new__(cls, value):
Expand All @@ -63,6 +67,8 @@ def get_input_output_params(self) -> Tuple[Dict, Dict]:
Tuple[Dict, Dict]: A tuple containing (input_params, output_params)
"""
function_io = FunctionInputOutput.get(self.value, None)
if function_io is None:
return {}, {}
input_params = {param["code"]: param for param in function_io["spec"]["params"]}
output_params = {param["code"]: param for param in function_io["spec"]["output"]}
return input_params, output_params
Expand Down
5 changes: 3 additions & 2 deletions aixplain/enums/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER

CACHE_FILE = f"{CACHE_FOLDER}/languages.json"
LOCK_FILE = f"{CACHE_FILE}.lock"


def load_languages():
resp = load_from_cache(CACHE_FILE)
resp = load_from_cache(CACHE_FILE, LOCK_FILE)
if resp is None:
api_key = config.TEAM_API_KEY
backend_url = config.BACKEND_URL
Expand All @@ -45,7 +46,7 @@ def load_languages():
f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)'
)
resp = r.json()
save_to_cache(CACHE_FILE, resp)
save_to_cache(CACHE_FILE, resp, LOCK_FILE)

languages = {}
for w in resp:
Expand Down
6 changes: 4 additions & 2 deletions aixplain/enums/license.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER

CACHE_FILE = f"{CACHE_FOLDER}/licenses.json"
LOCK_FILE = f"{CACHE_FILE}.lock"


def load_licenses():
resp = load_from_cache(CACHE_FILE)
resp = load_from_cache(CACHE_FILE, LOCK_FILE)

try:
if resp is None:
Expand All @@ -48,7 +49,8 @@ def load_licenses():
f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)'
)
resp = r.json()
save_to_cache(CACHE_FILE, resp)
save_to_cache(CACHE_FILE, resp, LOCK_FILE)

licenses = {"_".join(w["name"].split()): w["id"] for w in resp}
return Enum("License", licenses, type=str)
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion aixplain/factories/team_agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def create(
from aixplain.modules.agent import Agent

assert isinstance(agent, Agent), "TeamAgent Onboarding Error: Agents must be instances of Agent class"

mentalist_and_inspector_llm_id = None
if use_inspector or use_mentalist_and_inspector:
mentalist_and_inspector_llm_id = llm_id
Expand Down
17 changes: 17 additions & 0 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from aixplain.modules.agent.agent_response_data import AgentResponseData
from aixplain.enums import ResponseStatus
from aixplain.modules.agent.utils import process_variables
from aixplain.modules.agent.cache_agents import load_agents
from typing import Dict, List, Text, Optional, Union
from urllib.parse import urljoin

Expand Down Expand Up @@ -77,6 +78,22 @@ def __init__(
tasks: List[AgentTask] = [],
**additional_info,
) -> None:
AgentCache, AgentDetails = load_agents(cache_expiry=86400)

if id in AgentDetails:
cached_agent = AgentDetails[id]
name = cached_agent["name"]
description = cached_agent["description"]
tools = cached_agent["tools"]
llm_id = cached_agent["llm_id"]
api_key = cached_agent["api_key"]
supplier = cached_agent["supplier"]
version = cached_agent["version"]
cost = cached_agent["cost"]
status = cached_agent["status"]
tasks = cached_agent["tasks"]
additional_info = cached_agent["additional_info"]

"""Create an Agent with the necessary information.

Args:
Expand Down
101 changes: 101 additions & 0 deletions aixplain/modules/agent/cache_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
import json
import logging
from datetime import datetime
from enum import Enum
from urllib.parse import urljoin
from typing import Dict, Optional, List, Tuple, Union, Text
from aixplain.utils import config
from aixplain.utils.request_utils import _request_with_retry
from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER
from aixplain.enums import Supplier
from aixplain.modules.agent.tool import Tool

AGENT_CACHE_FILE = f"{CACHE_FOLDER}/agents.json"
LOCK_FILE = f"{AGENT_CACHE_FILE}.lock"


def load_agents(cache_expiry: Optional[int] = None) -> Tuple[Enum, Dict]:
"""
Load AI agents from cache or fetch from backend if not cached.
Only agents with status "onboarded" should be cached.

Args:
cache_expiry (int, optional): Expiry time in seconds. Default is 24 hours.

Returns:
Tuple[Enum, Dict]: (Enum of agent IDs, Dictionary with agent details)
"""
if cache_expiry is None:
cache_expiry = 86400

os.makedirs(CACHE_FOLDER, exist_ok=True)

cached_data = load_from_cache(AGENT_CACHE_FILE, LOCK_FILE)

if cached_data is not None:
return parse_agents(cached_data)

api_key = config.TEAM_API_KEY
backend_url = config.BACKEND_URL
url = urljoin(backend_url, "sdk/agents")
headers = {"x-api-key": api_key, "Content-Type": "application/json"}

try:
response = _request_with_retry("get", url, headers=headers)
response.raise_for_status()
agents_data = response.json()
except Exception as e:
logging.error(f"Failed to fetch agents from API: {e}")
return Enum("Agent", {}), {}

onboarded_agents = [agent for agent in agents_data if agent.get("status", "").lower() == "onboarded"]

save_to_cache(AGENT_CACHE_FILE, {"items": onboarded_agents}, LOCK_FILE)

return parse_agents({"items": onboarded_agents})


def parse_agents(agents_data: Dict) -> Tuple[Enum, Dict]:
"""
Convert agent data into an Enum and dictionary format for easy use.

Args:
agents_data (Dict): JSON response with agents list.

Returns:
- agents_enum: Enum with agent IDs.
- agents_details: Dictionary containing all agent parameters.
"""
if not agents_data["items"]:
logging.warning("No onboarded agents found.")
return Enum("Agent", {}), {}

agents_enum = Enum(
"Agent",
{a["id"].upper().replace("-", "_"): a["id"] for a in agents_data["items"]},
type=str,
)

agents_details = {
agent["id"]: {
"id": agent["id"],
"name": agent.get("name", ""),
"description": agent.get("description", ""),
"role": agent.get("role", ""),
"tools": [Tool(t) if isinstance(t, dict) else t for t in agent.get("tools", [])],
"llm_id": agent.get("llm_id", "6646261c6eb563165658bbb1"),
"supplier": agent.get("supplier", "aiXplain"),
"version": agent.get("version", "1.0"),
"status": agent.get("status", "onboarded"),
"created_at": agent.get("created_at", ""),
"tasks": agent.get("tasks", []),
**agent,
}
for agent in agents_data["items"]
}

return agents_enum, agents_details


Agent, AgentDetails = load_agents()
16 changes: 16 additions & 0 deletions aixplain/modules/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from aixplain.modules.model.response import ModelResponse
from aixplain.enums.response_status import ResponseStatus
from aixplain.modules.model.model_parameters import ModelParameters
from aixplain.modules.model.cache_models import load_models


class Model(Asset):
Expand Down Expand Up @@ -91,6 +92,21 @@ def __init__(
model_params (Dict, optional): parameters for the function.
**additional_info: Any additional Model info to be saved
"""
ModelCache, ModelDetails = load_models(cache_expiry=86400)

if id in ModelDetails:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need the entire data to populate Model class instance, why we wouldn't directly get the related Model class instance directly from the cache?

For example below the sample code should be used in the ModelFactory

ModelCache = AixplainCache(Model)
ModelCache.load_assets() # this could run once internally (populate on demand)
model = ModelCache.get(id)
if not model:
    model = Model(id=..., foo=bar)


cached_model = ModelDetails[id]
input_params = cached_model["input"]
api_key = cached_model["spec"].get("api_key", api_key)
additional_info = cached_model["spec"].get("additional_info", {})
function = cached_model["spec"].get("function", function)
is_subscribed = cached_model["spec"].get("is_subscribed", is_subscribed)
created_at = cached_model["spec"].get("created_at", created_at)
model_params = cached_model["spec"].get("model_params", model_params)
output_params = cached_model["output"]
description = cached_model["spec"].get("description", description)

super().__init__(id, name, description, supplier, version, cost=cost)
self.api_key = api_key
self.additional_info = additional_info
Expand Down
88 changes: 88 additions & 0 deletions aixplain/modules/model/cache_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import json
import time
import logging
from datetime import datetime
from enum import Enum
from urllib.parse import urljoin
from typing import Dict, Optional, Union, Text
from aixplain.utils import config
from aixplain.utils.request_utils import _request_with_retry
from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER
from aixplain.enums import Supplier, Function

CACHE_FILE = f"{CACHE_FOLDER}/models.json"
LOCK_FILE = f"{CACHE_FILE}.lock"


def load_models(cache_expiry: Optional[int] = None):
"""
Load models from cache or fetch from backend if not cached.
Only models with status "onboarded" should be cached.

Args:
cache_expiry (int, optional): Expiry time in seconds. Default is user-configurable.
"""

api_key = config.TEAM_API_KEY
backend_url = config.BACKEND_URL

cached_data = load_from_cache(CACHE_FILE, LOCK_FILE)

if cached_data is not None:

return parse_models(cached_data)

url = urljoin(backend_url, "sdk/models")
headers = {"x-api-key": api_key, "Content-Type": "application/json"}

response = _request_with_retry("get", url, headers=headers)
if not 200 <= response.status_code < 300:
raise Exception(f"Models could not be loaded, API key '{api_key}' might be invalid.")

models_data = response.json()

onboarded_models = [model for model in models_data["items"] if model["status"].lower() == "onboarded"]
save_to_cache(CACHE_FILE, {"items": onboarded_models}, LOCK_FILE)

return parse_models({"items": onboarded_models})


def parse_models(models_data):
"""
Convert model data into an Enum and dictionary format for easy use.

Returns:
- models_enum: Enum with model IDs.
- models_details: Dictionary containing all model parameters.
"""

if not models_data["items"]:
logging.warning("No onboarded models found.")
return Enum("Model", {}), {}
models_enum = Enum("Model", {m["id"].upper().replace("-", "_"): m["id"] for m in models_data["items"]}, type=str)

models_details = {
model["id"]: {
"id": model["id"],
"name": model.get("name", ""),
"description": model.get("description", ""),
"api_key": model.get("api_key", config.TEAM_API_KEY),
"supplier": model.get("supplier", "aiXplain"),
"version": model.get("version"),
"function": model.get("function"),
"is_subscribed": model.get("is_subscribed", False),
"cost": model.get("cost"),
"created_at": model.get("created_at"),
"input_params": model.get("input_params"),
"output_params": model.get("output_params"),
"model_params": model.get("model_params"),
**model,
}
for model in models_data["items"]
}

return models_enum, models_details


Model, ModelDetails = load_models()
Loading