diff --git a/README.md b/README.md
index 5ba1c427..569e7303 100644
--- a/README.md
+++ b/README.md
@@ -255,4 +255,4 @@ Please cite our paper if you use this code or part of it in your work:
year={2024},
booktitle={Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)}
}
-```
+```
\ No newline at end of file
diff --git a/frontend/demo_light/.streamlit/config.toml b/frontend/demo_light/.streamlit/config.toml
new file mode 100644
index 00000000..e1c8593f
--- /dev/null
+++ b/frontend/demo_light/.streamlit/config.toml
@@ -0,0 +1,10 @@
+[client]
+showErrorDetails = false
+toolbarMode = "minimal"
+
+[theme]
+primaryColor = "#F63366"
+backgroundColor = "#FFFFFF"
+secondaryBackgroundColor = "#F0F2F6"
+textColor = "#262730"
+font = "sans serif"
\ No newline at end of file
diff --git a/frontend/demo_light/README.md b/frontend/demo_light/README.md
new file mode 100644
index 00000000..6a41c0e0
--- /dev/null
+++ b/frontend/demo_light/README.md
@@ -0,0 +1,33 @@
+# STORM Minimal User Interface
+
+This is a minimal user interface for `STORMWikiRunner` which includes the following features:
+1. Allowing user to create a new article through the "Create New Article" page.
+2. Showing the intermediate steps of STORMWikiRunner in real-time when creating an article.
+3. Displaying the written article and references side by side.
+4. Allowing user to view previously created articles through the "My Articles" page.
+
+
+
+
+
+
+
+
+
+## Setup
+1. Besides the required packages for `STORMWikiRunner`, you need to install additional packages:
+ ```bash
+ pip install -r requirements.txt
+ ```
+2. Make sure you set up the API keys following the instructions in the main README file. Create a copy of `secrets.toml` and place it under `.streamlit/`.
+3. Run the following command to start the user interface:
+ ```bash
+ streamlit run storm.py
+ ```
+ The user interface will create a `DEMO_WORKING_DIR` directory in the current directory to store the outputs.
+
+## Customization
+
+You can customize the `STORMWikiRunner` powering the user interface according to [the guidelines](https://github.com/stanford-oval/storm?tab=readme-ov-file#customize-storm) in the main README file.
+
+The `STORMWikiRunner` is initialized in `set_storm_runner()` in [demo_util.py](demo_util.py). You can change `STORMWikiRunnerArguments`, `STORMWikiLMConfigs`, or use a different retrieval model according to your need.
diff --git a/frontend/demo_light/assets/article_display.jpg b/frontend/demo_light/assets/article_display.jpg
new file mode 100644
index 00000000..8b0236c3
Binary files /dev/null and b/frontend/demo_light/assets/article_display.jpg differ
diff --git a/frontend/demo_light/assets/create_article.jpg b/frontend/demo_light/assets/create_article.jpg
new file mode 100644
index 00000000..35b44f90
Binary files /dev/null and b/frontend/demo_light/assets/create_article.jpg differ
diff --git a/frontend/demo_light/assets/void.jpg b/frontend/demo_light/assets/void.jpg
new file mode 100644
index 00000000..1cda9a53
Binary files /dev/null and b/frontend/demo_light/assets/void.jpg differ
diff --git a/frontend/demo_light/demo_util.py b/frontend/demo_light/demo_util.py
new file mode 100644
index 00000000..d940aa09
--- /dev/null
+++ b/frontend/demo_light/demo_util.py
@@ -0,0 +1,572 @@
+import base64
+import datetime
+import io
+import json
+import os
+import re
+from typing import Optional
+
+import markdown
+import pdfkit
+import pytz
+import streamlit as st
+from lm import OpenAIModel
+from rm import YouRM
+from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
+from storm_wiki.modules.callback import BaseCallbackHandler
+
+from stoc import stoc
+
+
+class DemoFileIOHelper():
+ @staticmethod
+ def read_structure_to_dict(articles_root_path):
+ """
+ Reads the directory structure of articles stored in the given root path and
+ returns a nested dictionary. The outer dictionary has article names as keys,
+ and each value is another dictionary mapping file names to their absolute paths.
+
+ Args:
+ articles_root_path (str): The root directory path containing article subdirectories.
+
+ Returns:
+ dict: A dictionary where each key is an article name, and each value is a dictionary
+ of file names and their absolute paths within that article's directory.
+ """
+ articles_dict = {}
+ for topic_name in os.listdir(articles_root_path):
+ topic_path = os.path.join(articles_root_path, topic_name)
+ if os.path.isdir(topic_path):
+ # Initialize or update the dictionary for the topic
+ articles_dict[topic_name] = {}
+ # Iterate over all files within a topic directory
+ for file_name in os.listdir(topic_path):
+ file_path = os.path.join(topic_path, file_name)
+ articles_dict[topic_name][file_name] = os.path.abspath(file_path)
+ return articles_dict
+
+ @staticmethod
+ def read_txt_file(file_path):
+ """
+ Reads the contents of a text file and returns it as a string.
+
+ Args:
+ file_path (str): The path to the text file to be read.
+
+ Returns:
+ str: The content of the file as a single string.
+ """
+ with open(file_path) as f:
+ return f.read()
+
+ @staticmethod
+ def read_json_file(file_path):
+ """
+ Reads a JSON file and returns its content as a Python dictionary or list,
+ depending on the JSON structure.
+
+ Args:
+ file_path (str): The path to the JSON file to be read.
+
+ Returns:
+ dict or list: The content of the JSON file. The type depends on the
+ structure of the JSON file (object or array at the root).
+ """
+ with open(file_path) as f:
+ return json.load(f)
+
+ @staticmethod
+ def read_image_as_base64(image_path):
+ """
+ Reads an image file and returns its content encoded as a base64 string,
+ suitable for embedding in HTML or transferring over networks where binary
+ data cannot be easily sent.
+
+ Args:
+ image_path (str): The path to the image file to be encoded.
+
+ Returns:
+ str: The base64 encoded string of the image, prefixed with the necessary
+ data URI scheme for images.
+ """
+ with open(image_path, "rb") as f:
+ data = f.read()
+ encoded = base64.b64encode(data)
+ data = "data:image/png;base64," + encoded.decode("utf-8")
+ return data
+
+ @staticmethod
+ def set_file_modification_time(file_path, modification_time_string):
+ """
+ Sets the modification time of a file based on a given time string in the California time zone.
+
+ Args:
+ file_path (str): The path to the file.
+ modification_time_string (str): The desired modification time in 'YYYY-MM-DD HH:MM:SS' format.
+ """
+ california_tz = pytz.timezone('America/Los_Angeles')
+ modification_time = datetime.datetime.strptime(modification_time_string, '%Y-%m-%d %H:%M:%S')
+ modification_time = california_tz.localize(modification_time)
+ modification_time_utc = modification_time.astimezone(datetime.timezone.utc)
+ modification_timestamp = modification_time_utc.timestamp()
+ os.utime(file_path, (modification_timestamp, modification_timestamp))
+
+ @staticmethod
+ def get_latest_modification_time(path):
+ """
+ Returns the latest modification time of all files in a directory in the California time zone as a string.
+
+ Args:
+ directory_path (str): The path to the directory.
+
+ Returns:
+ str: The latest file's modification time in 'YYYY-MM-DD HH:MM:SS' format.
+ """
+ california_tz = pytz.timezone('America/Los_Angeles')
+ latest_mod_time = None
+
+ file_paths = []
+ if os.path.isdir(path):
+ for root, dirs, files in os.walk(path):
+ for file in files:
+ file_paths.append(os.path.join(root, file))
+ else:
+ file_paths = [path]
+
+ for file_path in file_paths:
+ modification_timestamp = os.path.getmtime(file_path)
+ modification_time_utc = datetime.datetime.utcfromtimestamp(modification_timestamp)
+ modification_time_utc = modification_time_utc.replace(tzinfo=datetime.timezone.utc)
+ modification_time_california = modification_time_utc.astimezone(california_tz)
+
+ if latest_mod_time is None or modification_time_california > latest_mod_time:
+ latest_mod_time = modification_time_california
+
+ if latest_mod_time is not None:
+ return latest_mod_time.strftime('%Y-%m-%d %H:%M:%S')
+ else:
+ return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+
+ @staticmethod
+ def assemble_article_data(article_file_path_dict):
+ """
+ Constructs a dictionary containing the content and metadata of an article
+ based on the available files in the article's directory. This includes the
+ main article text, citations from a JSON file, and a conversation log if
+ available. The function prioritizes a polished version of the article if
+ both a raw and polished version exist.
+
+ Args:
+ article_file_paths (dict): A dictionary where keys are file names relevant
+ to the article (e.g., the article text, citations
+ in JSON format, conversation logs) and values
+ are their corresponding file paths.
+
+ Returns:
+ dict or None: A dictionary containing the parsed content of the article,
+ citations, and conversation log if available. Returns None
+ if neither the raw nor polished article text exists in the
+ provided file paths.
+ """
+ if "storm_gen_article.txt" in article_file_path_dict or "storm_gen_article_polished.txt" in article_file_path_dict:
+ full_article_name = "storm_gen_article_polished.txt" if "storm_gen_article_polished.txt" in article_file_path_dict else "storm_gen_article.txt"
+ article_data = {"article": DemoTextProcessingHelper.parse(
+ DemoFileIOHelper.read_txt_file(article_file_path_dict[full_article_name]))}
+ if "url_to_info.json" in article_file_path_dict:
+ article_data["citations"] = _construct_citation_dict_from_search_result(
+ DemoFileIOHelper.read_json_file(article_file_path_dict["url_to_info.json"]))
+ if "conversation_log.json" in article_file_path_dict:
+ article_data["conversation_log"] = DemoFileIOHelper.read_json_file(
+ article_file_path_dict["conversation_log.json"])
+ return article_data
+ return None
+
+
+class DemoTextProcessingHelper():
+
+ @staticmethod
+ def remove_citations(sent):
+ return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "")
+
+ @staticmethod
+ def parse_conversation_history(json_data):
+ """
+ Given conversation log data, return list of parsed data of following format
+ (persona_name, persona_description, list of dialogue turn)
+ """
+ parsed_data = []
+ for persona_conversation_data in json_data:
+ if ': ' in persona_conversation_data["perspective"]:
+ name, description = persona_conversation_data["perspective"].split(": ", 1)
+ elif '- ' in persona_conversation_data["perspective"]:
+ name, description = persona_conversation_data["perspective"].split("- ", 1)
+ else:
+ name, description = "", persona_conversation_data["perspective"]
+ cur_conversation = []
+ for dialogue_turn in persona_conversation_data["dlg_turns"]:
+ cur_conversation.append({"role": "user", "content": dialogue_turn["user_utterance"]})
+ cur_conversation.append(
+ {"role": "assistant",
+ "content": DemoTextProcessingHelper.remove_citations(dialogue_turn["agent_utterance"])})
+ parsed_data.append((name, description, cur_conversation))
+ return parsed_data
+
+ @staticmethod
+ def parse(text):
+ regex = re.compile(r']:\s+"(.*?)"\s+http')
+ text = regex.sub(']: http', text)
+ return text
+
+ @staticmethod
+ def add_markdown_indentation(input_string):
+ lines = input_string.split('\n')
+ processed_lines = [""]
+ for line in lines:
+ num_hashes = 0
+ for char in line:
+ if char == '#':
+ num_hashes += 1
+ else:
+ break
+ num_hashes -= 1
+ num_spaces = 4 * num_hashes
+ new_line = ' ' * num_spaces + line
+ processed_lines.append(new_line)
+ return '\n'.join(processed_lines)
+
+ @staticmethod
+ def get_current_time_string():
+ """
+ Returns the current time in the California time zone as a string.
+
+ Returns:
+ str: The current California time in 'YYYY-MM-DD HH:MM:SS' format.
+ """
+ california_tz = pytz.timezone('America/Los_Angeles')
+ utc_now = datetime.datetime.now(datetime.timezone.utc)
+ california_now = utc_now.astimezone(california_tz)
+ return california_now.strftime('%Y-%m-%d %H:%M:%S')
+
+ @staticmethod
+ def compare_time_strings(time_string1, time_string2, time_format='%Y-%m-%d %H:%M:%S'):
+ """
+ Compares two time strings to determine if they represent the same point in time.
+
+ Args:
+ time_string1 (str): The first time string to compare.
+ time_string2 (str): The second time string to compare.
+ time_format (str): The format of the time strings, defaults to '%Y-%m-%d %H:%M:%S'.
+
+ Returns:
+ bool: True if the time strings represent the same time, False otherwise.
+ """
+ # Parse the time strings into datetime objects
+ time1 = datetime.datetime.strptime(time_string1, time_format)
+ time2 = datetime.datetime.strptime(time_string2, time_format)
+
+ # Compare the datetime objects
+ return time1 == time2
+
+ @staticmethod
+ def add_inline_citation_link(article_text, citation_dict):
+ # Regular expression to find citations like [i]
+ pattern = r'\[(\d+)\]'
+
+ # Function to replace each citation with its Markdown link
+ def replace_with_link(match):
+ i = match.group(1)
+ url = citation_dict.get(int(i), {}).get('url', '#')
+ return f'[[{i}]]({url})'
+
+ # Replace all citations in the text with Markdown links
+ return re.sub(pattern, replace_with_link, article_text)
+
+ @staticmethod
+ def generate_html_toc(md_text):
+ toc = []
+ for line in md_text.splitlines():
+ if line.startswith("#"):
+ level = line.count("#")
+ title = line.strip("# ").strip()
+ anchor = title.lower().replace(" ", "-").replace(".", "")
+ toc.append(f"{title}")
+ return ""
+
+ @staticmethod
+ def construct_bibliography_from_url_to_info(url_to_info):
+ bibliography_list = []
+ sorted_url_to_unified_index = dict(sorted(url_to_info['url_to_unified_index'].items(),
+ key=lambda item: item[1]))
+ for url, index in sorted_url_to_unified_index.items():
+ title = url_to_info['url_to_info'][url]['title']
+ bibliography_list.append(f"[{index}]: [{title}]({url})")
+ bibliography_string = "\n\n".join(bibliography_list)
+ return f"# References\n\n{bibliography_string}"
+
+
+class DemoUIHelper():
+ def st_markdown_adjust_size(content, font_size=20):
+ st.markdown(f"""
+ {content}
+ """, unsafe_allow_html=True)
+
+ @staticmethod
+ def get_article_card_UI_style(boarder_color="#9AD8E1"):
+ return {
+ "card": {
+ "width": "100%",
+ "height": "116px",
+ "max-width": "640px",
+ "background-color": "#FFFFF",
+ "border": "1px solid #CCC",
+ "padding": "20px",
+ "border-radius": "5px",
+ "border-left": f"0.5rem solid {boarder_color}",
+ "box-shadow": "0 0.15rem 1.75rem 0 rgba(58, 59, 69, 0.15)",
+ "margin": "0px"
+ },
+ "title": {
+ "white-space": "nowrap",
+ "overflow": "hidden",
+ "text-overflow": "ellipsis",
+ "font-size": "17px",
+ "color": "rgb(49, 51, 63)",
+ "text-align": "left",
+ "width": "95%",
+ "font-weight": "normal"
+ },
+ "text": {
+ "white-space": "nowrap",
+ "overflow": "hidden",
+ "text-overflow": "ellipsis",
+ "font-size": "25px",
+ "color": "rgb(49, 51, 63)",
+ "text-align": "left",
+ "width": "95%"
+ },
+ "filter": {
+ "background-color": "rgba(0, 0, 0, 0)"
+ }
+ }
+
+ @staticmethod
+ def customize_toast_css_style():
+ # Note padding is top right bottom left
+ st.markdown(
+ """
+
+ """, unsafe_allow_html=True
+ )
+
+ @staticmethod
+ def article_markdown_to_html(article_title, article_content):
+ return f"""
+
+
+
+ {article_title}
+
+
+
+
+
{article_title.replace('_', ' ')}
+
+ Table of Contents
+ {DemoTextProcessingHelper.generate_html_toc(article_content)}
+ {markdown.markdown(article_content)}
+
+
+ """
+
+
+def _construct_citation_dict_from_search_result(search_results):
+ if search_results is None:
+ return None
+ citation_dict = {}
+ for url, index in search_results['url_to_unified_index'].items():
+ citation_dict[index] = {'url': url,
+ 'title': search_results['url_to_info'][url]['title'],
+ 'snippets': search_results['url_to_info'][url]['snippets']}
+ return citation_dict
+
+
+def _display_main_article_text(article_text, citation_dict, table_content_sidebar):
+ # Post-process the generated article for better display.
+ if "Write the lead section:" in article_text:
+ article_text = article_text[
+ article_text.find("Write the lead section:") + len("Write the lead section:"):]
+ if article_text[0] == '#':
+ article_text = '\n'.join(article_text.split('\n')[1:])
+ article_text = DemoTextProcessingHelper.add_inline_citation_link(article_text, citation_dict)
+ # '$' needs to be changed to '\$' to avoid being interpreted as LaTeX in st.markdown()
+ article_text = article_text.replace("$", "\\$")
+ stoc.from_markdown(article_text, table_content_sidebar)
+
+
+def _display_references(citation_dict):
+ if citation_dict:
+ reference_list = [f"reference [{i}]" for i in range(1, len(citation_dict) + 1)]
+ selected_key = st.selectbox("Select a reference", reference_list)
+ citation_val = citation_dict[reference_list.index(selected_key) + 1]
+ citation_val['title'] = citation_val['title'].replace("$", "\\$")
+ st.markdown(f"**Title:** {citation_val['title']}")
+ st.markdown(f"**Url:** {citation_val['url']}")
+ snippets = '\n\n'.join(citation_val['snippets']).replace("$", "\\$")
+ st.markdown(f"**Highlights:**\n\n {snippets}")
+ else:
+ st.markdown("**No references available**")
+
+
+def _display_persona_conversations(conversation_log):
+ """
+ Display persona conversation in dialogue UI
+ """
+ # get personas list as (persona_name, persona_description, dialogue turns list) tuple
+ parsed_conversation_history = DemoTextProcessingHelper.parse_conversation_history(conversation_log)
+ # construct tabs for each persona conversation
+ persona_tabs = st.tabs([name for (name, _, _) in parsed_conversation_history])
+ for idx, persona_tab in enumerate(persona_tabs):
+ with persona_tab:
+ # show persona description
+ st.info(parsed_conversation_history[idx][1])
+ # show user / agent utterance in dialogue UI
+ for message in parsed_conversation_history[idx][2]:
+ message['content'] = message['content'].replace("$", "\\$")
+ with st.chat_message(message["role"]):
+ if message["role"] == "user":
+ st.markdown(f"**{message['content']}**")
+ else:
+ st.markdown(message["content"])
+
+
+def _display_main_article(selected_article_file_path_dict, show_reference=True, show_conversation=True):
+ article_data = DemoFileIOHelper.assemble_article_data(selected_article_file_path_dict)
+
+ with st.container(height=1000, border=True):
+ table_content_sidebar = st.sidebar.expander("**Table of contents**", expanded=True)
+ _display_main_article_text(article_text=article_data.get("article", ""),
+ citation_dict=article_data.get("citations", {}),
+ table_content_sidebar=table_content_sidebar)
+
+ # display reference panel
+ if show_reference and "citations" in article_data:
+ with st.sidebar.expander("**References**", expanded=True):
+ with st.container(height=800, border=False):
+ _display_references(citation_dict=article_data.get("citations", {}))
+
+ # display conversation history
+ if show_conversation and "conversation_log" in article_data:
+ with st.expander(
+ "**STORM** is powered by a knowledge agent that proactively research a given topic by asking good questions coming from different perspectives.\n\n"
+ ":sunglasses: Click here to view the agent's brain**STORM**ing process!"):
+ _display_persona_conversations(conversation_log=article_data.get("conversation_log", {}))
+
+
+def get_demo_dir():
+ return os.path.dirname(os.path.abspath(__file__))
+
+
+def clear_other_page_session_state(page_index: Optional[int]):
+ if page_index is None:
+ keys_to_delete = [key for key in st.session_state if key.startswith("page")]
+ else:
+ keys_to_delete = [key for key in st.session_state if key.startswith("page") and f"page{page_index}" not in key]
+ for key in set(keys_to_delete):
+ del st.session_state[key]
+
+
+def set_storm_runner():
+ current_working_dir = os.path.join(get_demo_dir(), "DEMO_WORKING_DIR")
+ if not os.path.exists(current_working_dir):
+ os.makedirs(current_working_dir)
+
+ # configure STORM runner
+ llm_configs = STORMWikiLMConfigs()
+ llm_configs.init_openai_model(openai_api_key=st.secrets['OPENAI_API_KEY'], openai_type='openai')
+ llm_configs.set_question_asker_lm(OpenAIModel(model='gpt-4-1106-preview', api_key=st.secrets['OPENAI_API_KEY'],
+ api_provider='openai',
+ max_tokens=500, temperature=1.0, top_p=0.9))
+ engine_args = STORMWikiRunnerArguments(
+ output_dir=current_working_dir,
+ max_conv_turn=3,
+ max_perspective=3,
+ search_top_k=3,
+ retrieve_top_k=5
+ )
+
+ rm = YouRM(ydc_api_key=st.secrets['YDC_API_KEY'], k=engine_args.search_top_k)
+
+ runner = STORMWikiRunner(engine_args, llm_configs, rm)
+ st.session_state["runner"] = runner
+
+
+def display_article_page(selected_article_name, selected_article_file_path_dict,
+ show_title=True, show_main_article=True):
+ if show_title:
+ st.markdown(f"{selected_article_name.replace('_', ' ')}
",
+ unsafe_allow_html=True)
+
+ if show_main_article:
+ _display_main_article(selected_article_file_path_dict)
+
+
+
+class StreamlitCallbackHandler(BaseCallbackHandler):
+ def __init__(self, status_container):
+ self.status_container = status_container
+
+ def on_identify_perspective_start(self, **kwargs):
+ self.status_container.info('Start identifying different perspectives for researching the topic.')
+
+ def on_identify_perspective_end(self, perspectives: list[str], **kwargs):
+ perspective_list = "\n- ".join(perspectives)
+ self.status_container.success(f'Finish identifying perspectives. Will now start gathering information'
+ f' from the following perspectives:\n- {perspective_list}')
+
+ def on_information_gathering_start(self, **kwargs):
+ self.status_container.info('Start browsing the Internet.')
+
+ def on_dialogue_turn_end(self, dlg_turn, **kwargs):
+ urls = list(set([r.url for r in dlg_turn.search_results]))
+ for url in urls:
+ self.status_container.markdown(f"""
+
+
+ """, unsafe_allow_html=True)
+
+ def on_information_gathering_end(self, **kwargs):
+ self.status_container.success('Finish collecting information.')
+
+ def on_information_organization_start(self, **kwargs):
+ self.status_container.info('Start organizing information into a hierarchical outline.')
+
+ def on_direct_outline_generation_end(self, outline: str, **kwargs):
+ self.status_container.success(f'Finish leveraging the internal knowledge of the large language model.')
+
+ def on_outline_refinement_end(self, outline: str, **kwargs):
+ self.status_container.success(f'Finish leveraging the collected information.')
diff --git a/frontend/demo_light/pages_util/CreateNewArticle.py b/frontend/demo_light/pages_util/CreateNewArticle.py
new file mode 100644
index 00000000..9495ffe5
--- /dev/null
+++ b/frontend/demo_light/pages_util/CreateNewArticle.py
@@ -0,0 +1,102 @@
+import os
+import time
+
+import demo_util
+import streamlit as st
+from demo_util import DemoFileIOHelper, DemoTextProcessingHelper, DemoUIHelper
+
+
+def create_new_article_page():
+ demo_util.clear_other_page_session_state(page_index=3)
+
+ if "page3_write_article_state" not in st.session_state:
+ st.session_state["page3_write_article_state"] = "not started"
+
+ if st.session_state["page3_write_article_state"] == "not started":
+
+ _, search_form_column, _ = st.columns([2, 5, 2])
+ with search_form_column:
+ with st.form(key='search_form'):
+ # Text input for the search topic
+ DemoUIHelper.st_markdown_adjust_size(content="Enter the topic you want to learn in depth:",
+ font_size=18)
+ st.session_state["page3_topic"] = st.text_input(label='page3_topic', label_visibility="collapsed")
+ pass_appropriateness_check = True
+
+ # Submit button for the form
+ submit_button = st.form_submit_button(label='Research')
+ # only start new search when button is clicked, not started, or already finished previous one
+ if submit_button and st.session_state["page3_write_article_state"] in ["not started", "show results"]:
+ if not st.session_state["page3_topic"].strip():
+ pass_appropriateness_check = False
+ st.session_state["page3_warning_message"] = "topic could not be empty"
+
+ st.session_state["page3_topic_name_cleaned"] = st.session_state["page3_topic"].replace(
+ ' ', '_').replace('/', '_')
+ if not pass_appropriateness_check:
+ st.session_state["page3_write_article_state"] = "not started"
+ alert = st.warning(st.session_state["page3_warning_message"], icon="⚠️")
+ time.sleep(5)
+ alert.empty()
+ else:
+ st.session_state["page3_write_article_state"] = "initiated"
+
+ if st.session_state["page3_write_article_state"] == "initiated":
+ current_working_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR")
+ if not os.path.exists(current_working_dir):
+ os.makedirs(current_working_dir)
+
+ if "runner" not in st.session_state:
+ demo_util.set_storm_runner()
+ st.session_state["page3_current_working_dir"] = current_working_dir
+ st.session_state["page3_write_article_state"] = "pre_writing"
+
+ if st.session_state["page3_write_article_state"] == "pre_writing":
+ status = st.status("I am brain**STORM**ing now to research the topic. (This may take 2-3 minutes.)")
+ st_callback_handler = demo_util.StreamlitCallbackHandler(status)
+ with status:
+ # STORM main gen outline
+ st.session_state["runner"].run(
+ topic=st.session_state["page3_topic"],
+ do_research=True,
+ do_generate_outline=True,
+ do_generate_article=False,
+ do_polish_article=False,
+ callback_handler=st_callback_handler
+ )
+ conversation_log_path = os.path.join(st.session_state["page3_current_working_dir"],
+ st.session_state["page3_topic_name_cleaned"], "conversation_log.json")
+ demo_util._display_persona_conversations(DemoFileIOHelper.read_json_file(conversation_log_path))
+ st.session_state["page3_write_article_state"] = "final_writing"
+ status.update(label="brain**STORM**ing complete!", state="complete")
+
+ if st.session_state["page3_write_article_state"] == "final_writing":
+ # polish final article
+ with st.status(
+ "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)") as status:
+ st.info('Now I will connect the information I found for your reference. (This may take 4-5 minutes.)')
+ st.session_state["runner"].run(topic=st.session_state["page3_topic"], do_research=False,
+ do_generate_outline=False,
+ do_generate_article=True, do_polish_article=True, remove_duplicate=False)
+ # finish the session
+ st.session_state["runner"].post_run()
+
+ # update status bar
+ st.session_state["page3_write_article_state"] = "prepare_to_show_result"
+ status.update(label="information snythesis complete!", state="complete")
+
+ if st.session_state["page3_write_article_state"] == "prepare_to_show_result":
+ _, show_result_col, _ = st.columns([4, 3, 4])
+ with show_result_col:
+ if st.button("show final article"):
+ st.session_state["page3_write_article_state"] = "completed"
+ st.rerun()
+
+ if st.session_state["page3_write_article_state"] == "completed":
+ # display polished article
+ current_working_dir_paths = DemoFileIOHelper.read_structure_to_dict(
+ st.session_state["page3_current_working_dir"])
+ current_article_file_path_dict = current_working_dir_paths[st.session_state["page3_topic_name_cleaned"]]
+ demo_util.display_article_page(selected_article_name=st.session_state["page3_topic_name_cleaned"],
+ selected_article_file_path_dict=current_article_file_path_dict,
+ show_title=True, show_main_article=True)
diff --git a/frontend/demo_light/pages_util/MyArticles.py b/frontend/demo_light/pages_util/MyArticles.py
new file mode 100644
index 00000000..f2c306f2
--- /dev/null
+++ b/frontend/demo_light/pages_util/MyArticles.py
@@ -0,0 +1,89 @@
+import os
+
+import demo_util
+import streamlit as st
+from demo_util import DemoFileIOHelper, DemoUIHelper
+from streamlit_card import card
+
+
+# set page config and display title
+def my_articles_page():
+ with st.sidebar:
+ _, return_button_col = st.columns([2, 5])
+ with return_button_col:
+ if st.button("Select another article", disabled="page2_selected_my_article" not in st.session_state):
+ if "page2_selected_my_article" in st.session_state:
+ del st.session_state["page2_selected_my_article"]
+ st.rerun()
+
+ # sync my articles
+ if "page2_user_articles_file_path_dict" not in st.session_state:
+ local_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR")
+ os.makedirs(local_dir, exist_ok=True)
+ st.session_state["page2_user_articles_file_path_dict"] = DemoFileIOHelper.read_structure_to_dict(local_dir)
+
+ # if no feature demo selected, display all featured articles as info cards
+ def article_card_setup(column_to_add, card_title, article_name):
+ with column_to_add:
+ cleaned_article_title = article_name.replace("_", " ")
+ hasClicked = card(title=" / ".join(card_title),
+ text=article_name.replace("_", " "),
+ image=DemoFileIOHelper.read_image_as_base64(
+ os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg")),
+ styles=DemoUIHelper.get_article_card_UI_style(boarder_color="#9AD8E1"))
+ if hasClicked:
+ st.session_state["page2_selected_my_article"] = article_name
+ st.rerun()
+
+ if "page2_selected_my_article" not in st.session_state:
+ # display article cards
+ my_article_columns = st.columns(3)
+ if len(st.session_state["page2_user_articles_file_path_dict"]) > 0:
+ # get article names
+ article_names = sorted(list(st.session_state["page2_user_articles_file_path_dict"].keys()))
+ # configure pagination
+ pagination = st.container()
+ bottom_menu = st.columns((1, 4, 1, 1, 1))[1:-1]
+ with bottom_menu[2]:
+ batch_size = st.selectbox("Page Size", options=[24, 48, 72])
+ with bottom_menu[1]:
+ total_pages = (
+ int(len(article_names) / batch_size) if int(len(article_names) / batch_size) > 0 else 1
+ )
+ current_page = st.number_input(
+ "Page", min_value=1, max_value=total_pages, step=1
+ )
+ with bottom_menu[0]:
+ st.markdown(f"Page **{current_page}** of **{total_pages}** ")
+ # show article cards
+ with pagination:
+ my_article_count = 0
+ start_index = (current_page - 1) * batch_size
+ end_index = min(current_page * batch_size, len(article_names))
+ for article_name in article_names[start_index: end_index]:
+ column_to_add = my_article_columns[my_article_count % 3]
+ my_article_count += 1
+ article_card_setup(column_to_add=column_to_add,
+ card_title=["My Article"],
+ article_name=article_name)
+ else:
+ with my_article_columns[0]:
+ hasClicked = card(title="Get started",
+ text="Start your first research!",
+ image=DemoFileIOHelper.read_image_as_base64(
+ os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg")),
+ styles=DemoUIHelper.get_article_card_UI_style())
+ if hasClicked:
+ st.session_state.selected_page = 1
+ st.session_state["manual_selection_override"] = True
+ st.session_state["rerun_requested"] = True
+ st.rerun()
+ else:
+ selected_article_name = st.session_state["page2_selected_my_article"]
+ selected_article_file_path_dict = st.session_state["page2_user_articles_file_path_dict"][selected_article_name]
+
+ demo_util.display_article_page(selected_article_name=selected_article_name,
+ selected_article_file_path_dict=selected_article_file_path_dict,
+ show_title=True, show_main_article=True,
+ show_feedback_form=False,
+ show_qa_panel=False)
diff --git a/frontend/demo_light/requirements.txt b/frontend/demo_light/requirements.txt
new file mode 100644
index 00000000..a5ba2884
--- /dev/null
+++ b/frontend/demo_light/requirements.txt
@@ -0,0 +1,10 @@
+streamlit==1.31.1
+streamlit-card
+markdown
+unidecode
+extra-streamlit-components==0.1.60
+streamlit_extras
+deprecation==2.1.0
+st-pages==0.4.5
+streamlit-float
+streamlit-option-menu
\ No newline at end of file
diff --git a/frontend/demo_light/stoc.py b/frontend/demo_light/stoc.py
new file mode 100644
index 00000000..7bd4402b
--- /dev/null
+++ b/frontend/demo_light/stoc.py
@@ -0,0 +1,131 @@
+"""https://github.com/arnaudmiribel/stoc"""
+
+import re
+
+import streamlit as st
+import unidecode
+
+DISABLE_LINK_CSS = """
+"""
+
+
+class stoc:
+ def __init__(self):
+ self.toc_items = list()
+
+ def h1(self, text: str, write: bool = True):
+ if write:
+ st.write(f"# {text}")
+ self.toc_items.append(("h1", text))
+
+ def h2(self, text: str, write: bool = True):
+ if write:
+ st.write(f"## {text}")
+ self.toc_items.append(("h2", text))
+
+ def h3(self, text: str, write: bool = True):
+ if write:
+ st.write(f"### {text}")
+ self.toc_items.append(("h3", text))
+
+ def toc(self, expander):
+ st.write(DISABLE_LINK_CSS, unsafe_allow_html=True)
+ # st.sidebar.caption("Table of contents")
+ if expander is None:
+ expander = st.sidebar.expander("**Table of contents**", expanded=True)
+ with expander:
+ with st.container(height=600, border=False):
+ markdown_toc = ""
+ for title_size, title in self.toc_items:
+ h = int(title_size.replace("h", ""))
+ markdown_toc += (
+ " " * 2 * h
+ + "- "
+ + f' {title} \n'
+ )
+ # st.sidebar.write(markdown_toc, unsafe_allow_html=True)
+ st.write(markdown_toc, unsafe_allow_html=True)
+
+ @classmethod
+ def get_toc(cls, markdown_text: str, topic=""):
+ def increase_heading_depth_and_add_top_heading(markdown_text, new_top_heading):
+ lines = markdown_text.splitlines()
+ # Increase the depth of each heading by adding an extra '#'
+ increased_depth_lines = ['#' + line if line.startswith('#') else line for line in lines]
+ # Add the new top-level heading at the beginning
+ increased_depth_lines.insert(0, f"# {new_top_heading}")
+ # Re-join the modified lines back into a single string
+ modified_text = '\n'.join(increased_depth_lines)
+ return modified_text
+
+ if topic:
+ markdown_text = increase_heading_depth_and_add_top_heading(markdown_text, topic)
+ toc = []
+ for line in markdown_text.splitlines():
+ if line.startswith('#'):
+ # Remove the '#' characters and strip leading/trailing spaces
+ heading_text = line.lstrip('#').strip()
+ # Create slug (lowercase, spaces to hyphens, remove non-alphanumeric characters)
+ slug = re.sub(r'[^a-zA-Z0-9\s-]', '', heading_text).lower().replace(' ', '-')
+ # Determine heading level for indentation
+ level = line.count('#') - 1
+ # Add to the table of contents
+ toc.append(' ' * level + f'- [{heading_text}](#{slug})')
+ return '\n'.join(toc)
+
+ @classmethod
+ def from_markdown(cls, text: str, expander=None):
+ self = cls()
+ for line in text.splitlines():
+ if line.startswith("###"):
+ self.h3(line[3:], write=False)
+ elif line.startswith("##"):
+ self.h2(line[2:], write=False)
+ elif line.startswith("#"):
+ self.h1(line[1:], write=False)
+ # customize markdown font size
+ custom_css = """
+
+ """
+ st.markdown(custom_css, unsafe_allow_html=True)
+
+ st.write(text)
+ self.toc(expander=expander)
+
+
+def normalize(s):
+ """
+ Normalize titles as valid HTML ids for anchors
+ >>> normalize("it's a test to spot how Things happ3n héhé")
+ "it-s-a-test-to-spot-how-things-happ3n-h-h"
+ """
+
+ # Replace accents with "-"
+ s_wo_accents = unidecode.unidecode(s)
+ accents = [s for s in s if s not in s_wo_accents]
+ for accent in accents:
+ s = s.replace(accent, "-")
+
+ # Lowercase
+ s = s.lower()
+
+ # Keep only alphanum and remove "-" suffix if existing
+ normalized = (
+ "".join([char if char.isalnum() else "-" for char in s]).strip("-").lower()
+ )
+
+ return normalized
diff --git a/frontend/demo_light/storm.py b/frontend/demo_light/storm.py
new file mode 100644
index 00000000..9a0ae663
--- /dev/null
+++ b/frontend/demo_light/storm.py
@@ -0,0 +1,64 @@
+import os
+import sys
+
+script_dir = os.path.dirname(os.path.abspath(__file__))
+wiki_root_dir = os.path.dirname(os.path.dirname(script_dir))
+
+sys.path.append(os.path.normpath(os.path.join(script_dir, '../../src/storm_wiki')))
+sys.path.append(os.path.normpath(os.path.join(script_dir, '../../src')))
+
+import demo_util
+from pages_util import MyArticles, CreateNewArticle
+from streamlit_float import *
+from streamlit_option_menu import option_menu
+
+
+def main():
+ global database
+ st.set_page_config(layout='wide')
+
+ if "first_run" not in st.session_state:
+ st.session_state['first_run'] = True
+
+ # set api keys from secrets
+ if st.session_state['first_run']:
+ for key, value in st.secrets.items():
+ if type(value) == str:
+ os.environ[key] = value
+
+ # initialize session_state
+ if "selected_article_index" not in st.session_state:
+ st.session_state["selected_article_index"] = 0
+ if "selected_page" not in st.session_state:
+ st.session_state["selected_page"] = 0
+ if st.session_state.get("rerun_requested", False):
+ st.session_state["rerun_requested"] = False
+ st.rerun()
+
+ st.write('', unsafe_allow_html=True)
+ menu_container = st.container()
+ with menu_container:
+ pages = ["My Articles", "Create New Article"]
+ menu_selection = option_menu(None, pages,
+ icons=['house', 'search'],
+ menu_icon="cast", default_index=0, orientation="horizontal",
+ manual_select=st.session_state.selected_page,
+ styles={
+ "container": {"padding": "0.2rem 0", "background-color": "#22222200"},
+ },
+ key='menu_selection')
+ if st.session_state.get("manual_selection_override", False):
+ menu_selection = pages[st.session_state["selected_page"]]
+ st.session_state["manual_selection_override"] = False
+ st.session_state["selected_page"] = None
+
+ if menu_selection == "My Articles":
+ demo_util.clear_other_page_session_state(page_index=2)
+ MyArticles.my_articles_page()
+ elif menu_selection == "Create New Article":
+ demo_util.clear_other_page_session_state(page_index=3)
+ CreateNewArticle.create_new_article_page()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/storm_wiki/engine.py b/src/storm_wiki/engine.py
index 66fa4b05..4191d35a 100644
--- a/src/storm_wiki/engine.py
+++ b/src/storm_wiki/engine.py
@@ -48,24 +48,9 @@ def init_openai_model(
'api_provider': openai_type,
'temperature': temperature,
'top_p': top_p,
- 'api_base': None,
- 'api_version': None,
+ 'api_base': None
}
- if openai_type and openai_type == 'azure':
- openai_kwargs['api_base'] = api_base
- openai_kwargs['api_version'] = api_version
-
- self.conv_simulator_lm = OpenAIModel(model='gpt-35-turbo-instruct', engine='gpt-35-turbo-instruct',
- max_tokens=500, **openai_kwargs)
- self.question_asker_lm = OpenAIModel(model='gpt-35-turbo', engine='gpt-35-turbo',
- max_tokens=500, **openai_kwargs)
- self.outline_gen_lm = OpenAIModel(model='gpt-4', engine='gpt-4',
- max_tokens=400, **openai_kwargs)
- self.article_gen_lm = OpenAIModel(model='gpt-4', engine='gpt-4',
- max_tokens=700, **openai_kwargs)
- self.article_polish_lm = OpenAIModel(model='gpt-4-32k', engine='gpt-4-32k',
- max_tokens=4000, **openai_kwargs)
- elif openai_type and openai_type == 'openai':
+ if openai_type and openai_type == 'openai':
self.conv_simulator_lm = OpenAIModel(model='gpt-3.5-turbo-instruct',
max_tokens=500, **openai_kwargs)
self.question_asker_lm = OpenAIModel(model='gpt-3.5-turbo',
@@ -73,9 +58,9 @@ def init_openai_model(
# 1/12/2024: Update gpt-4 to gpt-4-1106-preview. (Currently keep the original setup when using azure.)
self.outline_gen_lm = OpenAIModel(model='gpt-4-0125-preview',
max_tokens=400, **openai_kwargs)
- self.article_gen_lm = OpenAIModel(model='gpt-4-0125-preview',
+ self.article_gen_lm = OpenAIModel(model='gpt-4o-2024-05-13',
max_tokens=700, **openai_kwargs)
- self.article_polish_lm = OpenAIModel(model='gpt-4-0125-preview',
+ self.article_polish_lm = OpenAIModel(model='gpt-4o-2024-05-13',
max_tokens=4000, **openai_kwargs)
else:
logging.warning('No valid OpenAI API provider is provided. Cannot use default LLM configurations.')