Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: added new 'hint' wrappers that inject hints into the pre-prefix #707

Merged
merged 3 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions memgpt/functions/function_sets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def conversation_search(self, query: str, page: Optional[int] = 0) -> Optional[s
Returns:
str: Query result string
"""
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
try:
page = int(page)
except:
raise ValueError(f"'page' argument must be an integer")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
results, total = self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
Expand All @@ -119,6 +125,12 @@ def conversation_search_date(self, start_date: str, end_date: str, page: Optiona
Returns:
str: Query result string
"""
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
try:
page = int(page)
except:
raise ValueError(f"'page' argument must be an integer")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
results, total = self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
Expand Down Expand Up @@ -156,6 +168,12 @@ def archival_memory_search(self, query: str, page: Optional[int] = 0) -> Optiona
Returns:
str: Query result string
"""
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
try:
page = int(page)
except:
raise ValueError(f"'page' argument must be an integer")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
results, total = self.persistence_manager.archival_memory.search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
Expand Down
3 changes: 2 additions & 1 deletion memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def get_chat_completion(

# First step: turn the message sequence into a prompt that the model expects
try:
if hasattr(llm_wrapper, "supports_first_message") and llm_wrapper.supports_first_message:
# if hasattr(llm_wrapper, "supports_first_message") and llm_wrapper.supports_first_message:
if hasattr(llm_wrapper, "supports_first_message"):
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions, first_message=first_message)
else:
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
Expand Down
31 changes: 29 additions & 2 deletions memgpt/local_llm/llm_chat_completion_wrappers/chatml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
from ...errors import LLMJSONParsingError


PREFIX_HINT = """# Reminders:
# Important information about yourself and the user is stored in (limited) core memory
# You can modify core memory with core_memory_replace
# You can add to core memory with core_memory_append
# Less important information is stored in (unlimited) archival memory
# You can add to archival memory with archival_memory_insert
# You can search archival memory with archival_memory_search
# You will always see the statistics of archival memory, so you know if there is content inside it
# If you receive new important information about the user (or yourself), you immediately update your memory with core_memory_replace, core_memory_append, or archival_memory_insert"""

FIRST_PREFIX_HINT = """# Reminders:
# This is your first interaction with the user!
# Initial information about them is provided in the core memory user block
# Make sure to introduce yourself to them
# Your inner thoughts should be private, interesting, and creative
# Do NOT use inner thoughts to communicate with the user
# Use send_message to communicate with the user"""
# Don't forget to use send_message, otherwise the user won't see your message"""


class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
"""ChatML-style prompt formatter, tested for use with https://huggingface.co/ehartford/dolphin-2.5-mixtral-8x7b#training"""

Expand All @@ -24,12 +44,15 @@ def __init__(
allow_function_role=False, # use function role for function replies?
no_function_role_role="assistant", # if no function role, which role to use?
no_function_role_prefix="FUNCTION RETURN:\n", # if no function role, what prefix to use?
# add a guiding hint
assistant_prefix_hint=False,
):
self.simplify_json_content = simplify_json_content
self.clean_func_args = clean_function_args
self.include_assistant_prefix = include_assistant_prefix
self.assistant_prefix_extra = assistant_prefix_extra
self.assistant_prefix_extra_first_message = assistant_prefix_extra_first_message
self.assistant_prefix_hint = assistant_prefix_hint

# role-based
self.allow_custom_roles = allow_custom_roles
Expand Down Expand Up @@ -202,7 +225,9 @@ def chat_completion_to_prompt(self, messages, functions, first_message=False):

if self.include_assistant_prefix:
prompt += f"\n<|im_start|>assistant"
if first_message:
if self.assistant_prefix_hint:
prompt += f"\n{FIRST_PREFIX_HINT if first_message else PREFIX_HINT}"
if self.supports_first_message and first_message:
if self.assistant_prefix_extra_first_message:
prompt += self.assistant_prefix_extra_first_message
else:
Expand Down Expand Up @@ -355,7 +380,9 @@ def output_to_chat_completion_response(self, raw_llm_output, first_message=False
"""

# If we used a prefex to guide generation, we need to add it to the output as a preefix
assistant_prefix = self.assistant_prefix_extra_first_message if first_message else self.assistant_prefix_extra
assistant_prefix = (
self.assistant_prefix_extra_first_message if (self.supports_first_message and first_message) else self.assistant_prefix_extra
)
if assistant_prefix and raw_llm_output[: len(assistant_prefix)] != assistant_prefix:
raw_llm_output = assistant_prefix + raw_llm_output

Expand Down
2 changes: 1 addition & 1 deletion memgpt/local_llm/settings/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# '\n' +
# '</s>',
# '<|',
# '\n#',
"\n#",
# "\n\n\n",
# prevent chaining function calls / multi json objects / run-on generations
# NOTE: this requires the ability to patch the extra '}}' back into the prompt
Expand Down
3 changes: 3 additions & 0 deletions memgpt/local_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def get_available_wrappers() -> dict:
# New chatml-based wrappers
"chatml": chatml.ChatMLInnerMonologueWrapper(),
"chatml-noforce": chatml.ChatMLOuterInnerMonologueWrapper(),
# With extra hints
"chatml-hints": chatml.ChatMLInnerMonologueWrapper(assistant_prefix_hint=True),
"chatml-noforce-hints": chatml.ChatMLOuterInnerMonologueWrapper(assistant_prefix_hint=True),
# Legacy wrappers
"airoboros-l2-70b-2.1": airoboros.Airoboros21InnerMonologueWrapper(),
"airoboros-l2-70b-2.1-grammar": airoboros.Airoboros21InnerMonologueWrapper(assistant_prefix_extra=None),
Expand Down