diff --git a/README.md b/README.md index 5ba1c427..569e7303 100644 --- a/README.md +++ b/README.md @@ -255,4 +255,4 @@ Please cite our paper if you use this code or part of it in your work: year={2024}, booktitle={Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)} } -``` +``` \ No newline at end of file diff --git a/frontend/demo_light/.streamlit/config.toml b/frontend/demo_light/.streamlit/config.toml new file mode 100644 index 00000000..e1c8593f --- /dev/null +++ b/frontend/demo_light/.streamlit/config.toml @@ -0,0 +1,10 @@ +[client] +showErrorDetails = false +toolbarMode = "minimal" + +[theme] +primaryColor = "#F63366" +backgroundColor = "#FFFFFF" +secondaryBackgroundColor = "#F0F2F6" +textColor = "#262730" +font = "sans serif" \ No newline at end of file diff --git a/frontend/demo_light/README.md b/frontend/demo_light/README.md new file mode 100644 index 00000000..6a41c0e0 --- /dev/null +++ b/frontend/demo_light/README.md @@ -0,0 +1,33 @@ +# STORM Minimal User Interface + +This is a minimal user interface for `STORMWikiRunner` which includes the following features: +1. Allowing user to create a new article through the "Create New Article" page. +2. Showing the intermediate steps of STORMWikiRunner in real-time when creating an article. +3. Displaying the written article and references side by side. +4. Allowing user to view previously created articles through the "My Articles" page. + +

+ +

+ +

+ +

