Skip to content

Commit c08309e

Browse files
Rudimentary support of openai chat completions tools calls (ggml-org#981)
* Rudimentary support of openai chat completions tools calls -Most small models are not smart enough to do this, especially a combined tool call + role play response, but at least this allows experimentation along these lines with koboldcpp * try to also support specified function and tool choice set to none Allow tools start and end messages to be configured in adapter Try to force grammar to specific function call if specified (untested) * ensure tools get listed right after user content and before end of user message content * omit grammars approach try prompting instead -use more extensive json parsing and direct instructions to models to try to obtain the desired result -seems to work relatively well with Mistral-7B-Instruct-v.0.3.Q4_K_M.gguf and neuralhermes-2.5-mistral-7b.Q4_K_M.gguf -question of whether this is too opinionated of an approach, should the instructions be things that can be passed with the prompt template? * add back llamacpp recommended json grammar Go back to adding grammar but use "official" llamacpp grammar only not a custom one just for openai * Tidy up, remove unnecessary globals * clarity * fix missing local variable error This worked to fix the error I mentioned on my last comment --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
1 parent 5caf5f9 commit c08309e

File tree

1 file changed

+108
-19
lines changed

1 file changed

+108
-19
lines changed

Diff for: koboldcpp.py

+108-19
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,33 @@ def string_contains_sequence_substring(inputstr,sequences):
829829
using_gui_launcher = False
830830
using_outdated_flags = False
831831

832+
# Used to parse json for openai tool calls
833+
def extract_json_from_string(input_string):
834+
parsed_json = None
835+
try: # First check if model exported perfect json
836+
parsed_json = json.loads(input_string)
837+
return parsed_json
838+
except Exception as e:
839+
pass
840+
try: # Next check if all we need is to add brackets to make it perfect json
841+
parsed_json = json.loads(f"[{input_string}]")
842+
return parsed_json
843+
except Exception as e:
844+
pass
845+
try:
846+
# Now use regular expression to match JSON objects or arrays in case part is valid json and part is not
847+
json_pattern = r'(\{.*?\}|\[.*?\])' # was json_pattern = r'(\{.*\}|\[.*\])'
848+
potential_jsons = re.findall(json_pattern, input_string, re.DOTALL)
849+
for potential_json in potential_jsons:
850+
try:
851+
parsed_json = json.loads(potential_json)
852+
return parsed_json
853+
except Exception as e:
854+
continue
855+
except Exception as e:
856+
pass
857+
return []
858+
832859
def transform_genparams(genparams, api_format):
833860
#api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate
834861
#alias all nonstandard alternative names for rep pen.
@@ -873,15 +900,21 @@ def transform_genparams(genparams, api_format):
873900
user_message_end = adapter_obj.get("user_end", "")
874901
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
875902
assistant_message_end = adapter_obj.get("assistant_end", "")
903+
tools_message_start = adapter_obj.get("tools_start", "")
904+
tools_message_end = adapter_obj.get("tools_end", "")
876905
images_added = []
877906

907+
message_index = 0
878908
for message in messages_array:
909+
message_index += 1
879910
if message['role'] == "system":
880911
messages_string += system_message_start
881912
elif message['role'] == "user":
882913
messages_string += user_message_start
883914
elif message['role'] == "assistant":
884915
messages_string += assistant_message_start
916+
elif message['role'] == "tool":
917+
messages_string += tools_message_start
885918

886919
# content can be a string or an array of objects
887920
curr_content = message['content']
@@ -894,13 +927,64 @@ def transform_genparams(genparams, api_format):
894927
elif item['type']=="image_url":
895928
if item['image_url'] and item['image_url']['url'] and item['image_url']['url'].startswith("data:image"):
896929
images_added.append(item['image_url']['url'].split(",", 1)[1])
897-
930+
# If last message, add any tools calls after message content and before message end token if any
931+
if message['role'] == "user" and message_index == len(messages_array):
932+
# Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None
933+
tools_array = genparams.get('tools', [])
934+
if tools_array and len(tools_array) > 0 and genparams.get('tool_choice',None) != None:
935+
response_array = [{"id": "insert an id for the response", "type": "function", "function": {"name": "insert the name of the function you want to call", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}]
936+
json_formatting_instruction = " Use this style of JSON object formatting to give your answer if you think the user is asking you to perform an action: " + json.dumps(response_array, indent=0)
937+
tools_string = json.dumps(tools_array, indent=0)
938+
messages_string += tools_string
939+
specified_function = None
940+
if isinstance(genparams.get('tool_choice'), dict):
941+
try:
942+
specified_function = genparams.get('tool_choice').get('function').get('name')
943+
json_formatting_instruction = f"The user is asking you to use the style of this JSON object formatting to complete the parameters for the specific function named {specified_function} in the following format: " + json.dumps([{"id": "insert an id for the response", "type": "function", "function": {"name": f"{specified_function}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
944+
except Exception as e:
945+
# In case of any issues, just revert back to no specified function
946+
pass
947+
messages_string += json_formatting_instruction
948+
949+
# Set temperature low automatically if function calling
950+
genparams["temperature"] = 0.2
951+
genparams["using_openai_tools"] = True
952+
953+
# Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf)
954+
genparams["grammar"] = r"""
955+
root ::= arr
956+
value ::= object | array | string | number | ("true" | "false" | "null") ws
957+
arr ::=
958+
"[\n" ws (
959+
value
960+
(",\n" ws value)*
961+
)? "]"
962+
object ::=
963+
"{" ws (
964+
string ":" ws value
965+
("," ws string ":" ws value)*
966+
)? "}" ws
967+
array ::=
968+
"[" ws (
969+
value
970+
("," ws value)*
971+
)? "]" ws
972+
string ::=
973+
"\"" (
974+
[^"\\\x7F\x00-\x1F] |
975+
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4})
976+
)* "\"" ws
977+
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws
978+
ws ::= | " " | "\n" [ \t]{0,20}
979+
"""
898980
if message['role'] == "system":
899981
messages_string += system_message_end
900982
elif message['role'] == "user":
901983
messages_string += user_message_end
902984
elif message['role'] == "assistant":
903985
messages_string += assistant_message_end
986+
elif message['role'] == "tool":
987+
messages_string += tools_message_end
904988

905989
messages_string += assistant_message_start
906990
genparams["prompt"] = messages_string
@@ -913,6 +997,7 @@ def transform_genparams(genparams, api_format):
913997
genparams["stop_sequence"].append(assistant_message_start.strip())
914998
genparams["trim_stop"] = True
915999

1000+
9161001
elif api_format==5:
9171002
firstimg = genparams.get('image', "")
9181003
genparams["images"] = [firstimg]
@@ -963,9 +1048,8 @@ async def generate_text(self, genparams, api_format, stream_flag):
9631048
is_quiet = args.quiet
9641049
currfinishreason = "null"
9651050

966-
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
967-
968-
#flag instance as non-idle for a while
1051+
def run_blocking(): # api format 1=basic,2=kai,3=oai,4=oai-chat
1052+
# flag instance as non-idle for a while
9691053
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
9701054
if not washordereq:
9711055
global last_non_horde_req_time
@@ -1013,9 +1097,9 @@ def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
10131097
render_special=genparams.get('render_special', False),
10141098
banned_tokens=genparams.get('banned_tokens', []),
10151099
bypass_eos_token=genparams.get('bypass_eos', False),
1016-
)
1100+
)
10171101

1018-
genout = {"text":"","status":-1,"stopreason":-1}
1102+
genout = {"text": "", "status": -1, "stopreason": -1}
10191103
if stream_flag:
10201104
loop = asyncio.get_event_loop()
10211105
executor = ThreadPoolExecutor()
@@ -1024,9 +1108,9 @@ def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
10241108
genout = run_blocking()
10251109

10261110
recvtxt = genout['text']
1027-
currfinishreason = ("length" if (genout['stopreason']!=1) else "stop")
1111+
currfinishreason = ("length" if (genout['stopreason'] != 1) else "stop")
10281112

1029-
#flag instance as non-idle for a while
1113+
# flag instance as non-idle for a while
10301114
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
10311115
if not washordereq:
10321116
global last_non_horde_req_time
@@ -1035,27 +1119,32 @@ def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
10351119
if (args.debugmode != -1 and not is_quiet) or args.debugmode >= 1:
10361120
utfprint("\nOutput: " + recvtxt)
10371121

1038-
if api_format==1:
1039-
res = {"data": {"seqs":[recvtxt]}}
1040-
elif api_format==3:
1122+
if api_format == 1:
1123+
res = {"data": {"seqs": [recvtxt]}}
1124+
elif api_format == 3:
10411125
res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": friendlymodelname,
1042-
"usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200},
1043-
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]}
1044-
elif api_format==4:
1126+
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200},
1127+
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]}
1128+
elif api_format == 4:
1129+
using_openai_tools = genparams.get('using_openai_tools', False)
1130+
tool_calls = []
1131+
if using_openai_tools:
1132+
tool_calls = extract_json_from_string(recvtxt)
1133+
if tool_calls and len(tool_calls)>0:
1134+
recvtxt = None
10451135
res = {"id": "chatcmpl-1", "object": "chat.completion", "created": 1, "model": friendlymodelname,
1046-
"usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200},
1047-
"choices": [{"index": 0, "message":{"role": "assistant", "content": recvtxt,}, "finish_reason": currfinishreason}]}
1048-
elif api_format==5:
1136+
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200},
1137+
"choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason}]}
1138+
elif api_format == 5:
10491139
res = {"caption": end_trim_to_sentence(recvtxt)}
10501140
else:
1051-
res = {"results": [{"text": recvtxt, "finish_reason":currfinishreason}]}
1141+
res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason}]}
10521142

10531143
try:
10541144
return res
10551145
except Exception as e:
10561146
print(f"Generate: Error while generating: {e}")
10571147

1058-
10591148
async def send_oai_sse_event(self, data):
10601149
if data=="[DONE]":
10611150
self.wfile.write(f'data: {data}'.encode())

0 commit comments

Comments
 (0)