From 84f0dc336098f3fb8ff11ad0fe7565a8ddefe5c9 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Tue, 17 Oct 2023 13:51:54 -0600 Subject: [PATCH 1/6] added ignore error kwarg callable to bypass some errors --- elm/base.py | 41 ++++++++++++++++++++++++++++++++++++----- elm/pdf.py | 37 ++++++++++++++++++++++--------------- elm/summary.py | 13 +++++++++++-- 3 files changed, 69 insertions(+), 22 deletions(-) diff --git a/elm/base.py b/elm/base.py index 435a9d5b..b97f47c1 100644 --- a/elm/base.py +++ b/elm/base.py @@ -96,7 +96,7 @@ async def call_api(url, headers, request_json): return out async def call_api_async(self, url, headers, all_request_jsons, - rate_limit=40e3): + ignore_error=None, rate_limit=40e3): """Use GPT to clean raw pdf text in parallel calls to the OpenAI API. NOTE: you need to call this using the await command in ipython or @@ -119,6 +119,10 @@ async def call_api_async(self, url, headers, all_request_jsons, "messages": [{"role": "system", "content": "You do this..."}, {"role": "user", "content": "Do this: {}"}], "temperature": 0.0} + ignore_error : None | callable + Optional callable to parse API error string. If the callable + returns True, the error will be ignored, the API call will not be + tried again, and the output will be an empty string. rate_limit : float OpenAI API rate limit (tokens / minute). Note that the gpt-3.5-turbo limit is 90k as of 4/2023, but we're using a large @@ -132,6 +136,7 @@ async def call_api_async(self, url, headers, all_request_jsons, corresponding message in the all_request_jsons input. """ self.api_queue = ApiQueue(url, headers, all_request_jsons, + ignore_error=ignore_error, rate_limit=rate_limit) out = await self.api_queue.run() return out @@ -207,7 +212,8 @@ def generic_query(self, query, model_role=None, temperature=0): return response async def generic_async_query(self, queries, model_role=None, - temperature=0, rate_limit=40e3): + temperature=0, ignore_error=None, + rate_limit=40e3): """Run a number of generic single queries asynchronously (not conversational) @@ -225,6 +231,10 @@ async def generic_async_query(self, queries, model_role=None, GPT model temperature, a measure of response entropy from 0 to 1. 0 is more reliable and nearly deterministic; 1 will give the model more creative freedom and may not return as factual of results. + ignore_error : None | callable + Optional callable to parse API error string. If the callable + returns True, the error will be ignored, the API call will not be + tried again, and the output will be an empty string. rate_limit : float OpenAI API rate limit (tokens / minute). Note that the gpt-3.5-turbo limit is 90k as of 4/2023, but we're using a large @@ -247,6 +257,7 @@ async def generic_async_query(self, queries, model_role=None, all_request_jsons.append(req) self.api_queue = ApiQueue(self.URL, self.HEADERS, all_request_jsons, + ignore_error=ignore_error, rate_limit=rate_limit) out = await self.api_queue.run() @@ -324,7 +335,8 @@ def count_tokens(text, model): class ApiQueue: """Class to manage the parallel API queue and submission""" - def __init__(self, url, headers, request_jsons, rate_limit=40e3): + def __init__(self, url, headers, request_jsons, ignore_error=None, + rate_limit=40e3): """ Parameters ---------- @@ -343,6 +355,10 @@ def __init__(self, url, headers, request_jsons, rate_limit=40e3): "messages": [{"role": "system", "content": "You do this..."}, {"role": "user", "content": "Do this: {}"}], "temperature": 0.0} + ignore_error : None | callable + Optional callable to parse API error string. If the callable + returns True, the error will be ignored, the API call will not be + tried again, and the output will be an empty string. rate_limit : float OpenAI API rate limit (tokens / minute). Note that the gpt-3.5-turbo limit is 90k as of 4/2023, but we're using a large @@ -353,11 +369,13 @@ def __init__(self, url, headers, request_jsons, rate_limit=40e3): self.url = url self.headers = headers self.request_jsons = request_jsons + self.ignore_error = ignore_error self.rate_limit = rate_limit self.api_jobs = {} self.todo = [True] * len(self) self.out = [None] * len(self) + self.errors = [None] * len(self) def __len__(self): """Number of API calls to submit""" @@ -401,8 +419,21 @@ async def collect_jobs(self): task_out = await self.api_jobs[i] if 'error' in task_out: - logger.error('Received API error for task #{}: {}' - .format(i + 1, task_out)) + msg = ('Received API error for task #{0} ' + '(see `ApiQueue.errors[{1}]` and ' + '`ApiQueue.request_jsons[{1}]` for more details). ' + 'Error message: {2}'.format(i + 1, i, task_out)) + self.errors[i] = 'Error: {}'.format(task_out) + if (self.ignore_error is not None + and self.ignore_error(str(task_out))): + msg += ' Ignoring error and moving on.' + dummy = {'choices': [{'message': {'content': ''}}]} + self.out[i] = dummy + self.todo[i] = False + complete = len(self) - sum(self.todo) + else: + msg += ' Retrying query.' + logger.error(msg) else: self.out[i] = task_out self.todo[i] = False diff --git a/elm/pdf.py b/elm/pdf.py index 4f9e60b5..0c13d64b 100644 --- a/elm/pdf.py +++ b/elm/pdf.py @@ -16,6 +16,18 @@ class PDFtoTXT(ApiBase): """Class to parse text from a PDF document.""" + MODEL_ROLE = ('You clean up poorly formatted text ' + 'extracted from PDF documents.') + """High level model role.""" + + MODEL_INSTRUCTION = ('Text extracted from a PDF: ' + '\n"""\n{}\n"""\n\n' + 'The text above was extracted from a PDF document. ' + 'Can you make it nicely formatted? ' + 'Please only return the formatted text ' + 'without comments or added information.') + """Instructions to the model with python format braces for pdf text""" + def __init__(self, fp, page_range=None, model=None): """ Parameters @@ -68,14 +80,11 @@ def load_pdf(self, page_range): .format(i + 1 + page_range.start, len(pdf.pages))) else: out.append(page_text) - logger.debug('Loaded page {} out of {}' - .format(i + 1 + page_range.start, len(pdf.pages))) logger.info('Finished loading PDF.') return out - @staticmethod - def make_gpt_messages(pdf_raw_text): + def make_gpt_messages(self, pdf_raw_text): """Make the chat completion messages list for input to GPT Parameters @@ -91,16 +100,9 @@ def make_gpt_messages(pdf_raw_text): [{"role": "system", "content": "You do this..."}, {"role": "user", "content": "Please do this: {}"}] """ - query = ('Text extracted from a PDF: ' - '\"\"\"\n{}\"\"\"\n\n' - 'The text above was extracted from a PDF document. ' - 'Can you make it nicely formatted? ' - 'Please only return the formatted text, nothing else.' - .format(pdf_raw_text)) - - role_str = ('You clean up poorly formatted text ' - 'extracted from PDF documents.') - messages = [{"role": "system", "content": role_str}, + + query = self.MODEL_INSTRUCTION.format(pdf_raw_text) + messages = [{"role": "system", "content": self.MODEL_ROLE}, {"role": "user", "content": query}] return messages @@ -147,7 +149,7 @@ def clean_txt(self): return clean_pages - async def clean_txt_async(self, rate_limit=40e3): + async def clean_txt_async(self, ignore_error=None, rate_limit=40e3): """Use GPT to clean raw pdf text in parallel calls to the OpenAI API. NOTE: you need to call this using the await command in ipython or @@ -155,6 +157,10 @@ async def clean_txt_async(self, rate_limit=40e3): Parameters ---------- + ignore_error : None | callable + Optional callable to parse API error string. If the callable + returns True, the error will be ignored, the API call will not be + tried again, and the output will be an empty string. rate_limit : float OpenAI API rate limit (tokens / minute). Note that the gpt-3.5-turbo limit is 90k as of 4/2023, but we're using a large @@ -178,6 +184,7 @@ async def clean_txt_async(self, rate_limit=40e3): clean_pages = await self.call_api_async(self.URL, self.HEADERS, all_request_jsons, + ignore_error=ignore_error, rate_limit=rate_limit) for i, page in enumerate(clean_pages): diff --git a/elm/summary.py b/elm/summary.py index bbef4274..bc8a7a84 100644 --- a/elm/summary.py +++ b/elm/summary.py @@ -99,6 +99,8 @@ def run(self, temperature=0, fancy_combine=True): Summary of text. """ + logger.info('Summarizing {} text chunks in serial...' + .format(len(self.text_chunks))) summary = '' for i, chunk in enumerate(self.text_chunks): @@ -115,10 +117,12 @@ def run(self, temperature=0, fancy_combine=True): if fancy_combine: summary = self.combine(summary) + logger.info('Finished all summaries.') + return summary - async def run_async(self, temperature=0, rate_limit=40e3, - fancy_combine=True): + async def run_async(self, temperature=0, ignore_error=None, + rate_limit=40e3, fancy_combine=True): """Run text summary asynchronously for all text chunks NOTE: you need to call this using the await command in ipython or @@ -130,6 +134,10 @@ async def run_async(self, temperature=0, rate_limit=40e3, GPT model temperature, a measure of response entropy from 0 to 1. 0 is more reliable and nearly deterministic; 1 will give the model more creative freedom and may not return as factual of results. + ignore_error : None | callable + Optional callable to parse API error string. If the callable + returns True, the error will be ignored, the API call will not be + tried again, and the output will be an empty string. rate_limit : float OpenAI API rate limit (tokens / minute). Note that the gpt-3.5-turbo limit is 90k as of 4/2023, but we're using a large @@ -157,6 +165,7 @@ async def run_async(self, temperature=0, rate_limit=40e3, summaries = await self.generic_async_query(queries, model_role=self.MODEL_ROLE, temperature=temperature, + ignore_error=ignore_error, rate_limit=rate_limit) self.summary_chunks = summaries From e4c515b1aa4caef11faef55f0c147ab48182d9bf Mon Sep 17 00:00:00 2001 From: grantbuster Date: Wed, 18 Oct 2023 13:19:50 -0600 Subject: [PATCH 2/6] added max retry counter --- elm/base.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/elm/base.py b/elm/base.py index b97f47c1..39ca1ab4 100644 --- a/elm/base.py +++ b/elm/base.py @@ -3,6 +3,7 @@ ELM abstract class for API calls """ from abc import ABC +import numpy as np import asyncio import aiohttp import openai @@ -336,7 +337,7 @@ class ApiQueue: """Class to manage the parallel API queue and submission""" def __init__(self, url, headers, request_jsons, ignore_error=None, - rate_limit=40e3): + rate_limit=40e3, max_retries=5): """ Parameters ---------- @@ -364,6 +365,9 @@ def __init__(self, url, headers, request_jsons, ignore_error=None, gpt-3.5-turbo limit is 90k as of 4/2023, but we're using a large factor of safety (~1/2) because we can only count the tokens on the input side and assume the output is about the same count. + max_retries : int + Number of times to retry an API call with an error response before + raising an error. """ self.url = url @@ -371,11 +375,13 @@ def __init__(self, url, headers, request_jsons, ignore_error=None, self.request_jsons = request_jsons self.ignore_error = ignore_error self.rate_limit = rate_limit + self.max_retries = max_retries self.api_jobs = {} self.todo = [True] * len(self) self.out = [None] * len(self) self.errors = [None] * len(self) + self.tries = np.zeros(len(self)) def __len__(self): """Number of API calls to submit""" @@ -398,12 +404,14 @@ def submit_jobs(self): self.headers, request)) self.api_jobs[i] = task + self.tries[i] += 1 logger.debug('Submitted {} out of {}, ' 'token count is at {} ' - '(rate limit is {})' + '(rate limit is {}). ' + 'Max attempts for a job is {}' .format(i + 1, len(self), token_count, - self.rate_limit)) + self.rate_limit, int(self.tries.max()))) elif token_count >= self.rate_limit: token_count = 0 @@ -424,6 +432,7 @@ async def collect_jobs(self): '`ApiQueue.request_jsons[{1}]` for more details). ' 'Error message: {2}'.format(i + 1, i, task_out)) self.errors[i] = 'Error: {}'.format(task_out) + if (self.ignore_error is not None and self.ignore_error(str(task_out))): msg += ' Ignoring error and moving on.' @@ -433,7 +442,9 @@ async def collect_jobs(self): complete = len(self) - sum(self.todo) else: msg += ' Retrying query.' + logger.error(msg) + else: self.out[i] = task_out self.todo[i] = False @@ -454,8 +465,21 @@ async def run(self): logger.debug('Submitting async API calls...') + self.api_jobs = {} + self.todo = [True] * len(self) + self.out = [None] * len(self) + self.errors = [None] * len(self) + self.tries = np.zeros(len(self)) + while any(self.todo): self.submit_jobs() await self.collect_jobs() + if any(self.tries > self.max_retries): + msg = (f'Hit {self.max_retries} retries on API queries. ' + 'Stopping. See `ApiQueue.errors` for more ' + 'details on error response') + logger.error(msg) + raise RuntimeError(msg) + return self.out From 27fa2dc3a3cbf5959066fb92364fcd548c6fc77b Mon Sep 17 00:00:00 2001 From: grantbuster Date: Wed, 18 Oct 2023 13:21:30 -0600 Subject: [PATCH 3/6] set comprehension --- elm/pdf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elm/pdf.py b/elm/pdf.py index 0c13d64b..63f22264 100644 --- a/elm/pdf.py +++ b/elm/pdf.py @@ -219,8 +219,8 @@ def replace_chars_for_clean(text): raw_words = replace_chars_for_clean(raw).split(' ') clean_words = replace_chars_for_clean(clean).split(' ') - raw_words = set([x for x in raw_words if len(x) > 2]) - clean_words = set([x for x in clean_words if len(x) > 2]) + raw_words = {x for x in raw_words if len(x) > 2} + clean_words = {x for x in clean_words if len(x) > 2} isin = sum(x in clean_words for x in raw_words) From 3fb234e7cae9477f19a62af50d3f6427328354cf Mon Sep 17 00:00:00 2001 From: grantbuster Date: Thu, 19 Oct 2023 15:30:08 -0600 Subject: [PATCH 4/6] added chat feature to ewiz --- elm/base.py | 7 ++++++- elm/wizard.py | 35 +++++++++++++++++++++++++++++++---- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/elm/base.py b/elm/base.py index 39ca1ab4..e63c21ec 100644 --- a/elm/base.py +++ b/elm/base.py @@ -49,8 +49,13 @@ def __init__(self, model=None): cls.DEFAULT_MODEL """ self.model = model or self.DEFAULT_MODEL - self.chat_messages = [{"role": "system", "content": self.MODEL_ROLE}] self.api_queue = None + self.chat_messages = [] + self.clear() + + def clear(self): + """Clear chat""" + self.chat_messages = [{"role": "system", "content": self.MODEL_ROLE}] @staticmethod async def call_api(url, headers, request_json): diff --git a/elm/wizard.py b/elm/wizard.py index b368a634..f19d3699 100644 --- a/elm/wizard.py +++ b/elm/wizard.py @@ -126,7 +126,14 @@ def rank_strings(self, query, top_n=100): return strings, scores, best - def engineer_query(self, query, token_budget=None, new_info_threshold=0.7): + def _make_convo_query(self): + messages = [f"{msg['role'].upper()}: {msg['content']}" + for msg in self.messages] + messages = '\n\n'.join(messages) + return messages + + def engineer_query(self, query, token_budget=None, new_info_threshold=0.7, + conversational=False): """Engineer a query for GPT using the corpus of information Parameters @@ -139,6 +146,9 @@ def engineer_query(self, query, token_budget=None, new_info_threshold=0.7): New text added to the engineered query must contain at least this much new information. This helps prevent (for example) the table of contents being added multiple times. + conversational : bool + Flag to ask query with conversation history. Call + EnergyWizard.clear() to reset the chat. Returns ------- @@ -150,6 +160,9 @@ def engineer_query(self, query, token_budget=None, new_info_threshold=0.7): returned here """ + if conversational: + query = self._make_convo_query() + token_budget = token_budget or self.token_budget strings, _, idx = self.rank_strings(query) @@ -199,8 +212,14 @@ def make_ref_list(self, idx): return ref_list - def ask(self, query, debug=True, stream=True, temperature=0, - token_budget=None, new_info_threshold=0.7, print_references=False): + def ask(self, query, + debug=True, + stream=True, + temperature=0, + conversational=False, + token_budget=None, + new_info_threshold=0.7, + print_references=False): """Answers a query using GPT and a dataframe of relevant texts and embeddings. @@ -214,6 +233,9 @@ def ask(self, query, debug=True, stream=True, temperature=0, GPT model temperature, a measure of response entropy from 0 to 1. 0 is more reliable and nearly deterministic; 1 will give the model more creative freedom and may not return as factual of results. + conversational : bool + Flag to ask query with conversation history. Call + EnergyWizard.clear() to reset the chat. token_budget : int Option to override the class init token budget. new_info_threshold : float @@ -236,8 +258,10 @@ def ask(self, query, debug=True, stream=True, temperature=0, engineered prompt is returned here """ + self.chat_messages.append({"role": "user", "content": query}) out = self.engineer_query(query, token_budget=token_budget, - new_info_threshold=new_info_threshold) + new_info_threshold=new_info_threshold, + conversational=conversational) query, references = out messages = [{"role": "system", "content": self.MODEL_ROLE}, @@ -268,6 +292,9 @@ def ask(self, query, debug=True, stream=True, temperature=0, 'support its answer:') print(' - ' + '\n - '.join(references)) + self.chat_messages.append({'role': 'assistant', + 'content': response_message}) + if debug: return response_message, query, references else: From 7525dd086abf8f80929d034ddc159b85e9467a8e Mon Sep 17 00:00:00 2001 From: grantbuster Date: Thu, 19 Oct 2023 15:34:44 -0600 Subject: [PATCH 5/6] bug fix on attr --- elm/wizard.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/elm/wizard.py b/elm/wizard.py index f19d3699..2c661890 100644 --- a/elm/wizard.py +++ b/elm/wizard.py @@ -127,10 +127,10 @@ def rank_strings(self, query, top_n=100): return strings, scores, best def _make_convo_query(self): - messages = [f"{msg['role'].upper()}: {msg['content']}" - for msg in self.messages] - messages = '\n\n'.join(messages) - return messages + query = [f"{msg['role'].upper()}: {msg['content']}" + for msg in self.chat_messages] + query = '\n\n'.join(query) + return query def engineer_query(self, query, token_budget=None, new_info_threshold=0.7, conversational=False): From 23eae3f495deb3a7f532c201f8303b03c6e7cf70 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Fri, 20 Oct 2023 08:49:06 -0600 Subject: [PATCH 6/6] cleaned up convo feature - breaking change on wizard.ask() -> wizard.chat() and .messages attribute --- elm/base.py | 26 ++++++++++++++----- elm/tree.py | 7 ++---- elm/wizard.py | 59 ++++++++++++++++++++++---------------------- tests/test_wizard.py | 56 +++++++++++++++++++++++++++++++++++------ 4 files changed, 100 insertions(+), 48 deletions(-) diff --git a/elm/base.py b/elm/base.py index e63c21ec..8dd12195 100644 --- a/elm/base.py +++ b/elm/base.py @@ -50,12 +50,26 @@ def __init__(self, model=None): """ self.model = model or self.DEFAULT_MODEL self.api_queue = None - self.chat_messages = [] + self.messages = [] self.clear() + @property + def all_messages_txt(self): + """Get a string printout of the full conversation with the LLM + + Returns + ------- + str + """ + messages = [f"{msg['role'].upper()}: {msg['content']}" + for msg in self.messages] + messages = '\n\n'.join(messages) + return messages + def clear(self): - """Clear chat""" - self.chat_messages = [{"role": "system", "content": self.MODEL_ROLE}] + """Clear chat history and reduce messages to just the initial model + role message.""" + self.messages = [{"role": "system", "content": self.MODEL_ROLE}] @staticmethod async def call_api(url, headers, request_json): @@ -166,10 +180,10 @@ def chat(self, query, temperature=0): Model response """ - self.chat_messages.append({"role": "user", "content": query}) + self.messages.append({"role": "user", "content": query}) kwargs = dict(model=self.model, - messages=self.chat_messages, + messages=self.messages, temperature=temperature, stream=False) if 'azure' in str(openai.api_type).lower(): @@ -177,7 +191,7 @@ def chat(self, query, temperature=0): response = openai.ChatCompletion.create(**kwargs) response = response["choices"][0]["message"]["content"] - self.chat_messages.append({'role': 'assistant', 'content': response}) + self.messages.append({'role': 'assistant', 'content': response}) return response diff --git a/elm/tree.py b/elm/tree.py index 5a156853..8946ac08 100644 --- a/elm/tree.py +++ b/elm/tree.py @@ -80,7 +80,7 @@ def messages(self): ------- list """ - return self.api.chat_messages + return self.api.messages @property def all_messages_txt(self): @@ -90,10 +90,7 @@ def all_messages_txt(self): ------- str """ - messages = [f"{msg['role'].upper()}: {msg['content']}" - for msg in self.messages] - messages = '\n\n'.join(messages) - return messages + return self.api.all_messages_txt @property def history(self): diff --git a/elm/wizard.py b/elm/wizard.py index 2c661890..a0121ec5 100644 --- a/elm/wizard.py +++ b/elm/wizard.py @@ -126,14 +126,8 @@ def rank_strings(self, query, top_n=100): return strings, scores, best - def _make_convo_query(self): - query = [f"{msg['role'].upper()}: {msg['content']}" - for msg in self.chat_messages] - query = '\n\n'.join(query) - return query - def engineer_query(self, query, token_budget=None, new_info_threshold=0.7, - conversational=False): + convo=False): """Engineer a query for GPT using the corpus of information Parameters @@ -146,10 +140,10 @@ def engineer_query(self, query, token_budget=None, new_info_threshold=0.7, New text added to the engineered query must contain at least this much new information. This helps prevent (for example) the table of contents being added multiple times. - conversational : bool - Flag to ask query with conversation history. Call - EnergyWizard.clear() to reset the chat. - + convo : bool + Flag to perform semantic search with full conversation history + (True) or just the single query (False). Call EnergyWizard.clear() + to reset the chat history. Returns ------- message : str @@ -160,8 +154,13 @@ def engineer_query(self, query, token_budget=None, new_info_threshold=0.7, returned here """ - if conversational: - query = self._make_convo_query() + self.messages.append({"role": "user", "content": query}) + + if convo: + # [1:] to not include the system role in the semantic search + query = [f"{msg['role'].upper()}: {msg['content']}" + for msg in self.messages[1:]] + query = '\n\n'.join(query) token_budget = token_budget or self.token_budget @@ -212,16 +211,16 @@ def make_ref_list(self, idx): return ref_list - def ask(self, query, - debug=True, - stream=True, - temperature=0, - conversational=False, - token_budget=None, - new_info_threshold=0.7, - print_references=False): - """Answers a query using GPT and a dataframe of relevant texts and - embeddings. + def chat(self, query, + debug=True, + stream=True, + temperature=0, + convo=False, + token_budget=None, + new_info_threshold=0.7, + print_references=False): + """Answers a query by doing a semantic search of relevant text with + embeddings and then sending engineered query to the LLM. Parameters ---------- @@ -233,9 +232,10 @@ def ask(self, query, GPT model temperature, a measure of response entropy from 0 to 1. 0 is more reliable and nearly deterministic; 1 will give the model more creative freedom and may not return as factual of results. - conversational : bool - Flag to ask query with conversation history. Call - EnergyWizard.clear() to reset the chat. + convo : bool + Flag to perform semantic search with full conversation history + (True) or just the single query (False). Call EnergyWizard.clear() + to reset the chat history. token_budget : int Option to override the class init token budget. new_info_threshold : float @@ -258,10 +258,9 @@ def ask(self, query, engineered prompt is returned here """ - self.chat_messages.append({"role": "user", "content": query}) out = self.engineer_query(query, token_budget=token_budget, new_info_threshold=new_info_threshold, - conversational=conversational) + convo=convo) query, references = out messages = [{"role": "system", "content": self.MODEL_ROLE}, @@ -292,8 +291,8 @@ def ask(self, query, 'support its answer:') print(' - ' + '\n - '.join(references)) - self.chat_messages.append({'role': 'assistant', - 'content': response_message}) + self.messages.append({'role': 'assistant', + 'content': response_message}) if debug: return response_message, query, references diff --git a/tests/test_wizard.py b/tests/test_wizard.py index 751cd259..2b475656 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -40,12 +40,8 @@ def create(*args, **kwargs): # pylint: disable=unused-argument return response -def test_chunk_and_embed(mocker): - """Simple text to embedding test - - Note that embedding api is mocked here and not actually tested. - """ - +def make_corpus(mocker): + """Make a text corpus with embeddings for the wizard.""" mocker.patch.object(elm.embed.ChunkAndEmbed, "call_api", MockClass.call) mocker.patch.object(elm.wizard.EnergyWizard, "get_embedding", MockClass.get_embedding) @@ -58,14 +54,60 @@ def test_chunk_and_embed(mocker): for i, emb in enumerate(embeddings): corpus.append({'text': ce0.text_chunks[i], 'embedding': emb, 'ref': 'source0'}) + return corpus + +def test_chunk_and_embed(mocker): + """Simple text to embedding test + + Note that embedding api is mocked here and not actually tested. + """ + + corpus = make_corpus(mocker) wizard = EnergyWizard(pd.DataFrame(corpus), token_budget=1000, ref_col='ref') + question = 'What time is it?' - out = wizard.ask(question, debug=True, stream=False, print_references=True) + out = wizard.chat(question, debug=True, stream=False, + print_references=True) msg, query, ref = out assert msg == 'hello!' assert query.startswith(EnergyWizard.MODEL_INSTRUCTION) assert query.endswith(question) assert 'source0' in ref + + +def test_convo_query(mocker): + """Query with multiple messages + + Note that embedding api is mocked here and not actually tested. + """ + + corpus = make_corpus(mocker) + wizard = EnergyWizard(pd.DataFrame(corpus), token_budget=1000, + ref_col='ref') + + question1 = 'What time is it?' + question2 = 'How about now?' + + query = wizard.chat(question1, debug=True, stream=False, convo=True, + print_references=True)[1] + assert question1 in query + assert question2 not in query + assert len(wizard.messages) == 3 + + query = wizard.chat(question2, debug=True, stream=False, convo=True, + print_references=True)[1] + assert question1 in query + assert question2 in query + assert len(wizard.messages) == 5 + + wizard.clear() + assert len(wizard.messages) == 1 + + query = wizard.chat(question2, debug=True, stream=False, convo=True, + print_references=True)[1] + assert question1 not in query + assert question2 in query + assert len(wizard.messages) == 3