-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
183 lines (143 loc) · 6.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import streamlit as st
from openai import OpenAI
from glob import glob
from llama_index import download_loader, VectorStoreIndex, ServiceContext
from llama_index.vector_stores import MilvusVectorStore
from llama_index.embeddings import HuggingFaceEmbedding
from llama_index import ServiceContext
from llama_index.llms import PaLM
from llama_index.embeddings import GooglePaLMEmbedding
from llama_index.memory import ChatMemoryBuffer
import numpy as np
import google.generativeai as palm
import os
import pickle
from trulens_eval import Feedback, Tru, TruLlama
from trulens_eval.feedback import Groundedness
from trulens_eval.feedback.provider.openai import OpenAI as OAI
tru = Tru()
# 1. Set up the name of the collection to be created.
COLLECTION_NAME = 'hydroponics_knowledge_base'
# 2. Set up the dimension of the embeddings.
DIMENSION = 1536
# 3. Set the inference parameters
BATCH_SIZE = 128
TOP_K = 3
# 4. Set up the connection parameters for your Zilliz Cloud cluster.
URI = st.secrets['CLUSTER_ENDPOINT']
TOKEN = st.secrets['API_TOKEN']
# Palm API
palm_api_key = st.secrets['PALM_API_KEY']
palm.configure(api_key=palm_api_key)
models = [
m
for m in palm.list_models()
if "generateText" in m.supported_generation_methods
]
model = models[0].name
print(model)
llm = PaLM(api_key=palm_api_key)
# .streamlit/secrets.toml
# # OpenAI API key
os.environ["OPENAI_API_KEY"] = st.secrets["OPEN_API_KEY"]
# Creating memory
memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
# Query Engine
# query_engine = index.as_query_engine()
# # Truera Wrapper
# l = TruLlama(query_engine)
st.title("SHAi")
st.text("Sustainable Hydroponic AI")
# client = TruLlama(query_engine)
if "openai_model" not in st.session_state:
# st.session_state["openai_model"] = "gpt-3.5-turbo"
pass
if "messages" not in st.session_state:
st.session_state.messages = [
{"role": "assistant", "content": "Ask me a question about Hydroponics & I will try to answer it in a sustainable way possible!"}
]
@st.cache_resource(show_spinner=False)
def load_data():
with st.spinner(text="Loading and indexing the vectors from Zilliz – hang tight! This should take 30 - 50 seconds"):
# Grab all markdown files and convert them using the reader
docs = []
if os.path.exists("docs.pkl"):
with open("docs.pkl", "rb") as f:
# To load the data from pickle
docs = pickle.load(f)
# Push all doc files into Zilliz Cloud
vector_store = MilvusVectorStore(
uri=URI,
token=TOKEN,
collection_name=COLLECTION_NAME,
similarity_metric="L2",
dim=DIMENSION,
)
llm=PaLM(api_key=palm_api_key)
# Service Context - PALM - Vertex AI
embed_model = GooglePaLMEmbedding("models/embedding-gecko-001", api_key=palm_api_key)
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
index = VectorStoreIndex.from_documents(
documents=docs,
service_context=service_context,
show_progress=True,
)
return index
index = load_data()
# chat_engine = index.as_chat_engine(chat_mode="condense_question", verbose=True)
chat_engine = index.as_chat_engine(chat_mode="context",
memory=memory, verbose=True)
# chat_engine = index.as_query_engine()
# Turning off Trulens in production app - to eliminate production error
# import numpy as np
# # Initialize provider class
# openai = OAI()
# grounded = Groundedness(groundedness_provider=OAI())
# # Define a groundedness feedback function
# f_groundedness = Feedback(grounded.groundedness_measure_with_cot_reasons).on(
# TruLlama.select_source_nodes().node.text.collect()
# ).on_output(
# ).aggregate(grounded.grounded_statements_aggregator)
# # Question/answer relevance between overall question and answer.
# f_qa_relevance = Feedback(openai.relevance).on_input_output()
# # Question/statement relevance between question and each context chunk.
# f_qs_relevance = Feedback(openai.qs_relevance).on_input().on(
# TruLlama.select_source_nodes().node.text
# ).aggregate(np.mean)
# tru_query_engine_recorder = TruLlama(chat_engine,
# app_id='SHAi_App',
# feedbacks=[f_groundedness, f_qa_relevance, f_qs_relevance])
# tru.run_dashboard()
if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
for message in st.session_state.messages: # Display the prior chat messages
with st.chat_message(message["role"]):
st.write(message["content"])
# If last message is not from assistant, generate a new response
# if st.session_state.messages[-1]["role"] != "assistant":
# with st.chat_message("assistant"):
# with st.spinner("Thinking..."):
# response = ""
# try:
# response = chat_engine.chat(prompt)
# st.write(response.response)
# except Exception as e:
# st.error(f"Error during chat: {e}")
# # with tru_query_engine_recorder as recording:
# # chat_engine.chat(prompt)
# message = {"role": "assistant", "content": response.response}
# st.session_state.messages.append(message) # Adds response to message history
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
try:
with st.spinner("Thinking..."):
response = chat_engine.chat(prompt)
st.write(response.response)
message = {"role": "assistant", "content": response.response}
st.session_state.messages.append(message)
except Exception as e:
# st.error(f"Error generating response: {e}")
response= "I regret that I cannot provide information on this matter. Perhaps I could assist you with a different question related to hydroponics?"
st.write(response)
message = {"role": "assistant", "content": response}
st.session_state.messages.append(message)