-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquery_data.py
executable file
·55 lines (42 loc) · 1.9 KB
/
query_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# query_data.py
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.llms import HuggingFaceHub
from langchain_community.vectorstores import Chroma
from dotenv import load_dotenv
import os
# load env
load_dotenv()
CHROMA_PATH = "chroma"
def setup_qa_chain():
# Set up the embedding function
embedding_function = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Load the Chroma database
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
# retriever will return the top 2 results that most relevant to the query.
retriever = db.as_retriever(search_kwargs={"k": 2})
# Set up the Hugging Face model
llm = HuggingFaceHub(
repo_id="google/flan-t5-base",
# control the randomness of output
model_kwargs={"temperature": 0.5, "max_length": 512},
huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
prompt_template = """Use the following retrieved context information and your general knowledge to answer the question. If the question cannot be answered solely based on the context information, use your knowledge to reason and explain.
If the context information is irrelevant or insufficient to answer the question, primarily rely on your knowledge to provide a comprehensive and accurate answer.
Context information:
{context}
Question: {question}
Answer:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff", # combine all relevant information into one context
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": PROMPT}
)
return qa_chain