From 0f3e6b0bba196a5e0427ce542755ea3fe633f641 Mon Sep 17 00:00:00 2001 From: rhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com> Date: Mon, 6 Nov 2023 17:55:12 +0000 Subject: [PATCH 1/3] fix oai proxy fix generation not stoped while bot stop talking in chat mode fix possible `slot_id` not exist response for cors (and pre flight) --- examples/server/api_like_OAI.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/server/api_like_OAI.py b/examples/server/api_like_OAI.py index 313e1a9652d14..a67cb9f275a61 100755 --- a/examples/server/api_like_OAI.py +++ b/examples/server/api_like_OAI.py @@ -72,7 +72,7 @@ def make_postData(body, chat=False, stream=False): if(is_present(body, "seed")): postData["seed"] = body["seed"] if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()] if (args.stop != ""): - postData["stop"] = [args.stop] + postData["stop"] = [args.stop, args.ai_name.replace("\\n", "\n"), args.user_name.replace("\\n", "\n")] else: postData["stop"] = [] if(is_present(body, "stop")): postData["stop"] += body["stop"] @@ -130,7 +130,7 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False): } ] } - slot_id = data["slot_id"] + slot_id = data.get("slot_id") if (chat): if (start): resData["choices"][0]["delta"] = { @@ -150,11 +150,13 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False): return resData -@app.route('/chat/completions', methods=['POST']) -@app.route('/v1/chat/completions', methods=['POST']) +@app.route('/chat/completions', methods=['POST', 'OPTIONS']) +@app.route('/v1/chat/completions', methods=['POST', 'OPTIONS']) def chat_completions(): if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): return Response(status=403) + if request.method == 'OPTIONS': + return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) body = request.get_json() stream = False tokenize = False @@ -183,14 +185,16 @@ def generate(): decoded_line = line.decode('utf-8') resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) yield 'data: {}\n'.format(json.dumps(resData)) - return Response(generate(), mimetype='text/event-stream') + return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) -@app.route('/completions', methods=['POST']) -@app.route('/v1/completions', methods=['POST']) +@app.route('/completions', methods=['POST', 'OPTIONS']) +@app.route('/v1/completions', methods=['POST', 'OPTIONS']) def completion(): if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): return Response(status=403) + if request.method == 'OPTIONS': + return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) body = request.get_json() stream = False tokenize = False @@ -217,7 +221,7 @@ def generate(): decoded_line = line.decode('utf-8') resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) yield 'data: {}\n'.format(json.dumps(resData)) - return Response(generate(), mimetype='text/event-stream') + return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) if __name__ == '__main__': app.run(args.host, port=args.port) From 73bb25e901ab765408cddd702529cfd699ccf3b3 Mon Sep 17 00:00:00 2001 From: rhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com> Date: Wed, 8 Nov 2023 12:20:40 +0000 Subject: [PATCH 2/3] oai proxy: workaround for some client (such as Chatbox) --- examples/server/api_like_OAI.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/server/api_like_OAI.py b/examples/server/api_like_OAI.py index a67cb9f275a61..f4b661646506e 100755 --- a/examples/server/api_like_OAI.py +++ b/examples/server/api_like_OAI.py @@ -179,12 +179,12 @@ def generate(): data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) time_now = int(time.time()) resData = make_resData_stream({}, chat=True, time_now=time_now, start=True) - yield 'data: {}\n'.format(json.dumps(resData)) + yield 'data: {}\n\n'.format(json.dumps(resData)) for line in data.iter_lines(): if line: decoded_line = line.decode('utf-8') resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) - yield 'data: {}\n'.format(json.dumps(resData)) + yield 'data: {}\n\n'.format(json.dumps(resData)) return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) @@ -220,7 +220,7 @@ def generate(): if line: decoded_line = line.decode('utf-8') resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) - yield 'data: {}\n'.format(json.dumps(resData)) + yield 'data: {}\n\n'.format(json.dumps(resData)) return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) if __name__ == '__main__': From ec8e74ebc80f5ce0168774ec65be8f31399f5ac2 Mon Sep 17 00:00:00 2001 From: rhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com> Date: Thu, 30 Nov 2023 20:25:01 +0000 Subject: [PATCH 3/3] use stop as separator to replace hardcoded `\n` --- examples/server/api_like_OAI.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/server/api_like_OAI.py b/examples/server/api_like_OAI.py index f4b661646506e..830c056d4acfc 100755 --- a/examples/server/api_like_OAI.py +++ b/examples/server/api_like_OAI.py @@ -11,10 +11,10 @@ slot_id = -1 parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.") -parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n') -parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ") -parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ") -parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ") +parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.') +parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: 'USER: ')", default="USER: ") +parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: 'ASSISTANT: ')", default="ASSISTANT: ") +parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: 'ASSISTANT's RULE: ')", default="ASSISTANT's RULE: ") parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '')", default="") parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080') parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="") @@ -34,19 +34,19 @@ def is_present(json, key): #convert chat to prompt def convert_chat(messages): - prompt = "" + args.chat_prompt.replace("\\n", "\n") - system_n = args.system_name.replace("\\n", "\n") - user_n = args.user_name.replace("\\n", "\n") - ai_n = args.ai_name.replace("\\n", "\n") - stop = args.stop.replace("\\n", "\n") + system_n = args.system_name + user_n = args.user_name + ai_n = args.ai_name + stop = args.stop + prompt = "" + args.chat_prompt + stop for line in messages: if (line["role"] == "system"): - prompt += f"{system_n}{line['content']}" + prompt += f"{system_n}{line['content']}{stop}" if (line["role"] == "user"): - prompt += f"{user_n}{line['content']}" + prompt += f"{user_n}{line['content']}{stop}" if (line["role"] == "assistant"): prompt += f"{ai_n}{line['content']}{stop}" prompt += ai_n.rstrip() @@ -72,7 +72,7 @@ def make_postData(body, chat=False, stream=False): if(is_present(body, "seed")): postData["seed"] = body["seed"] if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()] if (args.stop != ""): - postData["stop"] = [args.stop, args.ai_name.replace("\\n", "\n"), args.user_name.replace("\\n", "\n")] + postData["stop"] = [args.stop] else: postData["stop"] = [] if(is_present(body, "stop")): postData["stop"] += body["stop"]