-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add model_utils and model_constants intorducing `PROVIDER_NAMES…
…` and `MODEL_INFO` variables for dynamic updation of model data (#4341) feat: Add model_utils and model_constants - Enhanced the initialization logic in model_constants to handle delayed imports and circular dependencies. - Improved type hinting for better code clarity and maintainability. Details: - `get_model_info`: Retrieves comprehensive information about all available models, which is used to populate the `MODEL_INFO` dictionary. - `MODEL_INFO`: A dictionary where each key is a model identifier, and the value is a dictionary containing details about the model, such as its `display_name` and configuration options. - `PROVIDER_NAMES`: A list derived from `MODEL_INFO` that holds the names of model providers, providing a quick reference to all available model providers.
- Loading branch information
1 parent
10a63ff
commit 98ee051
Showing
6 changed files
with
138 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
class ModelConstants: | ||
"""Class to hold model-related constants. To solve circular import issue.""" | ||
|
||
PROVIDER_NAMES: list[str] = [] | ||
MODEL_INFO: dict[str, dict[str, str | list]] = {} # Adjusted type hint | ||
|
||
@staticmethod | ||
def initialize(): | ||
from langflow.base.models.model_utils import get_model_info # Delayed import | ||
|
||
model_info = get_model_info() | ||
ModelConstants.MODEL_INFO = model_info | ||
ModelConstants.PROVIDER_NAMES = [ | ||
str(model.get("display_name")) | ||
for model in model_info.values() | ||
if isinstance(model.get("display_name"), str) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import importlib | ||
|
||
from langflow.base.models.model import LCModelComponent | ||
from langflow.inputs.inputs import InputTypes | ||
|
||
|
||
def get_model_info() -> dict[str, dict[str, str | list[InputTypes]]]: | ||
"""Get inputs for all model components.""" | ||
model_inputs = {} | ||
models_module = importlib.import_module("langflow.components.models") | ||
model_component_names = getattr(models_module, "__all__", []) | ||
|
||
for name in model_component_names: | ||
if name in ("base", "DynamicLLMComponent"): # Skip the base module | ||
continue | ||
|
||
component_class = getattr(models_module, name) | ||
if issubclass(component_class, LCModelComponent): | ||
component = component_class() | ||
base_input_names = {input_field.name for input_field in LCModelComponent._base_inputs} | ||
input_fields_list = [ | ||
input_field for input_field in component.inputs if input_field.name not in base_input_names | ||
] | ||
component_display_name = component.display_name | ||
model_inputs[name] = { | ||
"display_name": component_display_name, | ||
"inputs": input_fields_list, | ||
"icon": component.icon, | ||
} | ||
|
||
return model_inputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
25 changes: 25 additions & 0 deletions
25
src/backend/tests/unit/base/models/test_model_constants.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from src.backend.base.langflow.base.models.model_constants import ModelConstants | ||
|
||
|
||
def test_provider_names(): | ||
# Initialize the ModelConstants | ||
ModelConstants.initialize() | ||
|
||
# Expected provider names | ||
expected_provider_names = [ | ||
"AIML", | ||
"Amazon Bedrock", | ||
"Anthropic", | ||
"Azure OpenAI", | ||
"Ollama", | ||
"Vertex AI", | ||
"Cohere", | ||
"Google Generative AI", | ||
"HuggingFace", | ||
"OpenAI", | ||
"Perplexity", | ||
"Qianfan", | ||
] | ||
|
||
# Assert that the provider names match the expected list | ||
assert expected_provider_names == ModelConstants.PROVIDER_NAMES |