Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix: issue with dynamic inputs when selecting model #4538

Merged
merged 5 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 51 additions & 39 deletions src/backend/base/langflow/components/vectorstores/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
),
]

def del_fields(self, build_config, field_list):
for field in field_list:
if field in build_config:
del build_config[field]

return build_config

def insert_in_dict(self, build_config, field_name, new_parameters):
# Insert the new key-value pair after the found key
for new_field_name, new_parameter in new_parameters.items():
Expand All @@ -234,31 +241,30 @@ def insert_in_dict(self, build_config, field_name, new_parameters):
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
if field_name == "embedding_service":
if field_value == "Astra Vectorize":
for field in ["embedding"]:
if field in build_config:
del build_config[field]
self.del_fields(build_config, ["embedding"])

new_parameter = DropdownInput(
name="provider",
display_name="Vectorize Provider",
name="embedding_provider",
display_name="Embedding Provider",
options=self.VECTORIZE_PROVIDERS_MAPPING.keys(),
value="",
required=True,
real_time_refresh=True,
).to_dict()

self.insert_in_dict(build_config, "embedding_service", {"provider": new_parameter})
self.insert_in_dict(build_config, "embedding_service", {"embedding_provider": new_parameter})
else:
for field in [
"provider",
"z_00_model_name",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_authentication",
]:
if field in build_config:
del build_config[field]
self.del_fields(
build_config,
[
"embedding_provider",
"model",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_authentication",
],
)

new_parameter = HandleInput(
name="embedding",
Expand All @@ -269,32 +275,35 @@ def update_build_config(self, build_config: dict, field_value: str, field_name:

self.insert_in_dict(build_config, "embedding_service", {"embedding": new_parameter})

elif field_name == "provider":
for field in [
"z_00_model_name",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_authentication",
]:
if field in build_config:
del build_config[field]
elif field_name == "embedding_provider":
self.del_fields(
build_config,
["model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"],
)

model_options = self.VECTORIZE_PROVIDERS_MAPPING[field_value][1]

new_parameter_0 = DropdownInput(
name="z_00_model_name",
display_name="Model Name",
new_parameter = DropdownInput(
name="model",
display_name="Model",
info="The embedding model to use for the selected provider. Each provider has a different set of "
"models available (full list at "
"https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html):\n\n"
f"{', '.join(model_options)}",
options=model_options,
placeholder="Select a model",
value=model_options[0],
value=None,
required=True,
real_time_refresh=True,
).to_dict()

self.insert_in_dict(build_config, "embedding_provider", {"model": new_parameter})

elif field_name == "model":
self.del_fields(
build_config,
["z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"],
)

new_parameter_1 = DictInput(
name="z_01_model_parameters",
display_name="Model Parameters",
Expand All @@ -303,12 +312,13 @@ def update_build_config(self, build_config: dict, field_value: str, field_name:

new_parameter_2 = MessageTextInput(
name="z_02_api_key_name",
display_name="API Key name",
display_name="API Key Name",
info="The name of the embeddings provider API key stored on Astra. "
"If set, it will override the 'ProviderKey' in the authentication parameters.",
).to_dict()

new_parameter_3 = SecretStrInput(
load_from_db=False,
name="z_03_provider_api_key",
display_name="Provider API Key",
info="An alternative to the Astra Authentication that passes an API key for the provider "
Expand All @@ -319,15 +329,14 @@ def update_build_config(self, build_config: dict, field_value: str, field_name:

new_parameter_4 = DictInput(
name="z_04_authentication",
display_name="Authentication parameters",
display_name="Authentication Parameters",
is_list=True,
).to_dict()

self.insert_in_dict(
build_config,
"provider",
"model",
{
"z_00_model_name": new_parameter_0,
"z_01_model_parameters": new_parameter_1,
"z_02_api_key_name": new_parameter_2,
"z_03_provider_api_key": new_parameter_3,
Expand All @@ -339,8 +348,8 @@ def update_build_config(self, build_config: dict, field_value: str, field_name:

def build_vectorize_options(self, **kwargs):
for attribute in [
"provider",
"z_00_model_name",
"embedding_provider",
"model",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
Expand All @@ -350,8 +359,10 @@ def build_vectorize_options(self, **kwargs):
setattr(self, attribute, None)

# Fetch values from kwargs if any self.* attributes are None
provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.provider, [None])[0] or kwargs.get("provider")
model_name = self.z_00_model_name or kwargs.get("z_00_model_name")
provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.embedding_provider, [None])[0] or kwargs.get(
"embedding_provider"
)
model_name = self.model or kwargs.get("model")
authentication = {**(self.z_04_authentication or kwargs.get("z_04_authentication", {}))}
parameters = self.z_01_model_parameters or kwargs.get("z_01_model_parameters", {})

Expand Down Expand Up @@ -414,6 +425,7 @@ def build_vector_store(self, vectorize_options=None):
),
"collection_embedding_api_key": dict_options.get("collection_embedding_api_key"),
}

try:
vector_store = AstraDBVectorStore(
collection_name=self.collection_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_astra_vectorize():
store = None
try:
options = {"provider": "nvidia", "modelName": "NV-Embed-QA"}
options_comp = {"provider": "nvidia", "z_00_model_name": "NV-Embed-QA"}
options_comp = {"embedding_provider": "nvidia", "model": "NV-Embed-QA"}

store = AstraDBVectorStore(
collection_name=VECTORIZE_COLLECTION,
Expand Down Expand Up @@ -150,8 +150,8 @@ def test_astra_vectorize_with_provider_api_key():
}

options_comp = {
"provider": "openai",
"z_00_model_name": "text-embedding-3-small",
"embedding_provider": "openai",
"model": "text-embedding-3-small",
"z_01_model_parameters": {},
"z_03_provider_api_key": "openai",
"z_04_authentication": {},
Expand Down Expand Up @@ -206,8 +206,8 @@ def test_astra_vectorize_passes_authentication():
"authentication": {"providerKey": "openai"},
}
options_comp = {
"provider": "openai",
"z_00_model_name": "text-embedding-3-small",
"embedding_provider": "openai",
"model": "text-embedding-3-small",
"z_01_model_parameters": {},
"z_04_authentication": {"providerKey": "openai"},
}
Expand Down
Loading