-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Langstudio embedding connection for embedder_operator (#338)
* Add Langstudio embedding connection for embedder_operator * Modify langstudio connection config * Fix import error for langstudio sdk * Add langstudio sdk to rag op dockerfile
- Loading branch information
Showing
7 changed files
with
148 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
100 changes: 100 additions & 0 deletions
100
src/pai_rag/integrations/embeddings/pai/langstudio_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import os | ||
from alibabacloud_credentials.client import Client as CredentialClient | ||
from alibabacloud_credentials.models import Config as CredentialConfig | ||
from alibabacloud_pailangstudio20240710.client import Client as LangStudioClient | ||
from alibabacloud_pailangstudio20240710.models import ( | ||
GetConnectionRequest, | ||
ListConnectionsRequest, | ||
) | ||
from alibabacloud_tea_openapi import models as open_api_models | ||
from pai_rag.utils.constants import DEFAULT_DASHSCOPE_EMBEDDING_MODEL | ||
from pai_rag.integrations.embeddings.pai.pai_embedding_config import parse_embed_config | ||
from loguru import logger | ||
|
||
|
||
def get_region_id(): | ||
return next( | ||
( | ||
os.environ[key] | ||
for key in ["REGION", "REGION_ID", "ALIBABA_CLOUD_REGION_ID"] | ||
if key in os.environ and os.environ[key] | ||
), | ||
"cn-hangzhou", | ||
) | ||
|
||
|
||
def get_connection_info(region_id: str, connection_name: str, workspace_id: str): | ||
""" | ||
Get Connection information from LangStudio API. | ||
""" | ||
config1 = CredentialConfig( | ||
type="access_key", | ||
access_key_id=os.environ.get("ALIBABA_CLOUD_ACCESS_KEY_ID"), | ||
access_key_secret=os.environ.get("ALIBABA_CLOUD_ACCESS_KEY_SECRET"), | ||
) | ||
public_endpoint = f"pailangstudio.{region_id}.aliyuncs.com" | ||
client = LangStudioClient( | ||
config=open_api_models.Config( | ||
# Use default credential chain, see: | ||
# https://help.aliyun.com/zh/sdk/developer-reference/v2-manage-python-access-credentials#3ca299f04bw3c | ||
credential=CredentialClient(config=config1), | ||
endpoint=public_endpoint, | ||
) | ||
) | ||
resp = client.list_connections( | ||
request=ListConnectionsRequest( | ||
connection_name=connection_name, workspace_id=workspace_id, max_results=50 | ||
) | ||
) | ||
connection_info = next( | ||
( | ||
conn | ||
for conn in resp.body.connections | ||
if conn.connection_name == connection_name | ||
), | ||
None, | ||
) | ||
if not connection_info: | ||
raise ValueError(f"Connection {connection_name} not found") | ||
ls_connection = client.get_connection( | ||
connection_id=connection_info.connection_id, | ||
request=GetConnectionRequest( | ||
workspace_id=workspace_id, | ||
encrypt_option="PlainText", | ||
), | ||
) | ||
conn_info = ls_connection.body | ||
configs = conn_info.configs or {} | ||
secrets = conn_info.secrets or {} | ||
|
||
logger.info(f"Configs conn_info:\n {conn_info}") | ||
return conn_info, configs, secrets | ||
|
||
|
||
def convert_langstudio_embed_config(embed_config): | ||
region_id = embed_config.region_id or get_region_id() | ||
conn_info, config, secrets = get_connection_info( | ||
region_id, embed_config.connection_name, embed_config.workspace_id | ||
) | ||
if conn_info.custom_type == "OpenEmbeddingConnection": | ||
return parse_embed_config( | ||
{ | ||
"source": "openai", | ||
"api_key": secrets.get("api_key", None), | ||
"api_base": config.get("base_url", None), | ||
"model": embed_config.model, | ||
"embed_batch_size": embed_config.embed_batch_size, | ||
} | ||
) | ||
elif conn_info.custom_type == "DashScopeConnection": | ||
return parse_embed_config( | ||
{ | ||
"source": "dashscope", | ||
"api_key": secrets.get("api_key", None) | ||
or os.getenv("DASHSCOPE_API_KEY"), | ||
"model": embed_config.model or DEFAULT_DASHSCOPE_EMBEDDING_MODEL, | ||
"embed_batch_size": embed_config.embed_batch_size, | ||
} | ||
) | ||
else: | ||
raise ValueError(f"Unknown connection type: {conn_info.custom_type}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,3 +21,5 @@ | |
) | ||
|
||
DEFAULT_DATAFILE_DIR = "./data" | ||
|
||
DEFAULT_DASHSCOPE_EMBEDDING_MODEL = "text-embedding-v2" |