@@ -17,17 +17,26 @@ using json = nlohmann::ordered_json;
17
17
18
18
namespace minja {
19
19
20
+ struct chat_template_caps {
21
+ bool supports_tools = false ;
22
+ bool supports_tool_calls = false ;
23
+ bool supports_tool_responses = false ;
24
+ bool supports_system_role = false ;
25
+ bool supports_parallel_tool_calls = false ;
26
+ bool supports_tool_call_id = false ;
27
+ // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
28
+ // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
29
+ bool requires_object_arguments = false ;
30
+ // CohereForAI/c4ai-command-r-plus simple variant
31
+ bool requires_non_null_content = false ;
32
+ // MiniMaxAI/MiniMax-Text-01 special
33
+ bool requires_typed_content = false ;
34
+ };
35
+
20
36
class chat_template {
21
- public:
22
37
23
38
private:
24
- bool supports_tools_ = true ;
25
- // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
26
- // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
27
- bool requires_object_arguments_ = false ;
28
- bool requires_typed_content_ = false ;
29
- bool supports_system_role_ = true ;
30
- bool supports_parallel_tool_calls_ = false ;
39
+ chat_template_caps caps_;
31
40
std::string source_;
32
41
std::string bos_token_;
33
42
std::string eos_token_;
@@ -41,15 +50,16 @@ class chat_template {
41
50
{
42
51
try {
43
52
auto prompt = apply (messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false );
44
- // fprintf(stderr, "Prompt : %s\n", prompt.c_str());
53
+ // fprintf(stderr, "try_raw_render : %s\n", prompt.c_str());
45
54
return prompt;
46
55
} catch (const std::exception & e) {
47
- // fprintf(stderr, "Error : %s\n", e.what());
56
+ // fprintf(stderr, "try_raw_render error : %s\n", e.what());
48
57
return " " ;
49
58
}
50
59
}
51
60
52
61
public:
62
+
53
63
chat_template (const std::string & source, const std::string & bos_token, const std::string & eos_token)
54
64
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
55
65
{
@@ -58,69 +68,120 @@ class chat_template {
58
68
/* .lstrip_blocks = */ true ,
59
69
/* .keep_trailing_newline = */ false ,
60
70
});
61
- supports_tools_ = source.find (" tools" ) != std::string::npos;
62
71
63
- auto renders_string_arguments =
64
- try_raw_render ({
65
- {
66
- {" role" , " user" },
67
- {" content" , " Hey" }
68
- },
69
- {
70
- {" role" , " assistant" },
71
- {" tool_calls" , json::array ({
72
- {
73
- {" id" , " call_1___" },
74
- {" type" , " function" },
75
- {" function" , {
76
- {" arguments" , " {\" code\" : \" print('Hello, World!')\" }" },
77
- {" name" , " ipython" },
72
+ auto contains = [](const std::string & haystack, const std::string & needle) {
73
+ return haystack.find (needle) != std::string::npos;
74
+ };
75
+
76
+ const std::string user_needle = " <User Needle>" ;
77
+ const std::string sys_needle = " <System Needle>" ;
78
+ const json dummy_str_user_msg = {{" role" , " user" }, {" content" , user_needle}};
79
+ const json dummy_typed_user_msg = {{" role" , " user" }, {" content" , json::array ({{{" type" , " text" }, {" text" , user_needle}}})}};
80
+
81
+ caps_.requires_typed_content =
82
+ !contains (try_raw_render (json::array ({dummy_str_user_msg}), {}, false ), user_needle)
83
+ && contains (try_raw_render (json::array ({dummy_typed_user_msg}), {}, false ), user_needle);
84
+
85
+ const auto dummy_user_msg = caps_.requires_typed_content
86
+ ? dummy_typed_user_msg
87
+ : dummy_str_user_msg;
88
+ const json needle_system_msg = {
89
+ {" role" , " system" },
90
+ {" content" , caps_.requires_typed_content ? json::array ({{{" type" , " text" }, {" text" , sys_needle}}}) : json (sys_needle)},
91
+ };
92
+
93
+ caps_.supports_system_role = contains (try_raw_render ({needle_system_msg, dummy_user_msg,}, {}, false ), sys_needle);
94
+
95
+ auto out = try_raw_render (json::array ({
96
+ dummy_user_msg
97
+ }), json::array ({
98
+ {
99
+ {" name" , " some_tool" },
100
+ {" type" , " function" },
101
+ {" function" , {
102
+ {" name" , " some_tool" },
103
+ {" description" , " Some tool." },
104
+ {" parameters" , {
105
+ {" type" , " object" },
106
+ {" properties" , {
107
+ {" arg" , {
108
+ {" type" , " string" },
109
+ {" description" , " Some argument." },
78
110
}},
79
- },
80
- })},
81
- }
82
- }, {}, false ).find (" {\" code\" : \" print" ) != std::string::npos;
83
- if (!renders_string_arguments) {
84
- auto renders_object_arguments =
85
- try_raw_render ({
86
- {
87
- {" role" , " user" },
88
- {" content" , " Hey" }
89
- },
90
- {
91
- {" role" , " assistant" },
92
- {" tool_calls" , json::array ({
93
- {
94
- {" id" , " call_1___" },
95
- {" type" , " function" },
96
- {" function" , {
97
- {" arguments" , {
98
- {" code" , " print('Hello, World!')" },
99
- }},
100
- {" name" , " ipython" },
101
- }},
102
- },
103
- })},
104
- }
105
- }, {}, false ).find (" {\" code\" : \" print" ) != std::string::npos;
106
- requires_object_arguments_ = renders_object_arguments;
107
- }
108
- supports_parallel_tool_calls_ = source.find (" tool_call_id" ) != std::string::npos;
111
+ }},
112
+ {" required" , json::array ({ " arg" })},
113
+ }},
114
+ }},
115
+ },
116
+ }), false );
117
+ caps_.supports_tools = contains (out, " some_tool" );
109
118
110
- supports_system_role_ = try_raw_render ({
111
- {{" role" , " system" }, {" content" , " <System Needle>" }},
112
- {{" role" , " user" }, {" content" , " Hey" }}
113
- }, {}, false ).find (" <System Needle>" ) != std::string::npos;
119
+ auto make_tool_calls_msg = [&](const json & tool_calls) {
120
+ return json {
121
+ {" role" , " assistant" },
122
+ {" content" , nullptr },
123
+ {" tool_calls" , tool_calls},
124
+ };
125
+ };
126
+ auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
127
+ return json {
128
+ {" id" , " call_1___" },
129
+ {" type" , " function" },
130
+ {" function" , {
131
+ {" arguments" , arguments},
132
+ {" name" , tool_name},
133
+ }},
134
+ };
135
+ };
136
+ const json dummy_args_obj {{" argument_needle" , " print('Hello, World!')" }};
137
+
138
+ // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
139
+ out = try_raw_render (json::array ({
140
+ dummy_user_msg,
141
+ make_tool_calls_msg (json::array ({make_tool_call (" ipython" , dummy_args_obj.dump ())})),
142
+ }), {}, false );
143
+ auto tool_call_renders_str_arguments = contains (out, " \" argument_needle\" :" ) || contains (out, " 'argument_needle':" );
144
+ out = try_raw_render (json::array ({
145
+ dummy_user_msg,
146
+ make_tool_calls_msg (json::array ({make_tool_call (" ipython" , dummy_args_obj)})),
147
+ }), {}, false );
148
+ auto tool_call_renders_obj_arguments = contains (out, " \" argument_needle\" :" ) || contains (out, " 'argument_needle':" );
149
+
150
+ caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
151
+ caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
152
+ auto out_empty = try_raw_render (json::array ({dummy_user_msg, {{" role" , " assistant" }, {" content" , " " }}}), {}, false );
153
+ auto out_null = try_raw_render (json::array ({dummy_user_msg, {{" role" , " assistant" }, {" content" , nullptr }}}), {}, false );
154
+ caps_.requires_non_null_content = contains (out_empty, user_needle) && !contains (out_null, user_needle);
155
+
156
+ if (caps_.supports_tool_calls ) {
157
+ auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json (dummy_args_obj.dump ());
158
+ auto tc1 = make_tool_call (" test_tool1" , dummy_args);
159
+ auto tc2 = make_tool_call (" test_tool2" , dummy_args);
160
+ auto out = try_raw_render (json::array ({
161
+ dummy_user_msg,
162
+ make_tool_calls_msg (json::array ({tc1, tc2})),
163
+ }), {}, false );
164
+ caps_.supports_parallel_tool_calls = contains (out, " test_tool1" ) && contains (out, " test_tool2" );
114
165
115
- requires_typed_content_ = try_raw_render ({{{" role" , " user" }, {" content" , " Hey" }}}, {}, false ).find (" Hey" ) == std::string::npos
116
- && try_raw_render ({{{" role" , " user" }, {" content" , {{{" type" , " text" }, {" text" , " Hey" }}}}}}, {}, false ).find (" Hey" ) != std::string::npos;
166
+ out = try_raw_render (json::array ({
167
+ dummy_user_msg,
168
+ make_tool_calls_msg (json::array ({tc1})),
169
+ {
170
+ {" role" , " tool" },
171
+ {" name" , " test_tool1" },
172
+ {" content" , " Some response!" },
173
+ {" tool_call_id" , " call_911_" },
174
+ }
175
+ }), {}, false );
176
+ caps_.supports_tool_responses = contains (out, " Some response!" );
177
+ caps_.supports_tool_call_id = contains (out, " call_911_" );
178
+ }
117
179
}
118
180
119
181
const std::string & source () const { return source_; }
120
182
const std::string & bos_token () const { return bos_token_; }
121
183
const std::string & eos_token () const { return eos_token_; }
122
- bool supports_tools () const { return supports_tools_; }
123
- bool supports_parallel_tool_calls () const { return supports_parallel_tool_calls_; }
184
+ const chat_template_caps & original_caps () const { return caps_; }
124
185
125
186
std::string apply (
126
187
const nlohmann::ordered_json & messages,
@@ -131,13 +192,19 @@ class chat_template {
131
192
{
132
193
json actual_messages;
133
194
134
- // First, "fix" messages so they have a chance to be rendered correctly by the template
135
-
136
- if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) {
195
+ auto needs_adjustments = adjust_inputs && (false
196
+ || !caps_.supports_system_role
197
+ || !caps_.supports_tools
198
+ || !caps_.supports_tool_responses
199
+ || !caps_.supports_tool_calls
200
+ || caps_.requires_object_arguments
201
+ || caps_.requires_typed_content
202
+ );
203
+ if (needs_adjustments) {
137
204
actual_messages = json::array ();
138
205
139
206
auto add_message = [&](const json & msg) {
140
- if (requires_typed_content_ && msg.contains (" content" ) && !msg.at (" content" ).is_null () && msg.at (" content" ).is_string ()) {
207
+ if (caps_. requires_typed_content && msg.contains (" content" ) && !msg.at (" content" ).is_null () && msg.at (" content" ).is_string ()) {
141
208
actual_messages.push_back ({
142
209
{" role" , msg.at (" role" )},
143
210
{" content" , {{
@@ -160,24 +227,32 @@ class chat_template {
160
227
pending_system.clear ();
161
228
}
162
229
};
163
- for (const auto & message_ : messages) {
230
+ auto needs_tools_in_system = !tools.is_null () && tools.size () > 0 && !caps_.supports_tools ;
231
+
232
+ for (const auto & message_ : needs_tools_in_system ? add_system (messages, " Available tools: " + tools.dump (2 )) : messages) {
164
233
auto message = message_;
165
234
if (!message.contains (" role" ) || !message.contains (" content" )) {
166
235
throw std::runtime_error (" message must have 'role' and 'content' fields: " + message.dump ());
167
236
}
168
237
std::string role = message.at (" role" );
169
238
170
239
if (message.contains (" tool_calls" )) {
171
- if (requires_object_arguments_ || !supports_tools_ ) {
240
+ if (caps_. requires_object_arguments || !caps_. supports_tool_calls ) {
172
241
for (auto & tool_call : message.at (" tool_calls" )) {
173
242
if (tool_call[" type" ] == " function" ) {
174
243
auto & function = tool_call.at (" function" );
175
- std::string arguments = function.at (" arguments" );
176
- function[" arguments" ] = json::parse (arguments);
244
+ auto & arguments = function.at (" arguments" );
245
+ if (arguments.is_string ()) {
246
+ try {
247
+ arguments = json::parse (arguments.get <std::string>());
248
+ } catch (const std::exception & ecvt ) {
249
+ fprintf (stderr, " Failed to parse arguments: %s\n " , ecvt .what ());
250
+ }
251
+ }
177
252
}
178
253
}
179
254
}
180
- if (!supports_tools_ ) {
255
+ if (!caps_. supports_tool_calls ) {
181
256
auto content = message.at (" content" );
182
257
auto tool_calls = json::array ();
183
258
for (const auto & tool_call : message.at (" tool_calls" )) {
@@ -204,7 +279,7 @@ class chat_template {
204
279
message.erase (" tool_calls" );
205
280
}
206
281
}
207
- if (!supports_tools_ && role == " tool" ) {
282
+ if (!caps_. supports_tool_responses && role == " tool" ) {
208
283
message[" role" ] = " user" ;
209
284
auto obj = json {
210
285
{" tool_response" , {
@@ -219,7 +294,7 @@ class chat_template {
219
294
message.erase (" name" );
220
295
}
221
296
222
- if (!message[" content" ].is_null () && !supports_system_role_ ) {
297
+ if (!message[" content" ].is_null () && !caps_. supports_system_role ) {
223
298
std::string content = message.at (" content" );
224
299
if (role == " system" ) {
225
300
if (!pending_system.empty ()) pending_system += " \n " ;
@@ -238,7 +313,9 @@ class chat_template {
238
313
}
239
314
add_message (message);
240
315
}
241
- flush_sys ();
316
+ if (!caps_.supports_system_role ) {
317
+ flush_sys ();
318
+ }
242
319
} else {
243
320
actual_messages = messages;
244
321
}
@@ -261,7 +338,28 @@ class chat_template {
261
338
}
262
339
}
263
340
264
- return template_root_->render (context);
341
+ auto ret = template_root_->render (context);
342
+ // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
343
+ // fprintf(stderr, "apply: %s\n\n", ret.c_str());
344
+ return ret;
345
+ }
346
+
347
+ static nlohmann::ordered_json add_system (const nlohmann::ordered_json & messages, const std::string & system_prompt) {
348
+ json messages_with_system = messages;
349
+
350
+ if (messages_with_system.size () > 0 && messages_with_system[0 ].at (" role" ) == " system" ) {
351
+ std::string existing_system = messages_with_system.at (0 ).at (" content" );
352
+ messages_with_system[0 ] = json {
353
+ {" role" , " system" },
354
+ {" content" , existing_system + " \n " + system_prompt},
355
+ };
356
+ } else {
357
+ messages_with_system.insert (messages_with_system.begin (), json {
358
+ {" role" , " system" },
359
+ {" content" , system_prompt},
360
+ });
361
+ }
362
+ return messages_with_system;
265
363
}
266
364
};
267
365
0 commit comments