From 789d2dcb75e2b3d027a702d238982cf90c2b2ba8 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Mon, 25 Nov 2024 17:54:15 -0300 Subject: [PATCH 1/6] Onboard Utilities --- aixplain/enums/data_type.py | 1 + .../__init__.py} | 173 +++++------------ aixplain/factories/model_factory/utils.py | 142 ++++++++++++++ aixplain/modules/model/utility_model.py | 179 ++++++++++++++++++ tests/unit/model_test.py | 4 +- 5 files changed, 368 insertions(+), 131 deletions(-) rename aixplain/factories/{model_factory.py => model_factory/__init__.py} (73%) create mode 100644 aixplain/factories/model_factory/utils.py create mode 100644 aixplain/modules/model/utility_model.py diff --git a/aixplain/enums/data_type.py b/aixplain/enums/data_type.py index 11432bcf..dcae0422 100644 --- a/aixplain/enums/data_type.py +++ b/aixplain/enums/data_type.py @@ -35,6 +35,7 @@ class DataType(str, Enum): VIDEO = "video" EMBEDDING = "embedding" NUMBER = "number" + BOOLEAN = "boolean" def __str__(self): return self._value_ diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory/__init__.py similarity index 73% rename from aixplain/factories/model_factory.py rename to aixplain/factories/model_factory/__init__.py index b6588023..13db1fb4 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory/__init__.py @@ -24,14 +24,11 @@ import json import logging from aixplain.modules.model import Model -from aixplain.modules.model.llm_model import LLM +from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput from aixplain.enums import Function, Language, OwnershipType, Supplier, SortBy, SortOrder from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin -from warnings import warn -from aixplain.enums.function import FunctionInputOutput -from datetime import datetime class ModelFactory: @@ -45,53 +42,48 @@ class ModelFactory: backend_url = config.BACKEND_URL @classmethod - def _create_model_from_response(cls, response: Dict) -> Model: - """Converts response Json to 'Model' object + def create_utility_model(cls, name: Text, description: Text, inputs: List[UtilityModelInput], code: Text) -> UtilityModel: + """Create a utility model Args: - response (Dict): Json from API + name (Text): name of the model + description (Text): description of the model + inputs (List[UtilityModelInput]): inputs of the model + code (Text): code of the model Returns: - Model: Coverted 'Model' object + UtilityModel: created utility model """ - if "api_key" not in response: - response["api_key"] = config.TEAM_API_KEY - - parameters = {} - if "params" in response: - for param in response["params"]: - if "language" in param["name"]: - parameters[param["name"]] = [w["value"] for w in param["values"]] - - function = Function(response["function"]["id"]) - ModelClass = Model - if function == Function.TEXT_GENERATION: - ModelClass = LLM - - created_at = None - if "createdAt" in response and response["createdAt"]: - created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00")) - function_id = response["function"]["id"] - function = Function(function_id) - function_io = FunctionInputOutput.get(function_id, None) - input_params = {param["code"]: param for param in function_io["spec"]["params"]} - output_params = {param["code"]: param for param in function_io["spec"]["output"]} - - return ModelClass( - response["id"], - response["name"], - description=response.get("description", ""), - supplier=response["supplier"], - api_key=response["api_key"], - cost=response["pricing"], - function=function, - created_at=created_at, - parameters=parameters, - input_params=input_params, - output_params=output_params, - is_subscribed=True if "subscription" in response else False, - version=response["version"]["id"], + utility_model = UtilityModel( + id="", + name=name, + description=description, + inputs=inputs, + code=code, + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, ) + payload = utility_model.to_dict() + url = urljoin(cls.backend_url, "sdk/utilities") + headers = {"x-api-key": f"{config.TEAM_API_KEY}", "Content-Type": "application/json"} + try: + logging.info(f"Start service for POST Utility Model - {url} - {headers} - {payload}") + r = _request_with_retry("post", url, headers=headers, json=payload) + resp = r.json() + except Exception as e: + logging.error(f"Error creating utility model: {e}") + raise e + + if 200 <= r.status_code < 300: + utility_model.id = resp["id"] + logging.info(f"Utility Model Creation: Model {utility_model.id} instantiated.") + return utility_model + else: + error_message = ( + f"Utility Model Creation: Failed to create utility model. Status Code: {r.status_code}. Error: {resp}" + ) + logging.error(error_message) + raise Exception(error_message) @classmethod def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: @@ -128,7 +120,9 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: resp["api_key"] = config.TEAM_API_KEY if api_key is not None: resp["api_key"] = api_key - model = cls._create_model_from_response(resp) + from aixplain.factories.model_factory.utils import create_model_from_response + + model = create_model_from_response(resp) logging.info(f"Model Creation: Model {model_id} instantiated.") return model else: @@ -136,89 +130,6 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: logging.error(error_message) raise Exception(error_message) - @classmethod - def create_asset_from_id(cls, model_id: Text) -> Model: - warn( - 'This method will be deprecated in the next versions of the SDK. Use "get" instead.', - DeprecationWarning, - stacklevel=2, - ) - return cls.get(model_id) - - @classmethod - def _get_assets_from_page( - cls, - query, - page_number: int, - page_size: int, - function: Function, - suppliers: Union[Supplier, List[Supplier]], - source_languages: Union[Language, List[Language]], - target_languages: Union[Language, List[Language]], - is_finetunable: bool = None, - ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None, - sort_by: Optional[SortBy] = None, - sort_order: SortOrder = SortOrder.ASCENDING, - ) -> List[Model]: - try: - url = urljoin(cls.backend_url, "sdk/models/paginate") - filter_params = {"q": query, "pageNumber": page_number, "pageSize": page_size} - if is_finetunable is not None: - filter_params["isFineTunable"] = is_finetunable - if function is not None: - filter_params["functions"] = [function.value] - if suppliers is not None: - if isinstance(suppliers, Supplier) is True: - suppliers = [suppliers] - filter_params["suppliers"] = [supplier.value["id"] for supplier in suppliers] - if ownership is not None: - if isinstance(ownership, OwnershipType) is True: - ownership = [ownership] - filter_params["ownership"] = [ownership_.value for ownership_ in ownership] - - lang_filter_params = [] - if source_languages is not None: - if isinstance(source_languages, Language): - source_languages = [source_languages] - if function == Function.TRANSLATION: - lang_filter_params.append({"code": "sourcelanguage", "value": source_languages[0].value["language"]}) - else: - lang_filter_params.append({"code": "language", "value": source_languages[0].value["language"]}) - if source_languages[0].value["dialect"] != "": - lang_filter_params.append({"code": "dialect", "value": source_languages[0].value["dialect"]}) - if target_languages is not None: - if isinstance(target_languages, Language): - target_languages = [target_languages] - if function == Function.TRANSLATION: - code = "targetlanguage" - lang_filter_params.append({"code": code, "value": target_languages[0].value["language"]}) - if sort_by is not None: - filter_params["sort"] = [{"dir": sort_order.value, "field": sort_by.value}] - if len(lang_filter_params) != 0: - filter_params["ioFilter"] = lang_filter_params - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - - logging.info(f"Start service for POST Models Paginate - {url} - {headers} - {json.dumps(filter_params)}") - r = _request_with_retry("post", url, headers=headers, json=filter_params) - resp = r.json() - - except Exception as e: - error_message = f"Listing Models: Error in getting Models on Page {page_number}: {e}" - logging.error(error_message, exc_info=True) - return [] - if 200 <= r.status_code < 300: - logging.info(f"Listing Models: Status of getting Models on Page {page_number}: {r.status_code}") - all_models = resp["items"] - model_list = [cls._create_model_from_response(model_info_json) for model_info_json in all_models] - return model_list, resp["total"] - else: - error_message = f"Listing Models Error: Failed to retrieve models. Status Code: {r.status_code}. Error: {resp}" - logging.error(error_message) - raise Exception(error_message) - @classmethod def list( cls, @@ -249,7 +160,9 @@ def list( Returns: List[Model]: List of models based on given filters """ - models, total = cls._get_assets_from_page( + from aixplain.factories.model_factory.utils import get_assets_from_page + + models, total = get_assets_from_page( query, page_number, page_size, diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py new file mode 100644 index 00000000..01423795 --- /dev/null +++ b/aixplain/factories/model_factory/utils.py @@ -0,0 +1,142 @@ +import json +import logging +from aixplain.modules.model import Model +from aixplain.modules.model.llm_model import LLM +from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput +from aixplain.enums import DataType, Function, Language, OwnershipType, Supplier, SortBy, SortOrder +from aixplain.utils import config +from aixplain.utils.file_utils import _request_with_retry +from aixplain.enums.function import FunctionInputOutput +from datetime import datetime +from typing import Dict, Union, List, Optional, Tuple +from urllib.parse import urljoin + + +def create_model_from_response(response: Dict) -> Model: + """Converts response Json to 'Model' object + + Args: + response (Dict): Json from API + + Returns: + Model: Coverted 'Model' object + """ + if "api_key" not in response: + response["api_key"] = config.TEAM_API_KEY + + parameters = {} + if "params" in response: + for param in response["params"]: + if "language" in param["name"]: + parameters[param["name"]] = [w["value"] for w in param["values"]] + + function = Function(response["function"]["id"]) + inputs = [] + ModelClass = Model + if function == Function.TEXT_GENERATION: + ModelClass = LLM + elif function == Function.UTILITIES: + ModelClass = UtilityModel + inputs = [ + UtilityModelInput(name=param["name"], description=param.get("description", ""), type=DataType(param["dataType"])) + for param in response["params"] + ] + + created_at = None + if "createdAt" in response and response["createdAt"]: + created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00")) + function_id = response["function"]["id"] + function = Function(function_id) + function_io = FunctionInputOutput.get(function_id, None) + input_params = {param["code"]: param for param in function_io["spec"]["params"]} + output_params = {param["code"]: param for param in function_io["spec"]["output"]} + + return ModelClass( + response["id"], + response["name"], + description=response.get("description", ""), + code=response.get("code", ""), + supplier=response["supplier"], + api_key=response["api_key"], + cost=response["pricing"], + function=function, + created_at=created_at, + parameters=parameters, + input_params=input_params, + output_params=output_params, + is_subscribed=True if "subscription" in response else False, + version=response["version"]["id"], + inputs=inputs, + ) + + +def get_assets_from_page( + query, + page_number: int, + page_size: int, + function: Function, + suppliers: Union[Supplier, List[Supplier]], + source_languages: Union[Language, List[Language]], + target_languages: Union[Language, List[Language]], + is_finetunable: bool = None, + ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None, + sort_by: Optional[SortBy] = None, + sort_order: SortOrder = SortOrder.ASCENDING, +) -> List[Model]: + try: + url = urljoin(config.BACKEND_URL, "sdk/models/paginate") + filter_params = {"q": query, "pageNumber": page_number, "pageSize": page_size} + if is_finetunable is not None: + filter_params["isFineTunable"] = is_finetunable + if function is not None: + filter_params["functions"] = [function.value] + if suppliers is not None: + if isinstance(suppliers, Supplier) is True: + suppliers = [suppliers] + filter_params["suppliers"] = [supplier.value["id"] for supplier in suppliers] + if ownership is not None: + if isinstance(ownership, OwnershipType) is True: + ownership = [ownership] + filter_params["ownership"] = [ownership_.value for ownership_ in ownership] + + lang_filter_params = [] + if source_languages is not None: + if isinstance(source_languages, Language): + source_languages = [source_languages] + if function == Function.TRANSLATION: + lang_filter_params.append({"code": "sourcelanguage", "value": source_languages[0].value["language"]}) + else: + lang_filter_params.append({"code": "language", "value": source_languages[0].value["language"]}) + if source_languages[0].value["dialect"] != "": + lang_filter_params.append({"code": "dialect", "value": source_languages[0].value["dialect"]}) + if target_languages is not None: + if isinstance(target_languages, Language): + target_languages = [target_languages] + if function == Function.TRANSLATION: + code = "targetlanguage" + lang_filter_params.append({"code": code, "value": target_languages[0].value["language"]}) + if sort_by is not None: + filter_params["sort"] = [{"dir": sort_order.value, "field": sort_by.value}] + if len(lang_filter_params) != 0: + filter_params["ioFilter"] = lang_filter_params + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + logging.info(f"Start service for POST Models Paginate - {url} - {headers} - {json.dumps(filter_params)}") + r = _request_with_retry("post", url, headers=headers, json=filter_params) + resp = r.json() + + except Exception as e: + error_message = f"Listing Models: Error in getting Models on Page {page_number}: {e}" + logging.error(error_message, exc_info=True) + return [] + if 200 <= r.status_code < 300: + logging.info(f"Listing Models: Status of getting Models on Page {page_number}: {r.status_code}") + all_models = resp["items"] + from aixplain.factories.model_factory.utils import create_model_from_response + + model_list = [create_model_from_response(model_info_json) for model_info_json in all_models] + return model_list, resp["total"] + else: + error_message = f"Listing Models Error: Failed to retrieve models. Status Code: {r.status_code}. Error: {resp}" + logging.error(error_message) + raise Exception(error_message) diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py new file mode 100644 index 00000000..6835b56a --- /dev/null +++ b/aixplain/modules/model/utility_model.py @@ -0,0 +1,179 @@ +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: November 25th 2024 +Description: + Utility Model Class +""" +import logging +import os +import validators +from aixplain.enums import Function, Supplier, DataType +from aixplain.modules.model import Model +from aixplain.utils import config +from aixplain.utils.file_utils import _request_with_retry +from dataclasses import dataclass +from typing import Union, Optional, List, Text, Dict +from urllib.parse import urljoin + + +@dataclass +class UtilityModelInput: + name: Text + description: Text + type: DataType = DataType.TEXT + + def __post_init__(self): + self.validate_type() + + def validate_type(self): + if self.type not in [DataType.TEXT, DataType.BOOLEAN, DataType.NUMBER]: + raise ValueError("Utility Model Input type must be TEXT, BOOLEAN or NUMBER") + + def to_dict(self): + return {"name": self.name, "description": self.description, "type": self.type.value} + + +class UtilityModel(Model): + """Ready-to-use Utility Model. + + Attributes: + id (Text): ID of the Model + name (Text): Name of the Model + description (Text, optional): description of the model. Defaults to "". + api_key (Text, optional): API key of the Model. Defaults to None. + url (Text, optional): endpoint of the model. Defaults to config.MODELS_RUN_URL. + supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". + version (Text, optional): version of the model. Defaults to "1.0". + function (Text, optional): model AI function. Defaults to None. + url (str): URL to run the model. + backend_url (str): URL of the backend. + pricing (Dict, optional): model price. Defaults to None. + **additional_info: Any additional Model info to be saved + """ + + def __init__( + self, + id: Text, + name: Text, + description: Text, + code: Text, + inputs: List[UtilityModelInput], + api_key: Optional[Text] = None, + supplier: Union[Dict, Text, Supplier, int] = "aiXplain", + version: Optional[Text] = None, + function: Optional[Function] = None, + is_subscribed: bool = False, + cost: Optional[Dict] = None, + **additional_info, + ) -> None: + """Utility Model Init + + Args: + id (Text): ID of the Model + name (Text): Name of the Model + description (Text): description of the model. + code (Text): code of the model. + inputs (List[UtilityModelInput]): inputs of the model. + api_key (Text, optional): API key of the Model. Defaults to None. + supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". + version (Text, optional): version of the model. Defaults to "1.0". + function (Function, optional): model AI function. Defaults to None. + is_subscribed (bool, optional): Is the user subscribed. Defaults to False. + cost (Dict, optional): model price. Defaults to None. + **additional_info: Any additional Model info to be saved + """ + assert function == Function.UTILITIES, "Utility Model only supports 'utilities' function" + super().__init__( + id=id, + name=name, + description=description, + supplier=supplier, + version=version, + cost=cost, + function=function, + is_subscribed=is_subscribed, + api_key=api_key, + **additional_info, + ) + self.url = config.MODELS_RUN_URL + self.backend_url = config.BACKEND_URL + self.code = code + self.inputs = inputs + self.validate() + + def validate(self): + from aixplain.factories.file_factory import FileFactory + from uuid import uuid4 + + assert self.name and self.name.strip() != "", "Name is required" + assert self.description and self.description.strip() != "", "Description is required" + assert self.code and self.code.strip() != "", "Code is required" + assert self.inputs and len(self.inputs) > 0, "At least one input is required" + + self.code = FileFactory.to_link(self.code) + # store code in a temporary local path if it is not a valid URL or S3 path + if not validators.url(self.code) and not self.code.startswith("s3:"): + local_path = str(uuid4()) + with open(local_path, "w") as f: + f.write(self.code) + self.code = FileFactory.upload(local_path=local_path, is_temp=True) + os.remove(local_path) + + def to_dict(self): + return { + "name": self.name, + "description": self.description, + "inputs": [input.to_dict() for input in self.inputs], + "code": self.code, + "function": self.function.value, + } + + def update(self): + self.validate() + url = urljoin(self.backend_url, f"sdk/utilities/{self.id}") + headers = {"x-api-key": f"{self.api_key}", "Content-Type": "application/json"} + payload = self.to_dict() + try: + logging.info(f"Start service for PUT Utility Model - {url} - {headers} - {payload}") + r = _request_with_retry("put", url, headers=headers, json=payload) + response = r.json() + except Exception as e: + message = f"Utility Model Update Error: {e}" + logging.error(message) + raise Exception(f"{message}") + + if not 200 <= r.status_code < 300: + message = f"Utility Model Update Error: {response}" + logging.error(message) + raise Exception(f"{message}") + + def delete(self): + url = urljoin(self.backend_url, f"sdk/utilities/{self.id}") + headers = {"x-api-key": f"{self.api_key}", "Content-Type": "application/json"} + try: + logging.info(f"Start service for DELETE Utility Model - {url} - {headers}") + r = _request_with_retry("delete", url, headers=headers) + response = r.json() + except Exception: + message = "Utility Model Deletion Error: Make sure the utility model exists and you are the owner." + logging.error(message) + raise Exception(f"{message}") + + if r.status_code != 200: + message = f"Utility Model Deletion Error: {response}" + logging.error(message) + raise Exception(f"{message}") diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index a2463a8d..4431c135 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -164,6 +164,8 @@ def test_get_model_error_response(): def test_get_assets_from_page_error(): + from aixplain.factories.model_factory.utils import get_assets_from_page + with requests_mock.Mocker() as mock: query = "test-query" page_number = 0 @@ -175,7 +177,7 @@ def test_get_assets_from_page_error(): mock.post(url, headers=headers, json=error_response, status_code=500) with pytest.raises(Exception) as excinfo: - ModelFactory._get_assets_from_page( + get_assets_from_page( query=query, page_number=page_number, page_size=page_size, From 685db06a4f0e605857b140bc46e6673427fe9093 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Tue, 26 Nov 2024 15:50:48 -0300 Subject: [PATCH 2/6] Tests for utility models --- aixplain/factories/model_factory/__init__.py | 6 +- aixplain/modules/model/utility_model.py | 5 + .../model/run_utility_model_test.py | 33 +++++++ tests/unit/utility_test.py | 99 +++++++++++++++++++ 4 files changed, 142 insertions(+), 1 deletion(-) create mode 100644 tests/functional/model/run_utility_model_test.py create mode 100644 tests/unit/utility_test.py diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index 13db1fb4..75156426 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -42,7 +42,9 @@ class ModelFactory: backend_url = config.BACKEND_URL @classmethod - def create_utility_model(cls, name: Text, description: Text, inputs: List[UtilityModelInput], code: Text) -> UtilityModel: + def create_utility_model( + cls, name: Text, description: Text, inputs: List[UtilityModelInput], code: Text, output_description: Text + ) -> UtilityModel: """Create a utility model Args: @@ -50,6 +52,7 @@ def create_utility_model(cls, name: Text, description: Text, inputs: List[Utilit description (Text): description of the model inputs (List[UtilityModelInput]): inputs of the model code (Text): code of the model + output_description (Text): description of the output Returns: UtilityModel: created utility model @@ -62,6 +65,7 @@ def create_utility_model(cls, name: Text, description: Text, inputs: List[Utilit code=code, function=Function.UTILITIES, api_key=config.TEAM_API_KEY, + output_description=output_description, ) payload = utility_model.to_dict() url = urljoin(cls.backend_url, "sdk/utilities") diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 6835b56a..31bc6058 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -72,6 +72,7 @@ def __init__( description: Text, code: Text, inputs: List[UtilityModelInput], + output_description: Text, api_key: Optional[Text] = None, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, @@ -88,6 +89,7 @@ def __init__( description (Text): description of the model. code (Text): code of the model. inputs (List[UtilityModelInput]): inputs of the model. + output_description (Text): description of the output api_key (Text, optional): API key of the Model. Defaults to None. supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". version (Text, optional): version of the model. Defaults to "1.0". @@ -113,6 +115,7 @@ def __init__( self.backend_url = config.BACKEND_URL self.code = code self.inputs = inputs + self.output_description = output_description self.validate() def validate(self): @@ -123,6 +126,7 @@ def validate(self): assert self.description and self.description.strip() != "", "Description is required" assert self.code and self.code.strip() != "", "Code is required" assert self.inputs and len(self.inputs) > 0, "At least one input is required" + assert self.output_description and self.output_description.strip() != "", "Output description is required" self.code = FileFactory.to_link(self.code) # store code in a temporary local path if it is not a valid URL or S3 path @@ -140,6 +144,7 @@ def to_dict(self): "inputs": [input.to_dict() for input in self.inputs], "code": self.code, "function": self.function.value, + "outputDescription": self.output_description, } def update(self): diff --git a/tests/functional/model/run_utility_model_test.py b/tests/functional/model/run_utility_model_test.py new file mode 100644 index 00000000..5887c4ca --- /dev/null +++ b/tests/functional/model/run_utility_model_test.py @@ -0,0 +1,33 @@ +from aixplain.factories import ModelFactory +from aixplain.modules.model.utility_model import UtilityModelInput +from aixplain.enums import DataType + + +def test_run_utility_model(): + inputs = [ + UtilityModelInput(name="inputA", description="input A is the only input", type=DataType.TEXT), + ] + + output_description = "An example is 'test'" + + utility_model = ModelFactory.create_utility_model( + name="test_script", + description="This is a test script", + inputs=inputs, + code="def main(inputA):\n\treturn inputA", + output_description=output_description, + ) + + assert utility_model.id is not None + + response = utility_model.run(data={"inputA": "test"}) + assert response.status == "SUCCESS" + assert response.data == "test" + + utility_model.code = "def main(inputA):\n\treturn 5" + utility_model.update() + response = utility_model.run(data={"inputA": "test"}) + assert response.status == "SUCCESS" + assert str(response.data) == "5" + + utility_model.delete() diff --git a/tests/unit/utility_test.py b/tests/unit/utility_test.py new file mode 100644 index 00000000..c1b7b9e1 --- /dev/null +++ b/tests/unit/utility_test.py @@ -0,0 +1,99 @@ +import pytest +import requests_mock +from aixplain.factories.model_factory import ModelFactory +from urllib.parse import urljoin +from aixplain.utils import config +from aixplain.enums import DataType, Function +from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput +from unittest.mock import patch + + +def test_utility_model(): + with requests_mock.Mocker() as mock: + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="utility_model_test"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="utility_model_test"): + mock.post(urljoin(config.BACKEND_URL, "sdk/utilities"), json={"id": "123"}) + utility_model = ModelFactory.create_utility_model( + name="utility_model_test", + description="utility_model_test", + code="utility_model_test", + inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + output_description="output_description", + ) + assert utility_model.id == "123" + assert utility_model.name == "utility_model_test" + assert utility_model.description == "utility_model_test" + assert utility_model.code == "utility_model_test" + assert utility_model.inputs == [ + UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT) + ] + assert utility_model.output_description == "output_description" + + +def test_utility_model_with_invalid_name(): + with pytest.raises(Exception) as exc_info: + ModelFactory.create_utility_model( + name="", + description="utility_model_test", + code="utility_model_test", + inputs=[], + output_description="output_description", + ) + assert str(exc_info.value) == "Name is required" + + +def test_utility_model_with_invalid_inputs(): + with pytest.raises(Exception) as exc_info: + ModelFactory.create_utility_model( + name="utility_model_test", + description="utility_model_test", + code="utility_model_test", + inputs=[], + output_description="output_description", + ) + assert str(exc_info.value) == "At least one input is required" + + +def test_utility_model_to_dict(): + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="utility_model_test"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="utility_model_test"): + utility_model = UtilityModel( + id="123", + name="utility_model_test", + description="utility_model_test", + code="utility_model_test", + output_description="output_description", + inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, + ) + assert utility_model.to_dict() == { + "name": "utility_model_test", + "description": "utility_model_test", + "inputs": [{"name": "originCode", "description": "originCode", "type": "text"}], + "code": "utility_model_test", + "function": "utilities", + "outputDescription": "output_description", + } + + +def test_update_utility_model(): + with requests_mock.Mocker() as mock: + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="utility_model_test"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="utility_model_test"): + mock.put(urljoin(config.BACKEND_URL, "sdk/utilities/123"), json={"id": "123"}) + utility_model = UtilityModel( + id="123", + name="utility_model_test", + description="utility_model_test", + code="utility_model_test", + output_description="output_description", + inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, + ) + utility_model.description = "updated_description" + utility_model.update() + + assert utility_model.id == "123" + assert utility_model.description == "updated_description" From d543a0eb86d979432c6739b54590fb6e763afac3 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Tue, 3 Dec 2024 18:57:55 -0300 Subject: [PATCH 3/6] Parse code function --- aixplain/factories/model_factory/__init__.py | 11 +- aixplain/modules/model/utility_model.py | 53 +++---- aixplain/modules/model/utils.py | 60 ++++++- tests/unit/utility_test.py | 156 +++++++++++-------- 4 files changed, 179 insertions(+), 101 deletions(-) diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index 8ec3183b..435213c4 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -20,7 +20,7 @@ Description: Model Factory Class """ -from typing import Dict, List, Optional, Text, Tuple, Union +from typing import Callable, Dict, List, Optional, Text, Tuple, Union import json import logging from aixplain.modules.model import Model @@ -42,7 +42,12 @@ class ModelFactory: @classmethod def create_utility_model( - cls, name: Text, description: Text, inputs: List[UtilityModelInput], code: Text, output_description: Text + cls, + name: Text, + description: Text, + inputs: List[UtilityModelInput], + code: Union[Text, Callable], + output_description: Text, ) -> UtilityModel: """Create a utility model @@ -50,7 +55,7 @@ def create_utility_model( name (Text): name of the model description (Text): description of the model inputs (List[UtilityModelInput]): inputs of the model - code (Text): code of the model + code (Union[Text, Callable]): code of the model output_description (Text): description of the output Returns: diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 31bc6058..37c8b6f6 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -19,14 +19,13 @@ Utility Model Class """ import logging -import os -import validators from aixplain.enums import Function, Supplier, DataType from aixplain.modules.model import Model from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry +from aixplain.modules.model.utils import parse_code from dataclasses import dataclass -from typing import Union, Optional, List, Text, Dict +from typing import Callable, Union, Optional, List, Text, Dict from urllib.parse import urljoin @@ -53,15 +52,16 @@ class UtilityModel(Model): Attributes: id (Text): ID of the Model name (Text): Name of the Model - description (Text, optional): description of the model. Defaults to "". + code (Union[Text, Callable]): code of the model. + description (Text): description of the model. Defaults to "". + inputs (List[UtilityModelInput]): inputs of the model. Defaults to []. + output_description (Text): description of the output. Defaults to "". api_key (Text, optional): API key of the Model. Defaults to None. - url (Text, optional): endpoint of the model. Defaults to config.MODELS_RUN_URL. supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". version (Text, optional): version of the model. Defaults to "1.0". - function (Text, optional): model AI function. Defaults to None. - url (str): URL to run the model. - backend_url (str): URL of the backend. - pricing (Dict, optional): model price. Defaults to None. + function (Function, optional): model AI function. Defaults to None. + is_subscribed (bool, optional): Is the user subscribed. Defaults to False. + cost (Dict, optional): model price. Defaults to None. **additional_info: Any additional Model info to be saved """ @@ -69,10 +69,10 @@ def __init__( self, id: Text, name: Text, - description: Text, - code: Text, - inputs: List[UtilityModelInput], - output_description: Text, + code: Union[Text, Callable], + description: Optional[Text] = None, + inputs: List[UtilityModelInput] = [], + output_description: Text = "", api_key: Optional[Text] = None, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, @@ -86,10 +86,10 @@ def __init__( Args: id (Text): ID of the Model name (Text): Name of the Model - description (Text): description of the model. - code (Text): code of the model. - inputs (List[UtilityModelInput]): inputs of the model. - output_description (Text): description of the output + code (Union[Text, Callable]): code of the model. + description (Text): description of the model. Defaults to "". + inputs (List[UtilityModelInput]): inputs of the model. Defaults to []. + output_description (Text): description of the output. Defaults to "". api_key (Text, optional): API key of the Model. Defaults to None. supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". version (Text, optional): version of the model. Defaults to "1.0". @@ -119,24 +119,17 @@ def __init__( self.validate() def validate(self): - from aixplain.factories.file_factory import FileFactory - from uuid import uuid4 - + self.code, inputs, description = parse_code(self.code) + assert description is not None or self.description is not None, "Utility Model Error: Model description is required" + if self.description is None: + self.description = description + if len(self.inputs) == 0: + self.inputs = inputs assert self.name and self.name.strip() != "", "Name is required" assert self.description and self.description.strip() != "", "Description is required" assert self.code and self.code.strip() != "", "Code is required" - assert self.inputs and len(self.inputs) > 0, "At least one input is required" assert self.output_description and self.output_description.strip() != "", "Output description is required" - self.code = FileFactory.to_link(self.code) - # store code in a temporary local path if it is not a valid URL or S3 path - if not validators.url(self.code) and not self.code.startswith("s3:"): - local_path = str(uuid4()) - with open(local_path, "w") as f: - f.write(self.code) - self.code = FileFactory.upload(local_path=local_path, is_temp=True) - os.remove(local_path) - def to_dict(self): return { "name": self.name, diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 13cc1f7c..f2762080 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -3,7 +3,7 @@ import json import logging from aixplain.utils.file_utils import _request_with_retry -from typing import Dict, Text, Union, Optional +from typing import Callable, Dict, List, Text, Tuple, Union, Optional def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None): @@ -77,3 +77,61 @@ def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict: response = {"status": "FAILED", "error_message": error, "completed": True} logging.error(f"Error in request: {r.status_code}: {error}") return response + + +def parse_code(code: Union[Text, Callable]) -> Tuple[Text, List, Text]: + import inspect + import os + import re + import requests + import validators + from aixplain.modules.model.utility_model import UtilityModelInput + from aixplain.factories.file_factory import FileFactory + from uuid import uuid4 + + inputs, description = [], "" + + if isinstance(code, Callable): + str_code = inspect.getsource(code) + description = code.__doc__ + elif os.path.exists(code): + with open(code, "r") as f: + str_code = f.read() + elif validators.url(code): + str_code = requests.get(code).text + else: + str_code = code + + # assert str_code has a main function + if "def main(" not in str_code: + raise Exception("Utility Model Error: Code must have a main function") + + f = re.findall(r"main\((.*?(?:\s*=\s*[^,)]+)?(?:\s*,\s*.*?(?:\s*=\s*[^,)]+)?)*)\)", str_code) + parameters = f[0].split(",") if len(f) > 0 else [] + + for input in parameters: + assert ( + len(input.split(":")) > 1 + ), "Utility Model Error: Input type is required. For instance def main(a: int, b: int) -> int:" + input_name, input_type = input.split(":") + input_name = input_name.strip() + input_type = input_type.split("=")[0].strip() + + if input_type in ["int", "float"]: + input_type = "number" + inputs.append(UtilityModelInput(name=input_name, type=input_type, description="")) + elif input_type == "bool": + input_type = "boolean" + inputs.append(UtilityModelInput(name=input_name, type=input_type, description="")) + elif input_type == "str": + input_type = "text" + inputs.append(UtilityModelInput(name=input_name, type=input_type, description="")) + else: + raise Exception(f"Utility Model Error:Unsupported input type: {input_type}") + + local_path = str(uuid4()) + with open(local_path, "w") as f: + f.write(str_code) + code = FileFactory.upload(local_path=local_path, is_temp=True) + os.remove(local_path) + return code, inputs, description diff --git a/tests/unit/utility_test.py b/tests/unit/utility_test.py index c1b7b9e1..1988d0b1 100644 --- a/tests/unit/utility_test.py +++ b/tests/unit/utility_test.py @@ -12,88 +12,110 @@ def test_utility_model(): with requests_mock.Mocker() as mock: with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="utility_model_test"): with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="utility_model_test"): - mock.post(urljoin(config.BACKEND_URL, "sdk/utilities"), json={"id": "123"}) - utility_model = ModelFactory.create_utility_model( - name="utility_model_test", - description="utility_model_test", - code="utility_model_test", - inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], - output_description="output_description", - ) - assert utility_model.id == "123" - assert utility_model.name == "utility_model_test" - assert utility_model.description == "utility_model_test" - assert utility_model.code == "utility_model_test" - assert utility_model.inputs == [ - UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT) - ] - assert utility_model.output_description == "output_description" + with patch( + "aixplain.modules.model.utils.parse_code", + return_value=( + "def main(originCode: str)", + [UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + "utility_model_test", + ), + ): + mock.post(urljoin(config.BACKEND_URL, "sdk/utilities"), json={"id": "123"}) + utility_model = ModelFactory.create_utility_model( + name="utility_model_test", + description="utility_model_test", + code="def main(originCode: str)", + inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + output_description="output_description", + ) + assert utility_model.id == "123" + assert utility_model.name == "utility_model_test" + assert utility_model.description == "utility_model_test" + assert utility_model.code == "utility_model_test" + assert utility_model.inputs == [ + UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT) + ] + assert utility_model.output_description == "output_description" def test_utility_model_with_invalid_name(): - with pytest.raises(Exception) as exc_info: - ModelFactory.create_utility_model( - name="", - description="utility_model_test", - code="utility_model_test", - inputs=[], - output_description="output_description", - ) - assert str(exc_info.value) == "Name is required" - - -def test_utility_model_with_invalid_inputs(): - with pytest.raises(Exception) as exc_info: - ModelFactory.create_utility_model( - name="utility_model_test", - description="utility_model_test", - code="utility_model_test", - inputs=[], - output_description="output_description", - ) - assert str(exc_info.value) == "At least one input is required" + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="utility_model_test"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="utility_model_test"): + with patch( + "aixplain.modules.model.utils.parse_code", + return_value=( + "def main(originCode: str)", + [UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + "utility_model_test", + ), + ): + with pytest.raises(Exception) as exc_info: + ModelFactory.create_utility_model( + name="", + description="utility_model_test", + code="def main(originCode: str)", + inputs=[], + output_description="output_description", + ) + assert str(exc_info.value) == "Name is required" def test_utility_model_to_dict(): with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="utility_model_test"): with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="utility_model_test"): - utility_model = UtilityModel( - id="123", - name="utility_model_test", - description="utility_model_test", - code="utility_model_test", - output_description="output_description", - inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], - function=Function.UTILITIES, - api_key=config.TEAM_API_KEY, - ) - assert utility_model.to_dict() == { - "name": "utility_model_test", - "description": "utility_model_test", - "inputs": [{"name": "originCode", "description": "originCode", "type": "text"}], - "code": "utility_model_test", - "function": "utilities", - "outputDescription": "output_description", - } - - -def test_update_utility_model(): - with requests_mock.Mocker() as mock: - with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="utility_model_test"): - with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="utility_model_test"): - mock.put(urljoin(config.BACKEND_URL, "sdk/utilities/123"), json={"id": "123"}) + with patch( + "aixplain.modules.model.utils.parse_code", + return_value=( + "def main(originCode: str)", + [UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + "utility_model_test", + ), + ): utility_model = UtilityModel( id="123", name="utility_model_test", description="utility_model_test", - code="utility_model_test", + code="def main(originCode: str)", output_description="output_description", inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], function=Function.UTILITIES, api_key=config.TEAM_API_KEY, ) - utility_model.description = "updated_description" - utility_model.update() + assert utility_model.to_dict() == { + "name": "utility_model_test", + "description": "utility_model_test", + "inputs": [{"name": "originCode", "description": "originCode", "type": "text"}], + "code": "utility_model_test", + "function": "utilities", + "outputDescription": "output_description", + } + + +def test_update_utility_model(): + with requests_mock.Mocker() as mock: + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="def main(originCode: str)"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="def main(originCode: str)"): + with patch( + "aixplain.modules.model.utils.parse_code", + return_value=( + "def main(originCode: str)", + [UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + "utility_model_test", + ), + ): + mock.put(urljoin(config.BACKEND_URL, "sdk/utilities/123"), json={"id": "123"}) + utility_model = UtilityModel( + id="123", + name="utility_model_test", + description="utility_model_test", + code="def main(originCode: str)", + output_description="output_description", + inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, + ) + utility_model.description = "updated_description" + utility_model.update() - assert utility_model.id == "123" - assert utility_model.description == "updated_description" + assert utility_model.id == "123" + assert utility_model.description == "updated_description" From 66d8fd9d12d70584d227f9a445ce810eb3f835f8 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Wed, 4 Dec 2024 10:57:34 -0300 Subject: [PATCH 4/6] Adapting utility model onboarding to unit tests --- aixplain/factories/model_factory/__init__.py | 14 +-- aixplain/modules/model/utility_model.py | 12 +-- aixplain/modules/model/utils.py | 11 ++- .../model/run_utility_model_test.py | 8 +- tests/unit/utility_test.py | 93 +++++++++++++------ 5 files changed, 90 insertions(+), 48 deletions(-) diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index 435213c4..ca035825 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -44,19 +44,19 @@ class ModelFactory: def create_utility_model( cls, name: Text, - description: Text, - inputs: List[UtilityModelInput], code: Union[Text, Callable], - output_description: Text, + inputs: List[UtilityModelInput] = [], + description: Optional[Text] = None, + output_examples: Text = "", ) -> UtilityModel: """Create a utility model Args: name (Text): name of the model - description (Text): description of the model - inputs (List[UtilityModelInput]): inputs of the model code (Union[Text, Callable]): code of the model - output_description (Text): description of the output + description (Text, optional): description of the model + inputs (List[UtilityModelInput], optional): inputs of the model + output_examples (Text, optional): output examples Returns: UtilityModel: created utility model @@ -69,7 +69,7 @@ def create_utility_model( code=code, function=Function.UTILITIES, api_key=config.TEAM_API_KEY, - output_description=output_description, + output_examples=output_examples, ) payload = utility_model.to_dict() url = urljoin(cls.backend_url, "sdk/utilities") diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 37c8b6f6..1bc40f67 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -55,7 +55,7 @@ class UtilityModel(Model): code (Union[Text, Callable]): code of the model. description (Text): description of the model. Defaults to "". inputs (List[UtilityModelInput]): inputs of the model. Defaults to []. - output_description (Text): description of the output. Defaults to "". + output_examples (Text): output examples. Defaults to "". api_key (Text, optional): API key of the Model. Defaults to None. supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". version (Text, optional): version of the model. Defaults to "1.0". @@ -72,7 +72,7 @@ def __init__( code: Union[Text, Callable], description: Optional[Text] = None, inputs: List[UtilityModelInput] = [], - output_description: Text = "", + output_examples: Text = "", api_key: Optional[Text] = None, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, @@ -89,7 +89,7 @@ def __init__( code (Union[Text, Callable]): code of the model. description (Text): description of the model. Defaults to "". inputs (List[UtilityModelInput]): inputs of the model. Defaults to []. - output_description (Text): description of the output. Defaults to "". + output_examples (Text): output examples. Defaults to "". api_key (Text, optional): API key of the Model. Defaults to None. supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". version (Text, optional): version of the model. Defaults to "1.0". @@ -115,7 +115,7 @@ def __init__( self.backend_url = config.BACKEND_URL self.code = code self.inputs = inputs - self.output_description = output_description + self.output_examples = output_examples self.validate() def validate(self): @@ -128,7 +128,7 @@ def validate(self): assert self.name and self.name.strip() != "", "Name is required" assert self.description and self.description.strip() != "", "Description is required" assert self.code and self.code.strip() != "", "Code is required" - assert self.output_description and self.output_description.strip() != "", "Output description is required" + assert self.output_examples and self.output_examples.strip() != "", "Output description is required" def to_dict(self): return { @@ -137,7 +137,7 @@ def to_dict(self): "inputs": [input.to_dict() for input in self.inputs], "code": self.code, "function": self.function.value, - "outputDescription": self.output_description, + "outputDescription": self.output_examples, } def update(self): diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index f2762080..68131c96 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -85,6 +85,7 @@ def parse_code(code: Union[Text, Callable]) -> Tuple[Text, List, Text]: import re import requests import validators + from aixplain.enums import DataType from aixplain.modules.model.utility_model import UtilityModelInput from aixplain.factories.file_factory import FileFactory from uuid import uuid4 @@ -93,7 +94,7 @@ def parse_code(code: Union[Text, Callable]) -> Tuple[Text, List, Text]: if isinstance(code, Callable): str_code = inspect.getsource(code) - description = code.__doc__ + description = code.__doc__.strip() if code.__doc__ else "" elif os.path.exists(code): with open(code, "r") as f: str_code = f.read() @@ -119,15 +120,15 @@ def parse_code(code: Union[Text, Callable]) -> Tuple[Text, List, Text]: if input_type in ["int", "float"]: input_type = "number" - inputs.append(UtilityModelInput(name=input_name, type=input_type, description="")) + inputs.append(UtilityModelInput(name=input_name, type=DataType.NUMBER, description="")) elif input_type == "bool": input_type = "boolean" - inputs.append(UtilityModelInput(name=input_name, type=input_type, description="")) + inputs.append(UtilityModelInput(name=input_name, type=DataType.BOOLEAN, description="")) elif input_type == "str": input_type = "text" - inputs.append(UtilityModelInput(name=input_name, type=input_type, description="")) + inputs.append(UtilityModelInput(name=input_name, type=DataType.TEXT, description="")) else: - raise Exception(f"Utility Model Error:Unsupported input type: {input_type}") + raise Exception(f"Utility Model Error: Unsupported input type: {input_type}") local_path = str(uuid4()) with open(local_path, "w") as f: diff --git a/tests/functional/model/run_utility_model_test.py b/tests/functional/model/run_utility_model_test.py index 5887c4ca..ce0b7579 100644 --- a/tests/functional/model/run_utility_model_test.py +++ b/tests/functional/model/run_utility_model_test.py @@ -14,17 +14,19 @@ def test_run_utility_model(): name="test_script", description="This is a test script", inputs=inputs, - code="def main(inputA):\n\treturn inputA", - output_description=output_description, + code="def main(inputA: str):\n\treturn inputA", + output_examples=output_description, ) assert utility_model.id is not None + assert utility_model.inputs == inputs + assert utility_model.output_examples == output_description response = utility_model.run(data={"inputA": "test"}) assert response.status == "SUCCESS" assert response.data == "test" - utility_model.code = "def main(inputA):\n\treturn 5" + utility_model.code = "def main(inputA: str):\n\treturn 5" utility_model.update() response = utility_model.run(data={"inputA": "test"}) assert response.status == "SUCCESS" diff --git a/tests/unit/utility_test.py b/tests/unit/utility_test.py index 1988d0b1..4f46925c 100644 --- a/tests/unit/utility_test.py +++ b/tests/unit/utility_test.py @@ -5,6 +5,7 @@ from aixplain.utils import config from aixplain.enums import DataType, Function from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput +from aixplain.modules.model.utils import parse_code from unittest.mock import patch @@ -12,30 +13,19 @@ def test_utility_model(): with requests_mock.Mocker() as mock: with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="utility_model_test"): with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="utility_model_test"): - with patch( - "aixplain.modules.model.utils.parse_code", - return_value=( - "def main(originCode: str)", - [UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], - "utility_model_test", - ), - ): - mock.post(urljoin(config.BACKEND_URL, "sdk/utilities"), json={"id": "123"}) - utility_model = ModelFactory.create_utility_model( - name="utility_model_test", - description="utility_model_test", - code="def main(originCode: str)", - inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], - output_description="output_description", - ) - assert utility_model.id == "123" - assert utility_model.name == "utility_model_test" - assert utility_model.description == "utility_model_test" - assert utility_model.code == "utility_model_test" - assert utility_model.inputs == [ - UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT) - ] - assert utility_model.output_description == "output_description" + mock.post(urljoin(config.BACKEND_URL, "sdk/utilities"), json={"id": "123"}) + utility_model = ModelFactory.create_utility_model( + name="utility_model_test", + description="utility_model_test", + code="def main(originCode: str)", + output_examples="output_description", + ) + assert utility_model.id == "123" + assert utility_model.name == "utility_model_test" + assert utility_model.description == "utility_model_test" + assert utility_model.code == "utility_model_test" + assert utility_model.inputs == [UtilityModelInput(name="originCode", description="", type=DataType.TEXT)] + assert utility_model.output_examples == "output_description" def test_utility_model_with_invalid_name(): @@ -55,7 +45,7 @@ def test_utility_model_with_invalid_name(): description="utility_model_test", code="def main(originCode: str)", inputs=[], - output_description="output_description", + output_examples="output_description", ) assert str(exc_info.value) == "Name is required" @@ -76,7 +66,7 @@ def test_utility_model_to_dict(): name="utility_model_test", description="utility_model_test", code="def main(originCode: str)", - output_description="output_description", + output_examples="output_description", inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], function=Function.UTILITIES, api_key=config.TEAM_API_KEY, @@ -109,7 +99,7 @@ def test_update_utility_model(): name="utility_model_test", description="utility_model_test", code="def main(originCode: str)", - output_description="output_description", + output_examples="output_description", inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], function=Function.UTILITIES, api_key=config.TEAM_API_KEY, @@ -119,3 +109,52 @@ def test_update_utility_model(): assert utility_model.id == "123" assert utility_model.description == "updated_description" + + +def test_parse_code(): + # Code is a string + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="code_link"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="code_link"): + code = "def main(originCode: str) -> str:\n return originCode" + code_link, inputs, description = parse_code(code) + assert inputs == [UtilityModelInput(name="originCode", description="", type=DataType.TEXT)] + assert description == "" + assert code_link == "code_link" + + # Code is a function + def main(a: int, b: int): + """ + This function adds two numbers + """ + return a + b + + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="code_link"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="code_link"): + code = main + code_link, inputs, description = parse_code(code) + assert inputs == [ + UtilityModelInput(name="a", description="", type=DataType.NUMBER), + UtilityModelInput(name="b", description="", type=DataType.NUMBER), + ] + assert description == "This function adds two numbers" + assert code_link == "code_link" + + # Code must have a main function + code = "def wrong_function_name(originCode: str) -> str:\n return originCode" + with pytest.raises(Exception) as exc_info: + parse_code(code) + assert str(exc_info.value) == "Utility Model Error: Code must have a main function" + + # Input type is required + def main(originCode): + return originCode + + with pytest.raises(Exception) as exc_info: + parse_code(main) + assert str(exc_info.value) == "Utility Model Error: Input type is required. For instance def main(a: int, b: int) -> int:" + + # Unsupported input type + code = "def main(originCode: list) -> str:\n return originCode" + with pytest.raises(Exception) as exc_info: + parse_code(code) + assert str(exc_info.value) == "Utility Model Error: Unsupported input type: list" From 9b3d3ba052411c4afb571ec2093b56803b19a982 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Wed, 4 Dec 2024 14:46:01 -0300 Subject: [PATCH 5/6] Unit test for utility removal --- tests/unit/utility_test.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/unit/utility_test.py b/tests/unit/utility_test.py index 4f46925c..89803cac 100644 --- a/tests/unit/utility_test.py +++ b/tests/unit/utility_test.py @@ -111,6 +111,25 @@ def test_update_utility_model(): assert utility_model.description == "updated_description" +def test_delete_utility_model(): + with requests_mock.Mocker() as mock: + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="def main(originCode: str)"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="def main(originCode: str)"): + mock.delete(urljoin(config.BACKEND_URL, "sdk/utilities/123"), status_code=200, json={"id": "123"}) + utility_model = UtilityModel( + id="123", + name="utility_model_test", + description="utility_model_test", + code="def main(originCode: str)", + output_examples="output_description", + inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, + ) + utility_model.delete() + assert mock.called + + def test_parse_code(): # Code is a string with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="code_link"): From f830b598523f122fe462f36b1ffacc993c0125a3 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Tue, 10 Dec 2024 18:50:13 -0300 Subject: [PATCH 6/6] Auto-describe utility model inputs --- aixplain/modules/model/utils.py | 12 +++++++++--- tests/unit/utility_test.py | 12 ++++++++---- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 68131c96..f3691928 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -120,13 +120,19 @@ def parse_code(code: Union[Text, Callable]) -> Tuple[Text, List, Text]: if input_type in ["int", "float"]: input_type = "number" - inputs.append(UtilityModelInput(name=input_name, type=DataType.NUMBER, description="")) + inputs.append( + UtilityModelInput(name=input_name, type=DataType.NUMBER, description=f"The {input_name} input is a number") + ) elif input_type == "bool": input_type = "boolean" - inputs.append(UtilityModelInput(name=input_name, type=DataType.BOOLEAN, description="")) + inputs.append( + UtilityModelInput(name=input_name, type=DataType.BOOLEAN, description=f"The {input_name} input is a boolean") + ) elif input_type == "str": input_type = "text" - inputs.append(UtilityModelInput(name=input_name, type=DataType.TEXT, description="")) + inputs.append( + UtilityModelInput(name=input_name, type=DataType.TEXT, description=f"The {input_name} input is a text") + ) else: raise Exception(f"Utility Model Error: Unsupported input type: {input_type}") diff --git a/tests/unit/utility_test.py b/tests/unit/utility_test.py index 89803cac..595ab4ae 100644 --- a/tests/unit/utility_test.py +++ b/tests/unit/utility_test.py @@ -24,7 +24,9 @@ def test_utility_model(): assert utility_model.name == "utility_model_test" assert utility_model.description == "utility_model_test" assert utility_model.code == "utility_model_test" - assert utility_model.inputs == [UtilityModelInput(name="originCode", description="", type=DataType.TEXT)] + assert utility_model.inputs == [ + UtilityModelInput(name="originCode", description="The originCode input is a text", type=DataType.TEXT) + ] assert utility_model.output_examples == "output_description" @@ -136,7 +138,9 @@ def test_parse_code(): with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="code_link"): code = "def main(originCode: str) -> str:\n return originCode" code_link, inputs, description = parse_code(code) - assert inputs == [UtilityModelInput(name="originCode", description="", type=DataType.TEXT)] + assert inputs == [ + UtilityModelInput(name="originCode", description="The originCode input is a text", type=DataType.TEXT) + ] assert description == "" assert code_link == "code_link" @@ -152,8 +156,8 @@ def main(a: int, b: int): code = main code_link, inputs, description = parse_code(code) assert inputs == [ - UtilityModelInput(name="a", description="", type=DataType.NUMBER), - UtilityModelInput(name="b", description="", type=DataType.NUMBER), + UtilityModelInput(name="a", description="The a input is a number", type=DataType.NUMBER), + UtilityModelInput(name="b", description="The b input is a number", type=DataType.NUMBER), ] assert description == "This function adds two numbers" assert code_link == "code_link"