Skip to content

Commit

Permalink
fix: issue with dynamic inputs when selecting model (#4538)
Browse files Browse the repository at this point in the history
  • Loading branch information
erichare authored Nov 12, 2024
1 parent 0dc6cce commit 1dfa160
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 44 deletions.
90 changes: 51 additions & 39 deletions src/backend/base/langflow/components/vectorstores/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
),
]

def del_fields(self, build_config, field_list):
for field in field_list:
if field in build_config:
del build_config[field]

return build_config

def insert_in_dict(self, build_config, field_name, new_parameters):
# Insert the new key-value pair after the found key
for new_field_name, new_parameter in new_parameters.items():
Expand All @@ -234,31 +241,30 @@ def insert_in_dict(self, build_config, field_name, new_parameters):
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
if field_name == "embedding_service":
if field_value == "Astra Vectorize":
for field in ["embedding"]:
if field in build_config:
del build_config[field]
self.del_fields(build_config, ["embedding"])

new_parameter = DropdownInput(
name="provider",
display_name="Vectorize Provider",
name="embedding_provider",
display_name="Embedding Provider",
options=self.VECTORIZE_PROVIDERS_MAPPING.keys(),
value="",
required=True,
real_time_refresh=True,
).to_dict()

self.insert_in_dict(build_config, "embedding_service", {"provider": new_parameter})
self.insert_in_dict(build_config, "embedding_service", {"embedding_provider": new_parameter})
else:
for field in [
"provider",
"z_00_model_name",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_authentication",
]:
if field in build_config:
del build_config[field]
self.del_fields(
build_config,
[
"embedding_provider",
"model",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_authentication",
],
)

new_parameter = HandleInput(
name="embedding",
Expand All @@ -269,32 +275,35 @@ def update_build_config(self, build_config: dict, field_value: str, field_name:

self.insert_in_dict(build_config, "embedding_service", {"embedding": new_parameter})

elif field_name == "provider":
for field in [
"z_00_model_name",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_authentication",
]:
if field in build_config:
del build_config[field]
elif field_name == "embedding_provider":
self.del_fields(
build_config,
["model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"],
)

model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1]

new_parameter_0 = DropdownInput(
name="z_00_model_name",
display_name="Model Name",
new_parameter = DropdownInput(
name="model",
display_name="Model",
info="The embedding model to use for the selected provider. Each provider has a different set of "
"models available (full list at "
"https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html):\n\n"
f"{', '.join(model_options)}",
options=model_options,
placeholder="Select a model",
value=model_options[0],
value=None,
required=True,
real_time_refresh=True,
).to_dict()

self.insert_in_dict(build_config, "embedding_provider", {"model": new_parameter})

elif field_name == "model":
self.del_fields(
build_config,
["z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"],
)

new_parameter_1 = DictInput(
name="z_01_model_parameters",
display_name="Model Parameters",
Expand All @@ -303,12 +312,13 @@ def update_build_config(self, build_config: dict, field_value: str, field_name:

new_parameter_2 = MessageTextInput(
name="z_02_api_key_name",
display_name="API Key name",
display_name="API Key Name",
info="The name of the embeddings provider API key stored on Astra. "
"If set, it will override the 'ProviderKey' in the authentication parameters.",
).to_dict()

new_parameter_3 = SecretStrInput(
load_from_db=False,
name="z_03_provider_api_key",
display_name="Provider API Key",
info="An alternative to the Astra Authentication that passes an API key for the provider "
Expand All @@ -319,15 +329,14 @@ def update_build_config(self, build_config: dict, field_value: str, field_name:

new_parameter_4 = DictInput(
name="z_04_authentication",
display_name="Authentication parameters",
display_name="Authentication Parameters",
is_list=True,
).to_dict()

self.insert_in_dict(
build_config,
"provider",
"model",
{
"z_00_model_name": new_parameter_0,
"z_01_model_parameters": new_parameter_1,
"z_02_api_key_name": new_parameter_2,
"z_03_provider_api_key": new_parameter_3,
Expand All @@ -339,8 +348,8 @@ def update_build_config(self, build_config: dict, field_value: str, field_name:

def build_vectorize_options(self, **kwargs):
for attribute in [
"provider",
"z_00_model_name",
"embedding_provider",
"model",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
Expand All @@ -350,8 +359,10 @@ def build_vectorize_options(self, **kwargs):
setattr(self, attribute, None)

# Fetch values from kwargs if any self.* attributes are None
provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.provider, [None])[0] or kwargs.get("provider")
model_name = self.z_00_model_name or kwargs.get("z_00_model_name")
provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.embedding_provider, [None])[0] or kwargs.get(
"embedding_provider"
)
model_name = self.model or kwargs.get("model")
authentication = {**(self.z_04_authentication or kwargs.get("z_04_authentication", {}))}
parameters = self.z_01_model_parameters or kwargs.get("z_01_model_parameters", {})

Expand Down Expand Up @@ -414,6 +425,7 @@ def build_vector_store(self, vectorize_options=None):
),
"collection_embedding_api_key": dict_options.get("collection_embedding_api_key"),
}

try:
vector_store = AstraDBVectorStore(
collection_name=self.collection_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_astra_vectorize():
store = None
try:
options = {"provider": "nvidia", "modelName": "NV-Embed-QA"}
options_comp = {"provider": "nvidia", "z_00_model_name": "NV-Embed-QA"}
options_comp = {"embedding_provider": "nvidia", "model": "NV-Embed-QA"}

store = AstraDBVectorStore(
collection_name=VECTORIZE_COLLECTION,
Expand Down Expand Up @@ -150,8 +150,8 @@ def test_astra_vectorize_with_provider_api_key():
}

options_comp = {
"provider": "openai",
"z_00_model_name": "text-embedding-3-small",
"embedding_provider": "openai",
"model": "text-embedding-3-small",
"z_01_model_parameters": {},
"z_03_provider_api_key": "openai",
"z_04_authentication": {},
Expand Down Expand Up @@ -206,8 +206,8 @@ def test_astra_vectorize_passes_authentication():
"authentication": {"providerKey": "openai"},
}
options_comp = {
"provider": "openai",
"z_00_model_name": "text-embedding-3-small",
"embedding_provider": "openai",
"model": "text-embedding-3-small",
"z_01_model_parameters": {},
"z_04_authentication": {"providerKey": "openai"},
}
Expand Down

0 comments on commit 1dfa160

Please # to comment.