diff --git a/memgpt/local_llm/lmstudio/api.py b/memgpt/local_llm/lmstudio/api.py index a379914d19..1715ba445f 100644 --- a/memgpt/local_llm/lmstudio/api.py +++ b/memgpt/local_llm/lmstudio/api.py @@ -2,39 +2,59 @@ from urllib.parse import urljoin import requests -# from .settings import SIMPLE +from .settings import SIMPLE HOST = os.getenv("OPENAI_API_BASE") HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion -LMSTUDIO_API_SUFFIX = "/v1/completions" +LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions" +LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions" DEBUG = False -from .settings import SIMPLE - -def get_lmstudio_completion(prompt, settings=SIMPLE): +def get_lmstudio_completion(prompt, settings=SIMPLE, api="chat"): """Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client""" # Settings for the generation, includes the prompt + stop tokens, max length, etc request = settings - request["prompt"] = prompt + + if api == "chat": + # Uses the ChatCompletions API style + # Seems to work better, probably because it's applying some extra settings under-the-hood? + URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/")) + message_structure = [{"role": "user", "content": prompt}] + request["messages"] = message_structure + elif api == "completions": + # Uses basic string completions (string in, string out) + # Does not work as well as ChatCompletions for some reason + URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/")) + request["prompt"] = prompt + else: + raise ValueError(api) if not HOST.startswith(("http://", "https://")): raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") try: - URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_SUFFIX.strip("/")) response = requests.post(URI, json=request) if response.status_code == 200: result = response.json() - # result = result["results"][0]["text"] - result = result["choices"][0]["text"] + if api == "chat": + result = result["choices"][0]["message"]["content"] + elif api == "completions": + result = result["choices"][0]["text"] if DEBUG: print(f"json API response.text: {result}") else: - raise Exception( - f"API call got non-200 response code for address: {URI}. Make sure that the LM Studio local inference server is running and reachable at {URI}." - ) + # Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"} + if "context length" in str(response.text).lower(): + # "exceeds context length" is what appears in the LM Studio error message + # raise an alternate exception that matches OpenAI's message, which is "maximum context length" + raise Exception(f"Request exceeds maximum context length (code={response.status_code}, msg={response.text}, URI={URI})") + else: + raise Exception( + f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." + + f"Make sure that the LM Studio local inference server is running and reachable at {URI}." + ) except: # TODO handle gracefully raise diff --git a/memgpt/local_llm/lmstudio/settings.py b/memgpt/local_llm/lmstudio/settings.py index cdf8962e95..3d8622ace8 100644 --- a/memgpt/local_llm/lmstudio/settings.py +++ b/memgpt/local_llm/lmstudio/settings.py @@ -9,5 +9,12 @@ # '\n#', # '\n\n\n', ], - "max_tokens": 500, + # This controls the maximum number of tokens that the model can generate + # Cap this at the model context length (assuming 8k for Mistral 7B) + "max_tokens": 8000, + # This controls how LM studio handles context overflow + # In MemGPT we handle this ourselves, so this should be commented out + # "lmstudio": {"context_overflow_policy": 2}, + "stream": False, + "model": "local model", }