-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrag.py
118 lines (94 loc) · 4.54 KB
/
rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import streamlit as st
from langchain.document_loaders import TextLoader
from langchain.embeddings.base import Embeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_community.vectorstores import SKLearnVectorStore
from langchain_core.output_parsers import StrOutputParser
from langchain_ollama import ChatOllama
from sentence_transformers import SentenceTransformer
# Initialize Streamlit app
st.title("Welcome! How can I help you today?")
# Chat history initialization
if "chat_history" not in st.session_state:
st.session_state["chat_history"] = []
# Chat history
st.subheader("Chat History:")
with st.container():
for qa in st.session_state["chat_history"]:
st.markdown(f"**You:** {qa['question']}")
st.markdown(f"**Assistant:** {qa['answer']}")
st.markdown("---")
# Input text box
question = st.text_area("Enter your question:", placeholder="Type your question here...")
if st.button("Get Answer"):
if question.strip():
with st.spinner("Generating answer..."):
answer = ''
file_paths = [
'combined_text.txt',
]
# Load and split text
docs = [TextLoader(file_path).load() for file_path in file_paths]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
# Embeddings for deep vector representation using FAISS
class HuggingFaceEmbeddings(Embeddings):
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
self.model = SentenceTransformer(model_name)
def embed_documents(self, texts):
return self.model.encode(texts)
def embed_query(self, text):
return self.model.encode([text])[0]
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Look for cached vectorstore and use if seen, else generate
index_filepath = "faiss_index"
if os.path.exists(index_filepath):
vectorstore = FAISS.load_local(index_filepath, embedding_model, allow_dangerous_deserialization=True)
else:
vectorstore = FAISS.from_documents(doc_splits, embedding_model)
vectorstore.save_local(index_filepath)
retriever = vectorstore.as_retriever(k=4)
# Prompt engineering, the core of the behavior of the model
prompt = PromptTemplate(
template="""You are an assistant for question-answering tasks.
Use the following documents to answer the question.
If you don't know the answer, just say that you don't know.
Use as many sentences as you want but be accurate and detailed to some degree:
Question: {question}
Documents: {documents}
Answer:
""",
input_variables=["question", "documents"],
)
# Temp = 0 for deterministic responses, ollama server SHOULD be running in background
llm = ChatOllama(
model="llama3.1", # <--- Base Model from ollama
temperature=0,
)
# Pipeline for RAG
rag_chain = prompt | llm | StrOutputParser()
class RAGApplication:
def __init__(self, retriever, rag_chain):
self.retriever = retriever
self.rag_chain = rag_chain
def run(self, question):
documents = self.retriever.invoke(question)
doc_texts = "\n".join([doc.page_content for doc in documents])
answer = self.rag_chain.invoke({"question": question, "documents": doc_texts})
return answer
# Actually running the RAG
rag_application = RAGApplication(retriever, rag_chain)
answer = rag_application.run(question)
# Save to chat history
st.session_state["chat_history"].append({"question": question, "answer": answer})
# Display the answer
st.subheader("Answer:")
st.write(answer)
else:
st.error("Please enter a question before clicking the button.")