Skip to content

Commit

Permalink
Add Langstudio embedding connection for embedder_operator (#338)
Browse files Browse the repository at this point in the history
* Add Langstudio embedding connection for embedder_operator

* Modify langstudio connection config

* Fix import error for langstudio sdk

* Add langstudio sdk to rag op dockerfile
  • Loading branch information
wwxxzz authored Jan 16, 2025
1 parent 4d75d9c commit b285c03
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 7 deletions.
24 changes: 17 additions & 7 deletions src/pai_rag/integrations/embeddings/pai/embedding_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import os
from llama_index.core import Settings
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.dashscope import DashScopeEmbedding
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from pai_rag.utils.download_models import ModelScopeDownloader
from pai_rag.integrations.embeddings.pai.pai_embedding_config import (
PaiBaseEmbeddingConfig,
DashScopeEmbeddingConfig,
OpenAIEmbeddingConfig,
HuggingFaceEmbeddingConfig,
CnClipEmbeddingConfig,
LangStudioEmbeddingConfig,
)

from llama_index.core import Settings
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.dashscope import DashScopeEmbedding
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from pai_rag.integrations.embeddings.clip.cnclip_embedding import CnClipEmbedding
import os
from loguru import logger
from pai_rag.utils.download_models import ModelScopeDownloader


def create_embedding(
Expand All @@ -22,6 +22,7 @@ def create_embedding(
if isinstance(embed_config, OpenAIEmbeddingConfig):
embed_model = OpenAIEmbedding(
api_key=embed_config.api_key,
api_base=embed_config.api_base,
embed_batch_size=embed_config.embed_batch_size,
callback_manager=Settings.callback_manager,
)
Expand Down Expand Up @@ -91,7 +92,16 @@ def create_embedding(
logger.info(
f"Initialized CnClip embedding model {embed_config.model} with {embed_config.embed_batch_size} batch size."
)
elif isinstance(embed_config, LangStudioEmbeddingConfig):
from pai_rag.integrations.embeddings.pai.langstudio_utils import (
convert_langstudio_embed_config,
)

converted_embed_config = convert_langstudio_embed_config(embed_config)
logger.info(
f"Initialized LangStudio embedding model with {converted_embed_config}."
)
return create_embedding(converted_embed_config, pai_rag_model_dir)
else:
raise ValueError(f"Unknown Embedding source: {embed_config}")

Expand Down
100 changes: 100 additions & 0 deletions src/pai_rag/integrations/embeddings/pai/langstudio_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
from alibabacloud_credentials.client import Client as CredentialClient
from alibabacloud_credentials.models import Config as CredentialConfig
from alibabacloud_pailangstudio20240710.client import Client as LangStudioClient
from alibabacloud_pailangstudio20240710.models import (
GetConnectionRequest,
ListConnectionsRequest,
)
from alibabacloud_tea_openapi import models as open_api_models
from pai_rag.utils.constants import DEFAULT_DASHSCOPE_EMBEDDING_MODEL
from pai_rag.integrations.embeddings.pai.pai_embedding_config import parse_embed_config
from loguru import logger


def get_region_id():
return next(
(
os.environ[key]
for key in ["REGION", "REGION_ID", "ALIBABA_CLOUD_REGION_ID"]
if key in os.environ and os.environ[key]
),
"cn-hangzhou",
)


def get_connection_info(region_id: str, connection_name: str, workspace_id: str):
"""
Get Connection information from LangStudio API.
"""
config1 = CredentialConfig(
type="access_key",
access_key_id=os.environ.get("ALIBABA_CLOUD_ACCESS_KEY_ID"),
access_key_secret=os.environ.get("ALIBABA_CLOUD_ACCESS_KEY_SECRET"),
)
public_endpoint = f"pailangstudio.{region_id}.aliyuncs.com"
client = LangStudioClient(
config=open_api_models.Config(
# Use default credential chain, see:
# https://help.aliyun.com/zh/sdk/developer-reference/v2-manage-python-access-credentials#3ca299f04bw3c
credential=CredentialClient(config=config1),
endpoint=public_endpoint,
)
)
resp = client.list_connections(
request=ListConnectionsRequest(
connection_name=connection_name, workspace_id=workspace_id, max_results=50
)
)
connection_info = next(
(
conn
for conn in resp.body.connections
if conn.connection_name == connection_name
),
None,
)
if not connection_info:
raise ValueError(f"Connection {connection_name} not found")
ls_connection = client.get_connection(
connection_id=connection_info.connection_id,
request=GetConnectionRequest(
workspace_id=workspace_id,
encrypt_option="PlainText",
),
)
conn_info = ls_connection.body
configs = conn_info.configs or {}
secrets = conn_info.secrets or {}

logger.info(f"Configs conn_info:\n {conn_info}")
return conn_info, configs, secrets


def convert_langstudio_embed_config(embed_config):
region_id = embed_config.region_id or get_region_id()
conn_info, config, secrets = get_connection_info(
region_id, embed_config.connection_name, embed_config.workspace_id
)
if conn_info.custom_type == "OpenEmbeddingConnection":
return parse_embed_config(
{
"source": "openai",
"api_key": secrets.get("api_key", None),
"api_base": config.get("base_url", None),
"model": embed_config.model,
"embed_batch_size": embed_config.embed_batch_size,
}
)
elif conn_info.custom_type == "DashScopeConnection":
return parse_embed_config(
{
"source": "dashscope",
"api_key": secrets.get("api_key", None)
or os.getenv("DASHSCOPE_API_KEY"),
"model": embed_config.model or DEFAULT_DASHSCOPE_EMBEDDING_MODEL,
"embed_batch_size": embed_config.embed_batch_size,
}
)
else:
raise ValueError(f"Unknown connection type: {conn_info.custom_type}")
10 changes: 10 additions & 0 deletions src/pai_rag/integrations/embeddings/pai/pai_embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class SupportedEmbedType(str, Enum):
openai = "openai"
huggingface = "huggingface"
cnclip = "cnclip" # Chinese CLIP
langstudio = "langstudio"


class PaiBaseEmbeddingConfig(BaseModel):
Expand Down Expand Up @@ -42,6 +43,7 @@ class OpenAIEmbeddingConfig(PaiBaseEmbeddingConfig):
source: Literal[SupportedEmbedType.openai] = SupportedEmbedType.openai
model: str | None = None # use default
api_key: str | None = None # use default
api_base: str | None = None # use default


class HuggingFaceEmbeddingConfig(PaiBaseEmbeddingConfig):
Expand All @@ -54,6 +56,14 @@ class CnClipEmbeddingConfig(PaiBaseEmbeddingConfig):
model: str | None = "ViT-L-14"


class LangStudioEmbeddingConfig(PaiBaseEmbeddingConfig):
source: Literal[SupportedEmbedType.langstudio] = SupportedEmbedType.langstudio
region_id: str | None = "cn-hangzhou" # use default
connection_name: str | None = None
workspace_id: str | None = None
model: str | None = None


SupporttedEmbeddingClsMap = {
cls.get_type(): cls for cls in PaiBaseEmbeddingConfig.get_subclasses()
}
Expand Down
1 change: 1 addition & 0 deletions src/pai_rag/tools/data_process/docker/Dockerfile_cpu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ COPY . .

RUN poetry install && rm -rf $POETRY_CACHE_DIR
RUN poetry run aliyun-bootstrap -a install
RUN pip3 install https://sdk-portal-us-prod.oss-accelerate.aliyuncs.com/downloads/u-5fa6e81f-04cd-41d6-86ac-d8bffa4525e7-python-tea.zip
RUN pip3 install ray[default]

FROM python:3.11-slim AS prod
Expand Down
4 changes: 4 additions & 0 deletions src/pai_rag/tools/data_process/ops/embed_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
enable_sparse: bool = False,
enable_multimodal: bool = False,
multimodal_source: str = None,
connection_name: str = None,
workspace_id: str = None,
*args,
**kwargs,
):
Expand All @@ -39,6 +41,8 @@ def __init__(
"source": source,
"model": model,
"enable_sparse": enable_sparse,
"connection_name": connection_name,
"workspace_id": workspace_id,
}
)
# Init model download list
Expand Down
14 changes: 14 additions & 0 deletions src/pai_rag/tools/data_process/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def process_embedder(args):
"enable_sparse",
"enable_multimodal",
"multimodal_source",
"connection_name",
"workspace_id",
]
}
args.process.append("rag_embedder")
Expand Down Expand Up @@ -283,6 +285,18 @@ def init_configs():
default="cnclip",
help="Multi-modal embedding model source for rag_embedder operator.",
)
parser.add_argument(
"--connection_name",
type=str,
default=None,
help="Langstudio connection for rag_embedder operator.",
)
parser.add_argument(
"--workspace_id",
type=str,
default=None,
help="PAI workspace id for rag_embedder operator.",
)

args = parser.parse_args()

Expand Down
2 changes: 2 additions & 0 deletions src/pai_rag/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@
)

DEFAULT_DATAFILE_DIR = "./data"

DEFAULT_DASHSCOPE_EMBEDDING_MODEL = "text-embedding-v2"

0 comments on commit b285c03

Please # to comment.