-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: update Self Query Retriever Component (#3653)
* 🔧 (pyproject.toml): add lark dependency to support parsing and processing of grammars in the project ♻️ (SelfQueryRetriever.py): refactor input types in metadata fields to improve clarity and maintainability * 📝 (SelfQueryRetriever.py): Update class name and imports for consistency and clarity 📝 (SelfQueryRetriever.py): Refactor input and output definitions for better readability and maintainability 📝 (SelfQueryRetriever.py): Refactor method signatures and variable names for improved code organization and understanding * [autofix.ci] apply automated fixes * ♻️ (SelfQueryRetriever.py): Remove unused import 'VectorStore' to clean up the code and improve maintainability. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
076f4f0
commit 706d559
Showing
3 changed files
with
75 additions
and
51 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 56 additions & 50 deletions
106
src/backend/base/langflow/components/retrievers/SelfQueryRetriever.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,76 @@ | ||
# from langflow.field_typing import Data | ||
from typing import List | ||
|
||
from langchain.chains.query_constructor.base import AttributeInfo | ||
from langchain.retrievers.self_query.base import SelfQueryRetriever | ||
from langchain_core.vectorstores import VectorStore | ||
|
||
from langflow.custom import CustomComponent | ||
from langflow.field_typing import LanguageModel, Text | ||
from langflow.custom import Component | ||
from langflow.inputs import HandleInput, MessageTextInput | ||
from langflow.io import Output | ||
from langflow.schema import Data | ||
from langflow.schema.message import Message | ||
|
||
|
||
class SelfQueryRetrieverComponent(CustomComponent): | ||
display_name: str = "Self Query Retriever" | ||
description: str = "Retriever that uses a vector store and an LLM to generate the vector store queries." | ||
class SelfQueryRetrieverComponent(Component): | ||
display_name = "Self Query Retriever" | ||
description = "Retriever that uses a vector store and an LLM to generate the vector store queries." | ||
name = "SelfQueryRetriever" | ||
icon = "LangChain" | ||
|
||
def build_config(self): | ||
return { | ||
"query": { | ||
"display_name": "Query", | ||
"input_types": ["Message", "Text"], | ||
"info": "Query to be passed as input.", | ||
}, | ||
"vectorstore": { | ||
"display_name": "Vector Store", | ||
"info": "Vector Store to be passed as input.", | ||
}, | ||
"attribute_infos": { | ||
"display_name": "Metadata Field Info", | ||
"info": "Metadata Field Info to be passed as input.", | ||
}, | ||
"document_content_description": { | ||
"display_name": "Document Content Description", | ||
"info": "Document Content Description to be passed as input.", | ||
}, | ||
"llm": { | ||
"display_name": "LLM", | ||
"info": "LLM to be passed as input.", | ||
}, | ||
} | ||
inputs = [ | ||
HandleInput( | ||
name="query", | ||
display_name="Query", | ||
info="Query to be passed as input.", | ||
input_types=["Message", "Text"], | ||
), | ||
HandleInput( | ||
name="vectorstore", | ||
display_name="Vector Store", | ||
info="Vector Store to be passed as input.", | ||
input_types=["VectorStore"], | ||
), | ||
HandleInput( | ||
name="attribute_infos", | ||
display_name="Metadata Field Info", | ||
info="Metadata Field Info to be passed as input.", | ||
input_types=["Data"], | ||
is_list=True, | ||
), | ||
MessageTextInput( | ||
name="document_content_description", | ||
display_name="Document Content Description", | ||
info="Document Content Description to be passed as input.", | ||
), | ||
HandleInput( | ||
name="llm", | ||
display_name="LLM", | ||
info="LLM to be passed as input.", | ||
input_types=["LanguageModel"], | ||
), | ||
] | ||
|
||
outputs = [ | ||
Output(display_name="Retrieved Documents", name="documents", method="retrieve_documents"), | ||
] | ||
|
||
def build( | ||
self, | ||
query: Message, | ||
vectorstore: VectorStore, | ||
attribute_infos: list[Data], | ||
document_content_description: Text, | ||
llm: LanguageModel, | ||
) -> Data: | ||
metadata_field_infos = [AttributeInfo(**value.data) for value in attribute_infos] | ||
def retrieve_documents(self) -> List[Data]: | ||
metadata_field_infos = [AttributeInfo(**value.data) for value in self.attribute_infos] | ||
self_query_retriever = SelfQueryRetriever.from_llm( | ||
llm=llm, | ||
vectorstore=vectorstore, | ||
document_contents=document_content_description, | ||
llm=self.llm, | ||
vectorstore=self.vectorstore, | ||
document_contents=self.document_content_description, | ||
metadata_field_info=metadata_field_infos, | ||
enable_limit=True, | ||
) | ||
|
||
if isinstance(query, Message): | ||
input_text = query.text | ||
elif isinstance(query, str): | ||
input_text = query | ||
if isinstance(self.query, Message): | ||
input_text = self.query.text | ||
elif isinstance(self.query, str): | ||
input_text = self.query | ||
else: | ||
raise ValueError(f"Query type {type(self.query)} not supported.") | ||
|
||
if not isinstance(query, str): | ||
raise ValueError(f"Query type {type(query)} not supported.") | ||
documents = self_query_retriever.invoke(input=input_text, config={"callbacks": self.get_langchain_callbacks()}) | ||
data = [Data.from_document(document) for document in documents] | ||
self.status = data | ||
return data # type: ignore | ||
return data |