Skip to content

Commit

Permalink
LM Studio inference server support (#167)
Browse files Browse the repository at this point in the history
* updated airo wrapper to catch specific case where extra closing } is missing

* added lmstudio support
  • Loading branch information
cpacker authored Oct 29, 2023
1 parent b988ef2 commit adb73fa
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 3 deletions.
3 changes: 3 additions & 0 deletions memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json

from .webui.api import get_webui_completion
from .lmstudio.api import get_lmstudio_completion
from .llm_chat_completion_wrappers import airoboros, dolphin
from .utils import DotDict

Expand Down Expand Up @@ -40,6 +41,8 @@ async def get_chat_completion(
try:
if HOST_TYPE == "webui":
result = get_webui_completion(prompt)
elif HOST_TYPE == "lmstudio":
result = get_lmstudio_completion(prompt)
else:
print(f"Warning: BACKEND_TYPE was not set, defaulting to webui")
result = get_webui_completion(prompt)
Expand Down
5 changes: 4 additions & 1 deletion memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,10 @@ def output_to_chat_completion_response(self, raw_llm_output):
try:
function_json_output = json.loads(raw_llm_output)
except Exception as e:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
try:
function_json_output = json.loads(raw_llm_output + "\n}")
except:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]

Expand Down
41 changes: 41 additions & 0 deletions memgpt/local_llm/lmstudio/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import requests

# from .settings import SIMPLE

HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
LMSTUDIO_API_SUFFIX = "/v1/completions"
DEBUG = False

from .settings import SIMPLE


def get_lmstudio_completion(prompt, settings=SIMPLE):
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""

# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["prompt"] = prompt

if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")

try:
URI = os.path.join(HOST.strip("/"), LMSTUDIO_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
# result = result["results"][0]["text"]
result = result["choices"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
else:
raise Exception(
f"API call got non-200 response code for address: {URI}. Make sure that the LM Studio local inference server is running and reachable at {URI}."
)
except:
# TODO handle gracefully
raise

return result
13 changes: 13 additions & 0 deletions memgpt/local_llm/lmstudio/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
SIMPLE = {
"stop": [
"\nUSER:",
"\nASSISTANT:",
"\nFUNCTION RETURN:",
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
"max_tokens": 500,
}
9 changes: 7 additions & 2 deletions memgpt/local_llm/webui/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,21 @@ def get_webui_completion(prompt, settings=SIMPLE):
request = settings
request["prompt"] = prompt

if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")

try:
URI = f"{HOST.strip('/')}{WEBUI_API_SUFFIX}"
URI = os.path.join(HOST.strip("/"), WEBUI_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["results"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
else:
raise Exception(f"API call got non-200 response code for address: {URI}")
raise Exception(
f"API call got non-200 response code for address: {URI}. Make sure that the web UI server is running and reachable at {URI}."
)
except:
# TODO handle gracefully
raise
Expand Down

0 comments on commit adb73fa

Please # to comment.