+ +## Setup +1. Besides the required packages for `STORMWikiRunner`, you need to install additional packages: + ```bash + pip install -r requirements.txt + ``` +2. Make sure you set up the API keys following the instructions in the main README file. Create a copy of `secrets.toml` and place it under `.streamlit/`. +3. Run the following command to start the user interface: + ```bash + streamlit run storm.py + ``` + The user interface will create a `DEMO_WORKING_DIR` directory in the current directory to store the outputs. + +## Customization + +You can customize the `STORMWikiRunner` powering the user interface according to [the guidelines](https://github.com/stanford-oval/storm?tab=readme-ov-file#customize-storm) in the main README file. + +The `STORMWikiRunner` is initialized in `set_storm_runner()` in [demo_util.py](demo_util.py). You can change `STORMWikiRunnerArguments`, `STORMWikiLMConfigs`, or use a different retrieval model according to your need. diff --git a/frontend/demo_light/assets/article_display.jpg b/frontend/demo_light/assets/article_display.jpg new file mode 100644 index 00000000..8b0236c3 Binary files /dev/null and b/frontend/demo_light/assets/article_display.jpg differ diff --git a/frontend/demo_light/assets/create_article.jpg b/frontend/demo_light/assets/create_article.jpg new file mode 100644 index 00000000..35b44f90 Binary files /dev/null and b/frontend/demo_light/assets/create_article.jpg differ diff --git a/frontend/demo_light/assets/void.jpg b/frontend/demo_light/assets/void.jpg new file mode 100644 index 00000000..1cda9a53 Binary files /dev/null and b/frontend/demo_light/assets/void.jpg differ diff --git a/frontend/demo_light/demo_util.py b/frontend/demo_light/demo_util.py new file mode 100644 index 00000000..d940aa09 --- /dev/null +++ b/frontend/demo_light/demo_util.py @@ -0,0 +1,572 @@ +import base64 +import datetime +import io +import json +import os +import re +from typing import Optional + +import markdown +import pdfkit +import pytz +import streamlit as st +from lm import OpenAIModel +from rm import YouRM +from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from storm_wiki.modules.callback import BaseCallbackHandler + +from stoc import stoc + + +class DemoFileIOHelper(): + @staticmethod + def read_structure_to_dict(articles_root_path): + """ + Reads the directory structure of articles stored in the given root path and + returns a nested dictionary. The outer dictionary has article names as keys, + and each value is another dictionary mapping file names to their absolute paths. + + Args: + articles_root_path (str): The root directory path containing article subdirectories. + + Returns: + dict: A dictionary where each key is an article name, and each value is a dictionary + of file names and their absolute paths within that article's directory. + """ + articles_dict = {} + for topic_name in os.listdir(articles_root_path): + topic_path = os.path.join(articles_root_path, topic_name) + if os.path.isdir(topic_path): + # Initialize or update the dictionary for the topic + articles_dict[topic_name] = {} + # Iterate over all files within a topic directory + for file_name in os.listdir(topic_path): + file_path = os.path.join(topic_path, file_name) + articles_dict[topic_name][file_name] = os.path.abspath(file_path) + return articles_dict + + @staticmethod + def read_txt_file(file_path): + """ + Reads the contents of a text file and returns it as a string. + + Args: + file_path (str): The path to the text file to be read. + + Returns: + str: The content of the file as a single string. + """ + with open(file_path) as f: + return f.read() + + @staticmethod + def read_json_file(file_path): + """ + Reads a JSON file and returns its content as a Python dictionary or list, + depending on the JSON structure. + + Args: + file_path (str): The path to the JSON file to be read. + + Returns: + dict or list: The content of the JSON file. The type depends on the + structure of the JSON file (object or array at the root). + """ + with open(file_path) as f: + return json.load(f) + + @staticmethod + def read_image_as_base64(image_path): + """ + Reads an image file and returns its content encoded as a base64 string, + suitable for embedding in HTML or transferring over networks where binary + data cannot be easily sent. + + Args: + image_path (str): The path to the image file to be encoded. + + Returns: + str: The base64 encoded string of the image, prefixed with the necessary + data URI scheme for images. + """ + with open(image_path, "rb") as f: + data = f.read() + encoded = base64.b64encode(data) + data = "data:image/png;base64," + encoded.decode("utf-8") + return data + + @staticmethod + def set_file_modification_time(file_path, modification_time_string): + """ + Sets the modification time of a file based on a given time string in the California time zone. + + Args: + file_path (str): The path to the file. + modification_time_string (str): The desired modification time in 'YYYY-MM-DD HH:MM:SS' format. + """ + california_tz = pytz.timezone('America/Los_Angeles') + modification_time = datetime.datetime.strptime(modification_time_string, '%Y-%m-%d %H:%M:%S') + modification_time = california_tz.localize(modification_time) + modification_time_utc = modification_time.astimezone(datetime.timezone.utc) + modification_timestamp = modification_time_utc.timestamp() + os.utime(file_path, (modification_timestamp, modification_timestamp)) + + @staticmethod + def get_latest_modification_time(path): + """ + Returns the latest modification time of all files in a directory in the California time zone as a string. + + Args: + directory_path (str): The path to the directory. + + Returns: + str: The latest file's modification time in 'YYYY-MM-DD HH:MM:SS' format. + """ + california_tz = pytz.timezone('America/Los_Angeles') + latest_mod_time = None + + file_paths = [] + if os.path.isdir(path): + for root, dirs, files in os.walk(path): + for file in files: + file_paths.append(os.path.join(root, file)) + else: + file_paths = [path] + + for file_path in file_paths: + modification_timestamp = os.path.getmtime(file_path) + modification_time_utc = datetime.datetime.utcfromtimestamp(modification_timestamp) + modification_time_utc = modification_time_utc.replace(tzinfo=datetime.timezone.utc) + modification_time_california = modification_time_utc.astimezone(california_tz) + + if latest_mod_time is None or modification_time_california > latest_mod_time: + latest_mod_time = modification_time_california + + if latest_mod_time is not None: + return latest_mod_time.strftime('%Y-%m-%d %H:%M:%S') + else: + return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + @staticmethod + def assemble_article_data(article_file_path_dict): + """ + Constructs a dictionary containing the content and metadata of an article + based on the available files in the article's directory. This includes the + main article text, citations from a JSON file, and a conversation log if + available. The function prioritizes a polished version of the article if + both a raw and polished version exist. + + Args: + article_file_paths (dict): A dictionary where keys are file names relevant + to the article (e.g., the article text, citations + in JSON format, conversation logs) and values + are their corresponding file paths. + + Returns: + dict or None: A dictionary containing the parsed content of the article, + citations, and conversation log if available. Returns None + if neither the raw nor polished article text exists in the + provided file paths. + """ + if "storm_gen_article.txt" in article_file_path_dict or "storm_gen_article_polished.txt" in article_file_path_dict: + full_article_name = "storm_gen_article_polished.txt" if "storm_gen_article_polished.txt" in article_file_path_dict else "storm_gen_article.txt" + article_data = {"article": DemoTextProcessingHelper.parse( + DemoFileIOHelper.read_txt_file(article_file_path_dict[full_article_name]))} + if "url_to_info.json" in article_file_path_dict: + article_data["citations"] = _construct_citation_dict_from_search_result( + DemoFileIOHelper.read_json_file(article_file_path_dict["url_to_info.json"])) + if "conversation_log.json" in article_file_path_dict: + article_data["conversation_log"] = DemoFileIOHelper.read_json_file( + article_file_path_dict["conversation_log.json"]) + return article_data + return None + + +class DemoTextProcessingHelper(): + + @staticmethod + def remove_citations(sent): + return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "") + + @staticmethod + def parse_conversation_history(json_data): + """ + Given conversation log data, return list of parsed data of following format + (persona_name, persona_description, list of dialogue turn) + """ + parsed_data = [] + for persona_conversation_data in json_data: + if ': ' in persona_conversation_data["perspective"]: + name, description = persona_conversation_data["perspective"].split(": ", 1) + elif '- ' in persona_conversation_data["perspective"]: + name, description = persona_conversation_data["perspective"].split("- ", 1) + else: + name, description = "", persona_conversation_data["perspective"] + cur_conversation = [] + for dialogue_turn in persona_conversation_data["dlg_turns"]: + cur_conversation.append({"role": "user", "content": dialogue_turn["user_utterance"]}) + cur_conversation.append( + {"role": "assistant", + "content": DemoTextProcessingHelper.remove_citations(dialogue_turn["agent_utterance"])}) + parsed_data.append((name, description, cur_conversation)) + return parsed_data + + @staticmethod + def parse(text): + regex = re.compile(r']:\s+"(.*?)"\s+http') + text = regex.sub(']: http', text) + return text + + @staticmethod + def add_markdown_indentation(input_string): + lines = input_string.split('\n') + processed_lines = [""] + for line in lines: + num_hashes = 0 + for char in line: + if char == '#': + num_hashes += 1 + else: + break + num_hashes -= 1 + num_spaces = 4 * num_hashes + new_line = ' ' * num_spaces + line + processed_lines.append(new_line) + return '\n'.join(processed_lines) + + @staticmethod + def get_current_time_string(): + """ + Returns the current time in the California time zone as a string. + + Returns: + str: The current California time in 'YYYY-MM-DD HH:MM:SS' format. + """ + california_tz = pytz.timezone('America/Los_Angeles') + utc_now = datetime.datetime.now(datetime.timezone.utc) + california_now = utc_now.astimezone(california_tz) + return california_now.strftime('%Y-%m-%d %H:%M:%S') + + @staticmethod + def compare_time_strings(time_string1, time_string2, time_format='%Y-%m-%d %H:%M:%S'): + """ + Compares two time strings to determine if they represent the same point in time. + + Args: + time_string1 (str): The first time string to compare. + time_string2 (str): The second time string to compare. + time_format (str): The format of the time strings, defaults to '%Y-%m-%d %H:%M:%S'. + + Returns: + bool: True if the time strings represent the same time, False otherwise. + """ + # Parse the time strings into datetime objects + time1 = datetime.datetime.strptime(time_string1, time_format) + time2 = datetime.datetime.strptime(time_string2, time_format) + + # Compare the datetime objects + return time1 == time2 + + @staticmethod + def add_inline_citation_link(article_text, citation_dict): + # Regular expression to find citations like [i] + pattern = r'\[(\d+)\]' + + # Function to replace each citation with its Markdown link + def replace_with_link(match): + i = match.group(1) + url = citation_dict.get(int(i), {}).get('url', '#') + return f'[[{i}]]({url})' + + # Replace all citations in the text with Markdown links + return re.sub(pattern, replace_with_link, article_text) + + @staticmethod + def generate_html_toc(md_text): + toc = [] + for line in md_text.splitlines(): + if line.startswith("#"): + level = line.count("#") + title = line.strip("# ").strip() + anchor = title.lower().replace(" ", "-").replace(".", "") + toc.append(f"
  • {title}
  • ") + return "" + + @staticmethod + def construct_bibliography_from_url_to_info(url_to_info): + bibliography_list = [] + sorted_url_to_unified_index = dict(sorted(url_to_info['url_to_unified_index'].items(), + key=lambda item: item[1])) + for url, index in sorted_url_to_unified_index.items(): + title = url_to_info['url_to_info'][url]['title'] + bibliography_list.append(f"[{index}]: [{title}]({url})") + bibliography_string = "\n\n".join(bibliography_list) + return f"# References\n\n{bibliography_string}" + + +class DemoUIHelper(): + def st_markdown_adjust_size(content, font_size=20): + st.markdown(f""" + {content} + """, unsafe_allow_html=True) + + @staticmethod + def get_article_card_UI_style(boarder_color="#9AD8E1"): + return { + "card": { + "width": "100%", + "height": "116px", + "max-width": "640px", + "background-color": "#FFFFF", + "border": "1px solid #CCC", + "padding": "20px", + "border-radius": "5px", + "border-left": f"0.5rem solid {boarder_color}", + "box-shadow": "0 0.15rem 1.75rem 0 rgba(58, 59, 69, 0.15)", + "margin": "0px" + }, + "title": { + "white-space": "nowrap", + "overflow": "hidden", + "text-overflow": "ellipsis", + "font-size": "17px", + "color": "rgb(49, 51, 63)", + "text-align": "left", + "width": "95%", + "font-weight": "normal" + }, + "text": { + "white-space": "nowrap", + "overflow": "hidden", + "text-overflow": "ellipsis", + "font-size": "25px", + "color": "rgb(49, 51, 63)", + "text-align": "left", + "width": "95%" + }, + "filter": { + "background-color": "rgba(0, 0, 0, 0)" + } + } + + @staticmethod + def customize_toast_css_style(): + # Note padding is top right bottom left + st.markdown( + """ + + """, unsafe_allow_html=True + ) + + @staticmethod + def article_markdown_to_html(article_title, article_content): + return f""" + + + + {article_title} + + + +
    +

    {article_title.replace('_', ' ')}

    +
    +

    Table of Contents

    + {DemoTextProcessingHelper.generate_html_toc(article_content)} + {markdown.markdown(article_content)} + + + """ + + +def _construct_citation_dict_from_search_result(search_results): + if search_results is None: + return None + citation_dict = {} + for url, index in search_results['url_to_unified_index'].items(): + citation_dict[index] = {'url': url, + 'title': search_results['url_to_info'][url]['title'], + 'snippets': search_results['url_to_info'][url]['snippets']} + return citation_dict + + +def _display_main_article_text(article_text, citation_dict, table_content_sidebar): + # Post-process the generated article for better display. + if "Write the lead section:" in article_text: + article_text = article_text[ + article_text.find("Write the lead section:") + len("Write the lead section:"):] + if article_text[0] == '#': + article_text = '\n'.join(article_text.split('\n')[1:]) + article_text = DemoTextProcessingHelper.add_inline_citation_link(article_text, citation_dict) + # '$' needs to be changed to '\$' to avoid being interpreted as LaTeX in st.markdown() + article_text = article_text.replace("$", "\\$") + stoc.from_markdown(article_text, table_content_sidebar) + + +def _display_references(citation_dict): + if citation_dict: + reference_list = [f"reference [{i}]" for i in range(1, len(citation_dict) + 1)] + selected_key = st.selectbox("Select a reference", reference_list) + citation_val = citation_dict[reference_list.index(selected_key) + 1] + citation_val['title'] = citation_val['title'].replace("$", "\\$") + st.markdown(f"**Title:** {citation_val['title']}") + st.markdown(f"**Url:** {citation_val['url']}") + snippets = '\n\n'.join(citation_val['snippets']).replace("$", "\\$") + st.markdown(f"**Highlights:**\n\n {snippets}") + else: + st.markdown("**No references available**") + + +def _display_persona_conversations(conversation_log): + """ + Display persona conversation in dialogue UI + """ + # get personas list as (persona_name, persona_description, dialogue turns list) tuple + parsed_conversation_history = DemoTextProcessingHelper.parse_conversation_history(conversation_log) + # construct tabs for each persona conversation + persona_tabs = st.tabs([name for (name, _, _) in parsed_conversation_history]) + for idx, persona_tab in enumerate(persona_tabs): + with persona_tab: + # show persona description + st.info(parsed_conversation_history[idx][1]) + # show user / agent utterance in dialogue UI + for message in parsed_conversation_history[idx][2]: + message['content'] = message['content'].replace("$", "\\$") + with st.chat_message(message["role"]): + if message["role"] == "user": + st.markdown(f"**{message['content']}**") + else: + st.markdown(message["content"]) + + +def _display_main_article(selected_article_file_path_dict, show_reference=True, show_conversation=True): + article_data = DemoFileIOHelper.assemble_article_data(selected_article_file_path_dict) + + with st.container(height=1000, border=True): + table_content_sidebar = st.sidebar.expander("**Table of contents**", expanded=True) + _display_main_article_text(article_text=article_data.get("article", ""), + citation_dict=article_data.get("citations", {}), + table_content_sidebar=table_content_sidebar) + + # display reference panel + if show_reference and "citations" in article_data: + with st.sidebar.expander("**References**", expanded=True): + with st.container(height=800, border=False): + _display_references(citation_dict=article_data.get("citations", {})) + + # display conversation history + if show_conversation and "conversation_log" in article_data: + with st.expander( + "**STORM** is powered by a knowledge agent that proactively research a given topic by asking good questions coming from different perspectives.\n\n" + ":sunglasses: Click here to view the agent's brain**STORM**ing process!"): + _display_persona_conversations(conversation_log=article_data.get("conversation_log", {})) + + +def get_demo_dir(): + return os.path.dirname(os.path.abspath(__file__)) + + +def clear_other_page_session_state(page_index: Optional[int]): + if page_index is None: + keys_to_delete = [key for key in st.session_state if key.startswith("page")] + else: + keys_to_delete = [key for key in st.session_state if key.startswith("page") and f"page{page_index}" not in key] + for key in set(keys_to_delete): + del st.session_state[key] + + +def set_storm_runner(): + current_working_dir = os.path.join(get_demo_dir(), "DEMO_WORKING_DIR") + if not os.path.exists(current_working_dir): + os.makedirs(current_working_dir) + + # configure STORM runner + llm_configs = STORMWikiLMConfigs() + llm_configs.init_openai_model(openai_api_key=st.secrets['OPENAI_API_KEY'], openai_type='openai') + llm_configs.set_question_asker_lm(OpenAIModel(model='gpt-4-1106-preview', api_key=st.secrets['OPENAI_API_KEY'], + api_provider='openai', + max_tokens=500, temperature=1.0, top_p=0.9)) + engine_args = STORMWikiRunnerArguments( + output_dir=current_working_dir, + max_conv_turn=3, + max_perspective=3, + search_top_k=3, + retrieve_top_k=5 + ) + + rm = YouRM(ydc_api_key=st.secrets['YDC_API_KEY'], k=engine_args.search_top_k) + + runner = STORMWikiRunner(engine_args, llm_configs, rm) + st.session_state["runner"] = runner + + +def display_article_page(selected_article_name, selected_article_file_path_dict, + show_title=True, show_main_article=True): + if show_title: + st.markdown(f"

    {selected_article_name.replace('_', ' ')}

    ", + unsafe_allow_html=True) + + if show_main_article: + _display_main_article(selected_article_file_path_dict) + + + +class StreamlitCallbackHandler(BaseCallbackHandler): + def __init__(self, status_container): + self.status_container = status_container + + def on_identify_perspective_start(self, **kwargs): + self.status_container.info('Start identifying different perspectives for researching the topic.') + + def on_identify_perspective_end(self, perspectives: list[str], **kwargs): + perspective_list = "\n- ".join(perspectives) + self.status_container.success(f'Finish identifying perspectives. Will now start gathering information' + f' from the following perspectives:\n- {perspective_list}') + + def on_information_gathering_start(self, **kwargs): + self.status_container.info('Start browsing the Internet.') + + def on_dialogue_turn_end(self, dlg_turn, **kwargs): + urls = list(set([r.url for r in dlg_turn.search_results])) + for url in urls: + self.status_container.markdown(f""" + +
    Finish browsing {url}.
    + """, unsafe_allow_html=True) + + def on_information_gathering_end(self, **kwargs): + self.status_container.success('Finish collecting information.') + + def on_information_organization_start(self, **kwargs): + self.status_container.info('Start organizing information into a hierarchical outline.') + + def on_direct_outline_generation_end(self, outline: str, **kwargs): + self.status_container.success(f'Finish leveraging the internal knowledge of the large language model.') + + def on_outline_refinement_end(self, outline: str, **kwargs): + self.status_container.success(f'Finish leveraging the collected information.') diff --git a/frontend/demo_light/pages_util/CreateNewArticle.py b/frontend/demo_light/pages_util/CreateNewArticle.py new file mode 100644 index 00000000..9495ffe5 --- /dev/null +++ b/frontend/demo_light/pages_util/CreateNewArticle.py @@ -0,0 +1,102 @@ +import os +import time + +import demo_util +import streamlit as st +from demo_util import DemoFileIOHelper, DemoTextProcessingHelper, DemoUIHelper + + +def create_new_article_page(): + demo_util.clear_other_page_session_state(page_index=3) + + if "page3_write_article_state" not in st.session_state: + st.session_state["page3_write_article_state"] = "not started" + + if st.session_state["page3_write_article_state"] == "not started": + + _, search_form_column, _ = st.columns([2, 5, 2]) + with search_form_column: + with st.form(key='search_form'): + # Text input for the search topic + DemoUIHelper.st_markdown_adjust_size(content="Enter the topic you want to learn in depth:", + font_size=18) + st.session_state["page3_topic"] = st.text_input(label='page3_topic', label_visibility="collapsed") + pass_appropriateness_check = True + + # Submit button for the form + submit_button = st.form_submit_button(label='Research') + # only start new search when button is clicked, not started, or already finished previous one + if submit_button and st.session_state["page3_write_article_state"] in ["not started", "show results"]: + if not st.session_state["page3_topic"].strip(): + pass_appropriateness_check = False + st.session_state["page3_warning_message"] = "topic could not be empty" + + st.session_state["page3_topic_name_cleaned"] = st.session_state["page3_topic"].replace( + ' ', '_').replace('/', '_') + if not pass_appropriateness_check: + st.session_state["page3_write_article_state"] = "not started" + alert = st.warning(st.session_state["page3_warning_message"], icon="⚠️") + time.sleep(5) + alert.empty() + else: + st.session_state["page3_write_article_state"] = "initiated" + + if st.session_state["page3_write_article_state"] == "initiated": + current_working_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR") + if not os.path.exists(current_working_dir): + os.makedirs(current_working_dir) + + if "runner" not in st.session_state: + demo_util.set_storm_runner() + st.session_state["page3_current_working_dir"] = current_working_dir + st.session_state["page3_write_article_state"] = "pre_writing" + + if st.session_state["page3_write_article_state"] == "pre_writing": + status = st.status("I am brain**STORM**ing now to research the topic. (This may take 2-3 minutes.)") + st_callback_handler = demo_util.StreamlitCallbackHandler(status) + with status: + # STORM main gen outline + st.session_state["runner"].run( + topic=st.session_state["page3_topic"], + do_research=True, + do_generate_outline=True, + do_generate_article=False, + do_polish_article=False, + callback_handler=st_callback_handler + ) + conversation_log_path = os.path.join(st.session_state["page3_current_working_dir"], + st.session_state["page3_topic_name_cleaned"], "conversation_log.json") + demo_util._display_persona_conversations(DemoFileIOHelper.read_json_file(conversation_log_path)) + st.session_state["page3_write_article_state"] = "final_writing" + status.update(label="brain**STORM**ing complete!", state="complete") + + if st.session_state["page3_write_article_state"] == "final_writing": + # polish final article + with st.status( + "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)") as status: + st.info('Now I will connect the information I found for your reference. (This may take 4-5 minutes.)') + st.session_state["runner"].run(topic=st.session_state["page3_topic"], do_research=False, + do_generate_outline=False, + do_generate_article=True, do_polish_article=True, remove_duplicate=False) + # finish the session + st.session_state["runner"].post_run() + + # update status bar + st.session_state["page3_write_article_state"] = "prepare_to_show_result" + status.update(label="information snythesis complete!", state="complete") + + if st.session_state["page3_write_article_state"] == "prepare_to_show_result": + _, show_result_col, _ = st.columns([4, 3, 4]) + with show_result_col: + if st.button("show final article"): + st.session_state["page3_write_article_state"] = "completed" + st.rerun() + + if st.session_state["page3_write_article_state"] == "completed": + # display polished article + current_working_dir_paths = DemoFileIOHelper.read_structure_to_dict( + st.session_state["page3_current_working_dir"]) + current_article_file_path_dict = current_working_dir_paths[st.session_state["page3_topic_name_cleaned"]] + demo_util.display_article_page(selected_article_name=st.session_state["page3_topic_name_cleaned"], + selected_article_file_path_dict=current_article_file_path_dict, + show_title=True, show_main_article=True) diff --git a/frontend/demo_light/pages_util/MyArticles.py b/frontend/demo_light/pages_util/MyArticles.py new file mode 100644 index 00000000..f2c306f2 --- /dev/null +++ b/frontend/demo_light/pages_util/MyArticles.py @@ -0,0 +1,89 @@ +import os + +import demo_util +import streamlit as st +from demo_util import DemoFileIOHelper, DemoUIHelper +from streamlit_card import card + + +# set page config and display title +def my_articles_page(): + with st.sidebar: + _, return_button_col = st.columns([2, 5]) + with return_button_col: + if st.button("Select another article", disabled="page2_selected_my_article" not in st.session_state): + if "page2_selected_my_article" in st.session_state: + del st.session_state["page2_selected_my_article"] + st.rerun() + + # sync my articles + if "page2_user_articles_file_path_dict" not in st.session_state: + local_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR") + os.makedirs(local_dir, exist_ok=True) + st.session_state["page2_user_articles_file_path_dict"] = DemoFileIOHelper.read_structure_to_dict(local_dir) + + # if no feature demo selected, display all featured articles as info cards + def article_card_setup(column_to_add, card_title, article_name): + with column_to_add: + cleaned_article_title = article_name.replace("_", " ") + hasClicked = card(title=" / ".join(card_title), + text=article_name.replace("_", " "), + image=DemoFileIOHelper.read_image_as_base64( + os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg")), + styles=DemoUIHelper.get_article_card_UI_style(boarder_color="#9AD8E1")) + if hasClicked: + st.session_state["page2_selected_my_article"] = article_name + st.rerun() + + if "page2_selected_my_article" not in st.session_state: + # display article cards + my_article_columns = st.columns(3) + if len(st.session_state["page2_user_articles_file_path_dict"]) > 0: + # get article names + article_names = sorted(list(st.session_state["page2_user_articles_file_path_dict"].keys())) + # configure pagination + pagination = st.container() + bottom_menu = st.columns((1, 4, 1, 1, 1))[1:-1] + with bottom_menu[2]: + batch_size = st.selectbox("Page Size", options=[24, 48, 72]) + with bottom_menu[1]: + total_pages = ( + int(len(article_names) / batch_size) if int(len(article_names) / batch_size) > 0 else 1 + ) + current_page = st.number_input( + "Page", min_value=1, max_value=total_pages, step=1 + ) + with bottom_menu[0]: + st.markdown(f"Page **{current_page}** of **{total_pages}** ") + # show article cards + with pagination: + my_article_count = 0 + start_index = (current_page - 1) * batch_size + end_index = min(current_page * batch_size, len(article_names)) + for article_name in article_names[start_index: end_index]: + column_to_add = my_article_columns[my_article_count % 3] + my_article_count += 1 + article_card_setup(column_to_add=column_to_add, + card_title=["My Article"], + article_name=article_name) + else: + with my_article_columns[0]: + hasClicked = card(title="Get started", + text="Start your first research!", + image=DemoFileIOHelper.read_image_as_base64( + os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg")), + styles=DemoUIHelper.get_article_card_UI_style()) + if hasClicked: + st.session_state.selected_page = 1 + st.session_state["manual_selection_override"] = True + st.session_state["rerun_requested"] = True + st.rerun() + else: + selected_article_name = st.session_state["page2_selected_my_article"] + selected_article_file_path_dict = st.session_state["page2_user_articles_file_path_dict"][selected_article_name] + + demo_util.display_article_page(selected_article_name=selected_article_name, + selected_article_file_path_dict=selected_article_file_path_dict, + show_title=True, show_main_article=True, + show_feedback_form=False, + show_qa_panel=False) diff --git a/frontend/demo_light/requirements.txt b/frontend/demo_light/requirements.txt new file mode 100644 index 00000000..a5ba2884 --- /dev/null +++ b/frontend/demo_light/requirements.txt @@ -0,0 +1,10 @@ +streamlit==1.31.1 +streamlit-card +markdown +unidecode +extra-streamlit-components==0.1.60 +streamlit_extras +deprecation==2.1.0 +st-pages==0.4.5 +streamlit-float +streamlit-option-menu \ No newline at end of file diff --git a/frontend/demo_light/stoc.py b/frontend/demo_light/stoc.py new file mode 100644 index 00000000..7bd4402b --- /dev/null +++ b/frontend/demo_light/stoc.py @@ -0,0 +1,131 @@ +"""https://github.com/arnaudmiribel/stoc""" + +import re + +import streamlit as st +import unidecode + +DISABLE_LINK_CSS = """ +""" + + +class stoc: + def __init__(self): + self.toc_items = list() + + def h1(self, text: str, write: bool = True): + if write: + st.write(f"# {text}") + self.toc_items.append(("h1", text)) + + def h2(self, text: str, write: bool = True): + if write: + st.write(f"## {text}") + self.toc_items.append(("h2", text)) + + def h3(self, text: str, write: bool = True): + if write: + st.write(f"### {text}") + self.toc_items.append(("h3", text)) + + def toc(self, expander): + st.write(DISABLE_LINK_CSS, unsafe_allow_html=True) + # st.sidebar.caption("Table of contents") + if expander is None: + expander = st.sidebar.expander("**Table of contents**", expanded=True) + with expander: + with st.container(height=600, border=False): + markdown_toc = "" + for title_size, title in self.toc_items: + h = int(title_size.replace("h", "")) + markdown_toc += ( + " " * 2 * h + + "- " + + f' {title} \n' + ) + # st.sidebar.write(markdown_toc, unsafe_allow_html=True) + st.write(markdown_toc, unsafe_allow_html=True) + + @classmethod + def get_toc(cls, markdown_text: str, topic=""): + def increase_heading_depth_and_add_top_heading(markdown_text, new_top_heading): + lines = markdown_text.splitlines() + # Increase the depth of each heading by adding an extra '#' + increased_depth_lines = ['#' + line if line.startswith('#') else line for line in lines] + # Add the new top-level heading at the beginning + increased_depth_lines.insert(0, f"# {new_top_heading}") + # Re-join the modified lines back into a single string + modified_text = '\n'.join(increased_depth_lines) + return modified_text + + if topic: + markdown_text = increase_heading_depth_and_add_top_heading(markdown_text, topic) + toc = [] + for line in markdown_text.splitlines(): + if line.startswith('#'): + # Remove the '#' characters and strip leading/trailing spaces + heading_text = line.lstrip('#').strip() + # Create slug (lowercase, spaces to hyphens, remove non-alphanumeric characters) + slug = re.sub(r'[^a-zA-Z0-9\s-]', '', heading_text).lower().replace(' ', '-') + # Determine heading level for indentation + level = line.count('#') - 1 + # Add to the table of contents + toc.append(' ' * level + f'- [{heading_text}](#{slug})') + return '\n'.join(toc) + + @classmethod + def from_markdown(cls, text: str, expander=None): + self = cls() + for line in text.splitlines(): + if line.startswith("###"): + self.h3(line[3:], write=False) + elif line.startswith("##"): + self.h2(line[2:], write=False) + elif line.startswith("#"): + self.h1(line[1:], write=False) + # customize markdown font size + custom_css = """ + + """ + st.markdown(custom_css, unsafe_allow_html=True) + + st.write(text) + self.toc(expander=expander) + + +def normalize(s): + """ + Normalize titles as valid HTML ids for anchors + >>> normalize("it's a test to spot how Things happ3n héhé") + "it-s-a-test-to-spot-how-things-happ3n-h-h" + """ + + # Replace accents with "-" + s_wo_accents = unidecode.unidecode(s) + accents = [s for s in s if s not in s_wo_accents] + for accent in accents: + s = s.replace(accent, "-") + + # Lowercase + s = s.lower() + + # Keep only alphanum and remove "-" suffix if existing + normalized = ( + "".join([char if char.isalnum() else "-" for char in s]).strip("-").lower() + ) + + return normalized diff --git a/frontend/demo_light/storm.py b/frontend/demo_light/storm.py new file mode 100644 index 00000000..9a0ae663 --- /dev/null +++ b/frontend/demo_light/storm.py @@ -0,0 +1,64 @@ +import os +import sys + +script_dir = os.path.dirname(os.path.abspath(__file__)) +wiki_root_dir = os.path.dirname(os.path.dirname(script_dir)) + +sys.path.append(os.path.normpath(os.path.join(script_dir, '../../src/storm_wiki'))) +sys.path.append(os.path.normpath(os.path.join(script_dir, '../../src'))) + +import demo_util +from pages_util import MyArticles, CreateNewArticle +from streamlit_float import * +from streamlit_option_menu import option_menu + + +def main(): + global database + st.set_page_config(layout='wide') + + if "first_run" not in st.session_state: + st.session_state['first_run'] = True + + # set api keys from secrets + if st.session_state['first_run']: + for key, value in st.secrets.items(): + if type(value) == str: + os.environ[key] = value + + # initialize session_state + if "selected_article_index" not in st.session_state: + st.session_state["selected_article_index"] = 0 + if "selected_page" not in st.session_state: + st.session_state["selected_page"] = 0 + if st.session_state.get("rerun_requested", False): + st.session_state["rerun_requested"] = False + st.rerun() + + st.write('', unsafe_allow_html=True) + menu_container = st.container() + with menu_container: + pages = ["My Articles", "Create New Article"] + menu_selection = option_menu(None, pages, + icons=['house', 'search'], + menu_icon="cast", default_index=0, orientation="horizontal", + manual_select=st.session_state.selected_page, + styles={ + "container": {"padding": "0.2rem 0", "background-color": "#22222200"}, + }, + key='menu_selection') + if st.session_state.get("manual_selection_override", False): + menu_selection = pages[st.session_state["selected_page"]] + st.session_state["manual_selection_override"] = False + st.session_state["selected_page"] = None + + if menu_selection == "My Articles": + demo_util.clear_other_page_session_state(page_index=2) + MyArticles.my_articles_page() + elif menu_selection == "Create New Article": + demo_util.clear_other_page_session_state(page_index=3) + CreateNewArticle.create_new_article_page() + + +if __name__ == "__main__": + main() diff --git a/src/storm_wiki/engine.py b/src/storm_wiki/engine.py index 66fa4b05..4191d35a 100644 --- a/src/storm_wiki/engine.py +++ b/src/storm_wiki/engine.py @@ -48,24 +48,9 @@ def init_openai_model( 'api_provider': openai_type, 'temperature': temperature, 'top_p': top_p, - 'api_base': None, - 'api_version': None, + 'api_base': None } - if openai_type and openai_type == 'azure': - openai_kwargs['api_base'] = api_base - openai_kwargs['api_version'] = api_version - - self.conv_simulator_lm = OpenAIModel(model='gpt-35-turbo-instruct', engine='gpt-35-turbo-instruct', - max_tokens=500, **openai_kwargs) - self.question_asker_lm = OpenAIModel(model='gpt-35-turbo', engine='gpt-35-turbo', - max_tokens=500, **openai_kwargs) - self.outline_gen_lm = OpenAIModel(model='gpt-4', engine='gpt-4', - max_tokens=400, **openai_kwargs) - self.article_gen_lm = OpenAIModel(model='gpt-4', engine='gpt-4', - max_tokens=700, **openai_kwargs) - self.article_polish_lm = OpenAIModel(model='gpt-4-32k', engine='gpt-4-32k', - max_tokens=4000, **openai_kwargs) - elif openai_type and openai_type == 'openai': + if openai_type and openai_type == 'openai': self.conv_simulator_lm = OpenAIModel(model='gpt-3.5-turbo-instruct', max_tokens=500, **openai_kwargs) self.question_asker_lm = OpenAIModel(model='gpt-3.5-turbo', @@ -73,9 +58,9 @@ def init_openai_model( # 1/12/2024: Update gpt-4 to gpt-4-1106-preview. (Currently keep the original setup when using azure.) self.outline_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=400, **openai_kwargs) - self.article_gen_lm = OpenAIModel(model='gpt-4-0125-preview', + self.article_gen_lm = OpenAIModel(model='gpt-4o-2024-05-13', max_tokens=700, **openai_kwargs) - self.article_polish_lm = OpenAIModel(model='gpt-4-0125-preview', + self.article_polish_lm = OpenAIModel(model='gpt-4o-2024-05-13', max_tokens=4000, **openai_kwargs) else: logging.warning('No valid OpenAI API provider is provided. Cannot use default LLM configurations.')