diff --git a/gui.py b/gui.py new file mode 100644 index 0000000..fd3e2a0 --- /dev/null +++ b/gui.py @@ -0,0 +1,91 @@ +import dotenv +import streamlit as st +from streamlit_chat import message +from streamlit_extras.colored_header import colored_header +from streamlit_extras.add_vertical_space import add_vertical_space +import startLLM +import os + +dotenv_file = dotenv.find_dotenv(".env") +dotenv.load_dotenv() +llama_embeddings_model = os.environ.get("LLAMA_EMBEDDINGS_MODEL") +persist_directory = os.environ.get('PERSIST_DIRECTORY') +model_type = os.environ.get('MODEL_TYPE') +model_path = os.environ.get('MODEL_PATH') +model_n_ctx = int(os.environ.get('MODEL_N_CTX')) +model_temp = float(os.environ.get('MODEL_TEMP')) +model_stop = os.environ.get('MODEL_STOP') + +# Initialization +if "input" not in st.session_state: + st.session_state.input = "" + st.session_state.running = False + +st.set_page_config(page_title="CASALIOY") + +# Sidebar contents +with st.sidebar: + st.title('CASALIOY') + st.markdown(''' + ## About + This app is an LLM-powered chatbot built using: + - [Streamlit](https://streamlit.io/) + - [su77ungr/CASALIOY](https://github.com/alxspiker/CASALIOY) LLM Toolkit + + 💡 Note: No API key required! + Refreshing the page will restart gui.py with a fresh chat history. + CASALIOY will not remember previous questions as of yet. + + GUI does not support live response yet, so you have to wait for the tokens to process. + ''') + add_vertical_space(5) + st.write('Made with ❤️ by [su77ungr/CASALIOY](https://github.com/alxspiker/CASALIOY)') + +if 'generated' not in st.session_state: + st.session_state['generated'] = ["I can help you answer questions about the documents you have ingested into the vector store."] + +if 'past' not in st.session_state: + st.session_state['past'] = ['Hi, what can you help me with!'] + +colored_header(label='', description='', color_name='blue-30') +response_container = st.container() + + + +def generate_response(input=""): + with response_container: + col1, col2, col3 = st.columns(3) + with col1: + if st.number_input('Temperature', key="temp_input", value=float(model_temp), step=float(0.05), min_value=float(0), max_value=float(1)): + os.environ["MODEL_TEMP"] = str(st.session_state.temp_input) + dotenv.set_key(dotenv_file, "MODEL_TEMP", os.environ["MODEL_TEMP"]) + with col2: + if st.number_input('Context', key="ctx_input", value=int(model_n_ctx), step=int(512), min_value=int(512), max_value=int(9000)): + os.environ["MODEL_N_CTX"] = str(st.session_state.ctx_input) + dotenv.set_key(dotenv_file, "MODEL_N_CTX", os.environ["MODEL_N_CTX"]) + with col3: + if st.text_input('Stops', key="stops_input", value=str(model_stop)): + os.environ["MODEL_STOP"] = str(st.session_state.stops_input) + dotenv.set_key(dotenv_file, "MODEL_STOP", os.environ["MODEL_STOP"]) + #with st.form("my_form", clear_on_submit=True): + if st.session_state['generated']: + for i in range(len(st.session_state['generated'])): + message(st.session_state["past"][i], is_user=True, key=str(i) + '_user') + message(st.session_state["generated"][i], key=str(i)) + if input.strip() != "": + st.session_state.running=True + st.session_state.past.append(st.session_state.input) + if st.session_state.running: + message(st.session_state.input, is_user=True) + message("Loading response. Please wait for me to finish before refreshing the page...", key="rmessage") + #startLLM.qdrant = None #Not sure why this fixes db error + response = startLLM.main(st.session_state.input, True) + st.session_state.generated.append(response) + message(response) + st.session_state.running = False + st.text_input("You: ", "", key="input", disabled=st.session_state.running) + + +with st.form("my_form", clear_on_submit=True): + st.form_submit_button('SUBMIT', on_click=generate_response(st.session_state.input), disabled=st.session_state.running) + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ea752c3..3734f00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,8 @@ qdrant-client==1.1.7 llama-cpp-python==0.1.49 pdfminer.six==20221105 python-dotenv==1.0.0 +streamlit==1.22.0 +streamlit-chat==0.0.2.2 +streamlit-extras==0.2.7 pandoc==2.3 -unstructured==0.6.6 +unstructured==0.6.6 diff --git a/startLLM.py b/startLLM.py index c71dee5..503a1b2 100644 --- a/startLLM.py +++ b/startLLM.py @@ -11,11 +11,13 @@ persist_directory = os.environ.get('PERSIST_DIRECTORY') model_type = os.environ.get('MODEL_TYPE') model_path = os.environ.get('MODEL_PATH') -model_n_ctx = os.environ.get('MODEL_N_CTX') -model_temp = os.environ.get('MODEL_TEMP') +model_n_ctx = int(os.environ.get('MODEL_N_CTX')) +model_temp = float(os.environ.get('MODEL_TEMP')) model_stop = os.environ.get('MODEL_STOP').split(",") -def main(): +qa_system=None + +def initialize_qa_system(): # Load stored vectorstore llama = LlamaCppEmbeddings(model_path=llama_embeddings_model, n_ctx=model_n_ctx) # Load ggml-formatted model @@ -41,27 +43,36 @@ def main(): case _default: print("Only LlamaCpp or GPT4All supported right now. Make sure you set up your .env correctly.") qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=qdrant.as_retriever(search_type="mmr"), return_source_documents=True) + return qa +def main(prompt="", gui=False): + global qa_system + if qa_system is None: + qa_system = initialize_qa_system() # Interactive questions and answers - while True: - query = input("\nEnter a query: ") - if query == "exit": - break - - # Get the answer from the chain - res = qa(query) - answer, docs = res['result'], res['source_documents'] + if (prompt.strip() != "" and gui) or gui==False: + while True: + query = prompt if gui else input("\nEnter a query: ") + if query == "exit": + break + + # Get the answer from the chain + res = qa_system(query) + answer, docs = res['result'], res['source_documents'] - # Print the result - print("\n\n> Question:") - print(query) - print("\n> Answer:") - print(answer) - - # Print the relevant sources used for the answer - for document in docs: - print("\n> " + document.metadata["source"] + ":") - print(document.page_content) + # Print the result + print("\n\n> Question:") + print(query) + print("\n> Answer:") + print(answer) + + # Print the relevant sources used for the answer + for document in docs: + print("\n> " + document.metadata["source"] + ":") + print(document.page_content) + + if gui: + return answer if __name__ == "__main__": main()