Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

GUI SUPPORT #21

Merged
merged 13 commits into from
May 14, 2023
91 changes: 91 additions & 0 deletions gui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import dotenv
import streamlit as st
from streamlit_chat import message
from streamlit_extras.colored_header import colored_header
from streamlit_extras.add_vertical_space import add_vertical_space
import startLLM
import os

dotenv_file = dotenv.find_dotenv(".env")
dotenv.load_dotenv()
llama_embeddings_model = os.environ.get("LLAMA_EMBEDDINGS_MODEL")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = int(os.environ.get('MODEL_N_CTX'))
model_temp = float(os.environ.get('MODEL_TEMP'))
model_stop = os.environ.get('MODEL_STOP')

# Initialization
if "input" not in st.session_state:
st.session_state.input = ""
st.session_state.running = False

st.set_page_config(page_title="CASALIOY")

# Sidebar contents
with st.sidebar:
st.title('CASALIOY')
st.markdown('''
## About
This app is an LLM-powered chatbot built using:
- [Streamlit](https://streamlit.io/)
- [su77ungr/CASALIOY](https://github.com/alxspiker/CASALIOY) LLM Toolkit

💡 Note: No API key required!
Refreshing the page will restart gui.py with a fresh chat history.
CASALIOY will not remember previous questions as of yet.

GUI does not support live response yet, so you have to wait for the tokens to process.
''')
add_vertical_space(5)
st.write('Made with ❤️ by [su77ungr/CASALIOY](https://github.com/alxspiker/CASALIOY)')

if 'generated' not in st.session_state:
st.session_state['generated'] = ["I can help you answer questions about the documents you have ingested into the vector store."]

if 'past' not in st.session_state:
st.session_state['past'] = ['Hi, what can you help me with!']

colored_header(label='', description='', color_name='blue-30')
response_container = st.container()



def generate_response(input=""):
with response_container:
col1, col2, col3 = st.columns(3)
with col1:
if st.number_input('Temperature', key="temp_input", value=float(model_temp), step=float(0.05), min_value=float(0), max_value=float(1)):
os.environ["MODEL_TEMP"] = str(st.session_state.temp_input)
dotenv.set_key(dotenv_file, "MODEL_TEMP", os.environ["MODEL_TEMP"])
with col2:
if st.number_input('Context', key="ctx_input", value=int(model_n_ctx), step=int(512), min_value=int(512), max_value=int(9000)):
os.environ["MODEL_N_CTX"] = str(st.session_state.ctx_input)
dotenv.set_key(dotenv_file, "MODEL_N_CTX", os.environ["MODEL_N_CTX"])
with col3:
if st.text_input('Stops', key="stops_input", value=str(model_stop)):
os.environ["MODEL_STOP"] = str(st.session_state.stops_input)
dotenv.set_key(dotenv_file, "MODEL_STOP", os.environ["MODEL_STOP"])
#with st.form("my_form", clear_on_submit=True):
if st.session_state['generated']:
for i in range(len(st.session_state['generated'])):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
message(st.session_state["generated"][i], key=str(i))
if input.strip() != "":
st.session_state.running=True
st.session_state.past.append(st.session_state.input)
if st.session_state.running:
message(st.session_state.input, is_user=True)
message("Loading response. Please wait for me to finish before refreshing the page...", key="rmessage")
#startLLM.qdrant = None #Not sure why this fixes db error
response = startLLM.main(st.session_state.input, True)
st.session_state.generated.append(response)
message(response)
st.session_state.running = False
st.text_input("You: ", "", key="input", disabled=st.session_state.running)


with st.form("my_form", clear_on_submit=True):
st.form_submit_button('SUBMIT', on_click=generate_response(st.session_state.input), disabled=st.session_state.running)

5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@ qdrant-client==1.1.7
llama-cpp-python==0.1.49
pdfminer.six==20221105
python-dotenv==1.0.0
streamlit==1.22.0
streamlit-chat==0.0.2.2
streamlit-extras==0.2.7
pandoc==2.3
unstructured==0.6.6
unstructured==0.6.6
53 changes: 32 additions & 21 deletions startLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')
model_temp = os.environ.get('MODEL_TEMP')
model_n_ctx = int(os.environ.get('MODEL_N_CTX'))
model_temp = float(os.environ.get('MODEL_TEMP'))
model_stop = os.environ.get('MODEL_STOP').split(",")

def main():
qa_system=None

def initialize_qa_system():
# Load stored vectorstore
llama = LlamaCppEmbeddings(model_path=llama_embeddings_model, n_ctx=model_n_ctx)
# Load ggml-formatted model
Expand All @@ -41,27 +43,36 @@ def main():
case _default:
print("Only LlamaCpp or GPT4All supported right now. Make sure you set up your .env correctly.")
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=qdrant.as_retriever(search_type="mmr"), return_source_documents=True)
return qa

def main(prompt="", gui=False):
global qa_system
if qa_system is None:
qa_system = initialize_qa_system()
# Interactive questions and answers
while True:
query = input("\nEnter a query: ")
if query == "exit":
break

# Get the answer from the chain
res = qa(query)
answer, docs = res['result'], res['source_documents']
if (prompt.strip() != "" and gui) or gui==False:
while True:
query = prompt if gui else input("\nEnter a query: ")
if query == "exit":
break

# Get the answer from the chain
res = qa_system(query)
answer, docs = res['result'], res['source_documents']

# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)

# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)

# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)

if gui:
return answer

if __name__ == "__main__":
main()