Skip to content

Commit

Permalink
feat: Add model_utils and model_constants intorducing `PROVIDER_NAMES…
Browse files Browse the repository at this point in the history
…` and `MODEL_INFO` variables for dynamic updation of model data (#4341)

feat: Add model_utils and model_constants

- Enhanced the initialization logic in model_constants to handle delayed imports and circular dependencies.
- Improved type hinting for better code clarity and maintainability.

Details:
- `get_model_info`: Retrieves comprehensive information about all available models, which is used to populate the `MODEL_INFO` dictionary.
- `MODEL_INFO`: A dictionary where each key is a model identifier, and the value is a dictionary containing details about the model, such as its `display_name` and configuration options.
- `PROVIDER_NAMES`: A list derived from `MODEL_INFO` that holds the names of model providers, providing a quick reference to all available model providers.
  • Loading branch information
edwinjosechittilappilly authored Nov 1, 2024
1 parent 10a63ff commit 98ee051
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 4 deletions.
61 changes: 61 additions & 0 deletions src/backend/base/langflow/base/models/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import json
import warnings
from abc import abstractmethod
Expand Down Expand Up @@ -206,3 +207,63 @@ def get_chat_result(
@abstractmethod
def build_model(self) -> LanguageModel: # type: ignore[type-var]
"""Implement this method to build the model."""

def get_llm(self, provider_name: str, model_info: dict[str, dict[str, str | list[InputTypes]]]) -> LanguageModel:
"""Get LLM model based on provider name and inputs.
Args:
provider_name: Name of the model provider (e.g., "OpenAI", "Azure OpenAI")
inputs: Dictionary of input parameters for the model
model_info: Dictionary of model information
Returns:
Built LLM model instance
"""
try:
if provider_name not in [model.get("display_name") for model in model_info.values()]:
msg = f"Unknown model provider: {provider_name}"
raise ValueError(msg)

# Find the component class name from MODEL_INFO in a single iteration
component_info, module_name = next(
((info, key) for key, info in model_info.items() if info.get("display_name") == provider_name),
(None, None),
)
if not component_info:
msg = f"Component information not found for {provider_name}"
raise ValueError(msg)
component_inputs = component_info.get("inputs", [])
# Get the component class from the models module
# Ensure component_inputs is a list of the expected types
if not isinstance(component_inputs, list):
component_inputs = []
models_module = importlib.import_module("langflow.components.models")
component_class = getattr(models_module, str(module_name))
component = component_class()

return self.build_llm_model_from_inputs(component, component_inputs)
except Exception as e:
msg = f"Error building {provider_name} language model"
raise ValueError(msg) from e

def build_llm_model_from_inputs(
self, component: Component, inputs: list[InputTypes], prefix: str = ""
) -> LanguageModel:
"""Build LLM model from component and inputs.
Args:
component: LLM component instance
inputs: Dictionary of input parameters for the model
prefix: Prefix for the input names
Returns:
Built LLM model instance
"""
# Ensure prefix is a string
prefix = prefix or ""
# Filter inputs to only include valid component input names
input_data = {
str(component_input.name): getattr(self, f"{prefix}{component_input.name}", None)
for component_input in inputs
}

return component.set(**input_data).build_model()
17 changes: 17 additions & 0 deletions src/backend/base/langflow/base/models/model_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
class ModelConstants:
"""Class to hold model-related constants. To solve circular import issue."""

PROVIDER_NAMES: list[str] = []
MODEL_INFO: dict[str, dict[str, str | list]] = {} # Adjusted type hint

@staticmethod
def initialize():
from langflow.base.models.model_utils import get_model_info # Delayed import

model_info = get_model_info()
ModelConstants.MODEL_INFO = model_info
ModelConstants.PROVIDER_NAMES = [
str(model.get("display_name"))
for model in model_info.values()
if isinstance(model.get("display_name"), str)
]
31 changes: 31 additions & 0 deletions src/backend/base/langflow/base/models/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import importlib

from langflow.base.models.model import LCModelComponent
from langflow.inputs.inputs import InputTypes


def get_model_info() -> dict[str, dict[str, str | list[InputTypes]]]:
"""Get inputs for all model components."""
model_inputs = {}
models_module = importlib.import_module("langflow.components.models")
model_component_names = getattr(models_module, "__all__", [])

for name in model_component_names:
if name in ("base", "DynamicLLMComponent"): # Skip the base module
continue

component_class = getattr(models_module, name)
if issubclass(component_class, LCModelComponent):
component = component_class()
base_input_names = {input_field.name for input_field in LCModelComponent._base_inputs}
input_fields_list = [
input_field for input_field in component.inputs if input_field.name not in base_input_names
]
component_display_name = component.display_name
model_inputs[name] = {
"display_name": component_display_name,
"inputs": input_fields_list,
"icon": component.icon,
}

return model_inputs
8 changes: 4 additions & 4 deletions src/backend/base/langflow/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def build_vertex(self, vertex: Vertex) -> Vertex:

@celery_app.task(acks_late=True)
def process_graph_cached_task(
data_graph: dict[str, Any], # noqa: ARG001
inputs: dict | list[dict] | None = None, # noqa: ARG001
clear_cache=False, # noqa: ARG001, FBT002
session_id=None, # noqa: ARG001
data_graph: dict[str, Any],
inputs: dict | list[dict] | None = None,
clear_cache=False, # noqa: FBT002
session_id=None,
) -> dict[str, Any]:
msg = "This task is not implemented yet"
raise NotImplementedError(msg)
Empty file.
25 changes: 25 additions & 0 deletions src/backend/tests/unit/base/models/test_model_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from src.backend.base.langflow.base.models.model_constants import ModelConstants


def test_provider_names():
# Initialize the ModelConstants
ModelConstants.initialize()

# Expected provider names
expected_provider_names = [
"AIML",
"Amazon Bedrock",
"Anthropic",
"Azure OpenAI",
"Ollama",
"Vertex AI",
"Cohere",
"Google Generative AI",
"HuggingFace",
"OpenAI",
"Perplexity",
"Qianfan",
]

# Assert that the provider names match the expected list
assert expected_provider_names == ModelConstants.PROVIDER_NAMES

0 comments on commit 98ee051

Please # to comment.