From 221375ff6230517d4cae4f1db6b3cae211fb6112 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <abdulbasitanees98@gmail.com> Date: Thu, 20 Mar 2025 15:06:48 +0000 Subject: [PATCH 1/8] add vectara support dynamically --- aixplain/enums/__init__.py | 1 + aixplain/enums/index_stores.py | 14 +++++++ .../__init__.py} | 40 ++++++++++++++----- aixplain/factories/index_factory/utils.py | 39 ++++++++++++++++++ 4 files changed, 84 insertions(+), 10 deletions(-) create mode 100644 aixplain/enums/index_stores.py rename aixplain/factories/{index_factory.py => index_factory/__init__.py} (51%) create mode 100644 aixplain/factories/index_factory/utils.py diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index 4f0364e1..e80c03c6 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -18,3 +18,4 @@ from .database_source import DatabaseSourceType from .embedding_model import EmbeddingModel from .asset_status import AssetStatus +from .index_stores import IndexStores diff --git a/aixplain/enums/index_stores.py b/aixplain/enums/index_stores.py new file mode 100644 index 00000000..379a460e --- /dev/null +++ b/aixplain/enums/index_stores.py @@ -0,0 +1,14 @@ +from enum import Enum + + +class IndexStores(Enum): + AIR = {"name": "air", "id": "66eae6656eb56311f2595011"} + VECTARA = {"name": "vectara", "id": "655e20f46eb563062a1aa301"} + GRAPHRAG = {"name": "graphrag", "id": "67dd6d487cbf0a57cf4b72f3"} + # ZERO_ENTROPY = {"name": "zero_entropy", "id": ""} + + def __str__(self): + return self.value["name"] + + def get_model_id(self): + return self.value["id"] diff --git a/aixplain/factories/index_factory.py b/aixplain/factories/index_factory/__init__.py similarity index 51% rename from aixplain/factories/index_factory.py rename to aixplain/factories/index_factory/__init__.py index 7588e583..2e65330d 100644 --- a/aixplain/factories/index_factory.py +++ b/aixplain/factories/index_factory/__init__.py @@ -1,20 +1,40 @@ +__author__ = "aiXplain" + +""" +Copyright 2022 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Abdul Basit Anees, Thiago Castro Ferreira, Zaina Abushaban +Date: December 26th 2024 +Description: + Index Factory Class +""" + from aixplain.modules.model.index_model import IndexModel from aixplain.factories import ModelFactory -from aixplain.enums import EmbeddingModel, Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier -from typing import Optional, Text, Union, List, Tuple - -AIR_MODEL_ID = "66eae6656eb56311f2595011" +from aixplain.enums import Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier +from typing import Text, Union, List, Tuple, Optional +from aixplain.factories.index_factory.utils import BaseIndexParams, get_model_id_and_payload class IndexFactory(ModelFactory): @classmethod - def create( - cls, name: Text, description: Text, embedding_model: EmbeddingModel = EmbeddingModel.OPENAI_ADA002 - ) -> IndexModel: + def create(cls, params: BaseIndexParams) -> IndexModel: """Create a new index collection""" - model = cls.get(AIR_MODEL_ID) + model_id, data = get_model_id_and_payload(params) + model = cls.get(model_id) - data = {"data": name, "description": description, "model": embedding_model.value} response = model.run(data=data) if response.status == ResponseStatus.SUCCESS: model_id = response.data @@ -23,7 +43,7 @@ def create( error_message = f"Index Factory Exception: {response.error_message}" if error_message == "": - error_message = "Index Factory Exception: An error occurred while creating the index collection." + error_message = "Index Factory Exception:An error occurred while creating the index collection." raise Exception(error_message) @classmethod diff --git a/aixplain/factories/index_factory/utils.py b/aixplain/factories/index_factory/utils.py new file mode 100644 index 00000000..4258251c --- /dev/null +++ b/aixplain/factories/index_factory/utils.py @@ -0,0 +1,39 @@ +from pydantic import BaseModel, ConfigDict +from typing import Text, Optional, Tuple, Dict +from aixplain.enums import IndexStores, EmbeddingModel + +class BaseIndexParams(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + name: Text + description: Optional[Text] = "" + +class VectaraParams(BaseIndexParams): + pass + +class ZeroEntropyParams(BaseIndexParams): + pass + +class AirParams(BaseIndexParams): + embedding_model: Optional[EmbeddingModel] = EmbeddingModel.OPENAI_ADA002 # should allow all embedding model ids as this is not very scalable + +class GraphRAGParams(BaseIndexParams): + embedding_model: Optional[EmbeddingModel] = EmbeddingModel.OPENAI_ADA002 + llm_model: Optional[Text] = "669a63646eb56306647e1091" # Gpt-4o-mini + +def get_model_id_and_payload(params: BaseIndexParams) -> Tuple[Text, Dict]: + payload = params.model_dump() + if isinstance(params, AirParams): + model_id = IndexStores.AIR.get_model_id() + payload["model"] = payload.pop("embedding_model") + elif isinstance(params, GraphRAGParams): + model_id = IndexStores.GRAPHRAG.get_model_id() + payload["model"] = payload.pop("embedding_model") + elif isinstance(params, VectaraParams): + model_id = IndexStores.VECTARA.get_model_id() + elif isinstance(params, ZeroEntropyParams): + # model_id = IndexStores.ZERO_ENTROPY.get_model_id() + raise ValueError("ZeroEntropy is not supported yet") + else: + raise ValueError(f"Invalid index params: {params}") + return model_id, payload From 976b5e3696ea0eaa1817c1dabc85750feed00c4a Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <abdulbasitanees98@gmail.com> Date: Thu, 27 Mar 2025 15:51:47 +0000 Subject: [PATCH 2/8] remove space --- aixplain/factories/index_factory/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py index 2e65330d..f5a63f61 100644 --- a/aixplain/factories/index_factory/__init__.py +++ b/aixplain/factories/index_factory/__init__.py @@ -43,7 +43,7 @@ def create(cls, params: BaseIndexParams) -> IndexModel: error_message = f"Index Factory Exception: {response.error_message}" if error_message == "": - error_message = "Index Factory Exception:An error occurred while creating the index collection." + error_message = "Index Factory Exception: An error occurred while creating the index collection." raise Exception(error_message) @classmethod From 3d20951bd74d9b2c7461b77b9db69315b56571bb Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <abdulbasitanees98@gmail.com> Date: Thu, 27 Mar 2025 17:11:51 +0000 Subject: [PATCH 3/8] refactor --- aixplain/factories/index_factory/__init__.py | 6 ++- aixplain/factories/index_factory/utils.py | 51 +++++++++++--------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py index f5a63f61..43d8326d 100644 --- a/aixplain/factories/index_factory/__init__.py +++ b/aixplain/factories/index_factory/__init__.py @@ -25,14 +25,16 @@ from aixplain.factories import ModelFactory from aixplain.enums import Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier from typing import Text, Union, List, Tuple, Optional -from aixplain.factories.index_factory.utils import BaseIndexParams, get_model_id_and_payload +from aixplain.factories.index_factory.utils import BaseIndexParams class IndexFactory(ModelFactory): @classmethod def create(cls, params: BaseIndexParams) -> IndexModel: """Create a new index collection""" - model_id, data = get_model_id_and_payload(params) + model_id = params.get_model_id() + data = params.to_dict() + model = cls.get(model_id) response = model.run(data=data) diff --git a/aixplain/factories/index_factory/utils.py b/aixplain/factories/index_factory/utils.py index 4258251c..121abdd4 100644 --- a/aixplain/factories/index_factory/utils.py +++ b/aixplain/factories/index_factory/utils.py @@ -8,32 +8,39 @@ class BaseIndexParams(BaseModel): name: Text description: Optional[Text] = "" + def to_dict(self): + return self.model_dump() + +class IndexParamsWithEmbeddingModel(BaseIndexParams): + + embedding_model: Optional[EmbeddingModel] = EmbeddingModel.OPENAI_ADA002 + + def to_dict(self): + data = super().to_dict() + data["model"] = data.pop("embedding_model").value + return data + class VectaraParams(BaseIndexParams): - pass + + def get_model_id(self): + return IndexStores.VECTARA.get_model_id() + class ZeroEntropyParams(BaseIndexParams): - pass + + def get_model_id(self): + raise ValueError("ZeroEntropy is not supported yet") + # return IndexStores.ZERO_ENTROPY.get_model_id() -class AirParams(BaseIndexParams): - embedding_model: Optional[EmbeddingModel] = EmbeddingModel.OPENAI_ADA002 # should allow all embedding model ids as this is not very scalable +class AirParams(IndexParamsWithEmbeddingModel): -class GraphRAGParams(BaseIndexParams): - embedding_model: Optional[EmbeddingModel] = EmbeddingModel.OPENAI_ADA002 + def get_model_id(self): + return IndexStores.AIR.get_model_id() + + +class GraphRAGParams(IndexParamsWithEmbeddingModel): llm_model: Optional[Text] = "669a63646eb56306647e1091" # Gpt-4o-mini -def get_model_id_and_payload(params: BaseIndexParams) -> Tuple[Text, Dict]: - payload = params.model_dump() - if isinstance(params, AirParams): - model_id = IndexStores.AIR.get_model_id() - payload["model"] = payload.pop("embedding_model") - elif isinstance(params, GraphRAGParams): - model_id = IndexStores.GRAPHRAG.get_model_id() - payload["model"] = payload.pop("embedding_model") - elif isinstance(params, VectaraParams): - model_id = IndexStores.VECTARA.get_model_id() - elif isinstance(params, ZeroEntropyParams): - # model_id = IndexStores.ZERO_ENTROPY.get_model_id() - raise ValueError("ZeroEntropy is not supported yet") - else: - raise ValueError(f"Invalid index params: {params}") - return model_id, payload + def get_model_id(self): + return IndexStores.GRAPHRAG.get_model_id() + From f5214ec091d71b584d68f4d26bf3f027699cdef7 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <abdulbasitanees98@gmail.com> Date: Thu, 27 Mar 2025 17:12:32 +0000 Subject: [PATCH 4/8] remove comment --- aixplain/enums/index_stores.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aixplain/enums/index_stores.py b/aixplain/enums/index_stores.py index 379a460e..38a5e9a2 100644 --- a/aixplain/enums/index_stores.py +++ b/aixplain/enums/index_stores.py @@ -5,7 +5,6 @@ class IndexStores(Enum): AIR = {"name": "air", "id": "66eae6656eb56311f2595011"} VECTARA = {"name": "vectara", "id": "655e20f46eb563062a1aa301"} GRAPHRAG = {"name": "graphrag", "id": "67dd6d487cbf0a57cf4b72f3"} - # ZERO_ENTROPY = {"name": "zero_entropy", "id": ""} def __str__(self): return self.value["name"] From 3d1f3b0b9171fa65a34a23d452c05ec45446e677 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <abdulbasitanees98@gmail.com> Date: Thu, 27 Mar 2025 17:27:07 +0000 Subject: [PATCH 5/8] handle non enum and add constant for llm --- aixplain/factories/index_factory/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aixplain/factories/index_factory/utils.py b/aixplain/factories/index_factory/utils.py index 121abdd4..d69df0be 100644 --- a/aixplain/factories/index_factory/utils.py +++ b/aixplain/factories/index_factory/utils.py @@ -2,6 +2,8 @@ from typing import Text, Optional, Tuple, Dict from aixplain.enums import IndexStores, EmbeddingModel +GPT_4O_MINI_ID = "669a63646eb56306647e1091" + class BaseIndexParams(BaseModel): model_config = ConfigDict(use_enum_values=True) @@ -17,7 +19,7 @@ class IndexParamsWithEmbeddingModel(BaseIndexParams): def to_dict(self): data = super().to_dict() - data["model"] = data.pop("embedding_model").value + data["model"] = data.pop("embedding_model").value if isinstance(self.embedding_model, EmbeddingModel) else data.pop("embedding_model") return data class VectaraParams(BaseIndexParams): @@ -39,7 +41,7 @@ def get_model_id(self): class GraphRAGParams(IndexParamsWithEmbeddingModel): - llm_model: Optional[Text] = "669a63646eb56306647e1091" # Gpt-4o-mini + llm_model: Optional[Text] = GPT_4O_MINI_ID def get_model_id(self): return IndexStores.GRAPHRAG.get_model_id() From 317eed54ba07fdcf30ab6299a0e7e8c0261d2efe Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <abdulbasitanees98@gmail.com> Date: Thu, 27 Mar 2025 21:43:43 +0000 Subject: [PATCH 6/8] restructure and use Generics --- aixplain/factories/index_factory/__init__.py | 9 ++--- aixplain/factories/index_factory/utils.py | 38 +++++++++----------- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py index 43d8326d..7a3f65c8 100644 --- a/aixplain/factories/index_factory/__init__.py +++ b/aixplain/factories/index_factory/__init__.py @@ -24,15 +24,16 @@ from aixplain.modules.model.index_model import IndexModel from aixplain.factories import ModelFactory from aixplain.enums import Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier -from typing import Text, Union, List, Tuple, Optional -from aixplain.factories.index_factory.utils import BaseIndexParams +from typing import Text, Union, List, Tuple, Optional, TypeVar, Generic +from aixplain.factories.index_factory.utils import BaseIndexParams +T = TypeVar('T', bound=BaseIndexParams) class IndexFactory(ModelFactory): @classmethod - def create(cls, params: BaseIndexParams) -> IndexModel: + def create(cls, params: T) -> IndexModel: """Create a new index collection""" - model_id = params.get_model_id() + model_id = params.id data = params.to_dict() model = cls.get(model_id) diff --git a/aixplain/factories/index_factory/utils.py b/aixplain/factories/index_factory/utils.py index d69df0be..204ea3c9 100644 --- a/aixplain/factories/index_factory/utils.py +++ b/aixplain/factories/index_factory/utils.py @@ -1,20 +1,19 @@ from pydantic import BaseModel, ConfigDict -from typing import Text, Optional, Tuple, Dict +from typing import Text, Optional, ClassVar from aixplain.enums import IndexStores, EmbeddingModel -GPT_4O_MINI_ID = "669a63646eb56306647e1091" class BaseIndexParams(BaseModel): model_config = ConfigDict(use_enum_values=True) - name: Text + data: Text description: Optional[Text] = "" def to_dict(self): - return self.model_dump() - -class IndexParamsWithEmbeddingModel(BaseIndexParams): - + return self.model_dump(exclude_none=True) + + +class BaseIndexParamsWithEmbeddingModel(BaseIndexParams): embedding_model: Optional[EmbeddingModel] = EmbeddingModel.OPENAI_ADA002 def to_dict(self): @@ -22,27 +21,22 @@ def to_dict(self): data["model"] = data.pop("embedding_model").value if isinstance(self.embedding_model, EmbeddingModel) else data.pop("embedding_model") return data + class VectaraParams(BaseIndexParams): - - def get_model_id(self): - return IndexStores.VECTARA.get_model_id() + id: ClassVar[str] = IndexStores.VECTARA.get_model_id() class ZeroEntropyParams(BaseIndexParams): - - def get_model_id(self): + id: ClassVar[str] = "" + + def __init__(self, **kwargs): raise ValueError("ZeroEntropy is not supported yet") - # return IndexStores.ZERO_ENTROPY.get_model_id() -class AirParams(IndexParamsWithEmbeddingModel): - def get_model_id(self): - return IndexStores.AIR.get_model_id() +class AirParams(BaseIndexParamsWithEmbeddingModel): + id: ClassVar[str] = IndexStores.AIR.get_model_id() -class GraphRAGParams(IndexParamsWithEmbeddingModel): - llm_model: Optional[Text] = GPT_4O_MINI_ID - - def get_model_id(self): - return IndexStores.GRAPHRAG.get_model_id() - +class GraphRAGParams(BaseIndexParamsWithEmbeddingModel): + id: ClassVar[str] = IndexStores.GRAPHRAG.get_model_id() + llm_model: Optional[Text] = None From 77e2458e0e7de4c7b82b8a70892e84926340f334 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <abdulbasitanees98@gmail.com> Date: Thu, 27 Mar 2025 21:47:55 +0000 Subject: [PATCH 7/8] minor fix --- aixplain/factories/index_factory/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py index 7a3f65c8..b994cb8d 100644 --- a/aixplain/factories/index_factory/__init__.py +++ b/aixplain/factories/index_factory/__init__.py @@ -29,7 +29,7 @@ T = TypeVar('T', bound=BaseIndexParams) -class IndexFactory(ModelFactory): +class IndexFactory(ModelFactory, Generic[T]): @classmethod def create(cls, params: T) -> IndexModel: """Create a new index collection""" From c6d4f07834a058ab4033bbdbb8f0a1227551bce6 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <abdulbasitanees98@gmail.com> Date: Mon, 31 Mar 2025 12:29:15 +0000 Subject: [PATCH 8/8] convert to abstract class, fix param name for llm --- aixplain/factories/index_factory/utils.py | 44 +++++++++++++++++------ 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/aixplain/factories/index_factory/utils.py b/aixplain/factories/index_factory/utils.py index 204ea3c9..d7488c36 100644 --- a/aixplain/factories/index_factory/utils.py +++ b/aixplain/factories/index_factory/utils.py @@ -1,42 +1,64 @@ from pydantic import BaseModel, ConfigDict from typing import Text, Optional, ClassVar from aixplain.enums import IndexStores, EmbeddingModel +from abc import ABC, abstractmethod -class BaseIndexParams(BaseModel): +class BaseIndexParams(BaseModel, ABC): model_config = ConfigDict(use_enum_values=True) - data: Text description: Optional[Text] = "" def to_dict(self): return self.model_dump(exclude_none=True) + @property + @abstractmethod + def id(self) -> str: + """Abstract property that must be implemented in subclasses.""" + pass + -class BaseIndexParamsWithEmbeddingModel(BaseIndexParams): +class BaseIndexParamsWithEmbeddingModel(BaseIndexParams, ABC): embedding_model: Optional[EmbeddingModel] = EmbeddingModel.OPENAI_ADA002 def to_dict(self): data = super().to_dict() - data["model"] = data.pop("embedding_model").value if isinstance(self.embedding_model, EmbeddingModel) else data.pop("embedding_model") + data["model"] = ( + data.pop("embedding_model").value + if isinstance(self.embedding_model, EmbeddingModel) + else data.pop("embedding_model") + ) return data class VectaraParams(BaseIndexParams): - id: ClassVar[str] = IndexStores.VECTARA.get_model_id() - + _id: ClassVar[str] = IndexStores.VECTARA.get_model_id() + + @property + def id(self) -> str: + return self._id + class ZeroEntropyParams(BaseIndexParams): - id: ClassVar[str] = "" + _id: ClassVar[str] = "" def __init__(self, **kwargs): raise ValueError("ZeroEntropy is not supported yet") class AirParams(BaseIndexParamsWithEmbeddingModel): - id: ClassVar[str] = IndexStores.AIR.get_model_id() - + _id: ClassVar[str] = IndexStores.AIR.get_model_id() + + @property + def id(self) -> str: + return self._id + class GraphRAGParams(BaseIndexParamsWithEmbeddingModel): - id: ClassVar[str] = IndexStores.GRAPHRAG.get_model_id() - llm_model: Optional[Text] = None + _id: ClassVar[str] = IndexStores.GRAPHRAG.get_model_id() + llm: Optional[Text] = None + + @property + def id(self) -> str: + return self._id