Skip to content

Commit 8dbb9e5

Browse files
committed
Added chat template support to llama-run
Fixes: #11178 The llama-run CLI currently doesn't take the chat template of a model into account. Thus executing llama-run on a model requiring a chat template will fail. In order to solve this, the chat template is being downloaded from ollama or huggingface as well and applied during the chat. Signed-off-by: Michael Engel <mengel@redhat.com>
1 parent ba8a1f9 commit 8dbb9e5

File tree

1 file changed

+146
-31
lines changed

1 file changed

+146
-31
lines changed

examples/run/run.cpp

+146-31
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#if defined(_WIN32)
2-
# include <windows.h>
32
# include <io.h>
3+
# include <windows.h>
44
#else
55
# include <sys/file.h>
66
# include <sys/ioctl.h>
@@ -12,12 +12,14 @@
1212
#endif
1313

1414
#include <signal.h>
15+
#include <sys/stat.h>
1516

1617
#include <climits>
1718
#include <cstdarg>
1819
#include <cstdio>
1920
#include <cstring>
2021
#include <filesystem>
22+
#include <fstream>
2123
#include <iostream>
2224
#include <sstream>
2325
#include <string>
@@ -34,14 +36,17 @@
3436
}
3537
#endif
3638

39+
#define LLAMA_USE_CURL
40+
3741
GGML_ATTRIBUTE_FORMAT(1, 2)
42+
3843
static std::string fmt(const char * fmt, ...) {
3944
va_list ap;
4045
va_list ap2;
4146
va_start(ap, fmt);
4247
va_copy(ap2, ap);
4348
const int size = vsnprintf(NULL, 0, fmt, ap);
44-
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
49+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
4550
std::string buf;
4651
buf.resize(size);
4752
const int size2 = vsnprintf(const_cast<char *>(buf.data()), buf.size() + 1, fmt, ap2);
@@ -53,6 +58,7 @@ static std::string fmt(const char * fmt, ...) {
5358
}
5459

5560
GGML_ATTRIBUTE_FORMAT(1, 2)
61+
5662
static int printe(const char * fmt, ...) {
5763
va_list args;
5864
va_start(args, fmt);
@@ -101,7 +107,8 @@ class Opt {
101107

102108
llama_context_params ctx_params;
103109
llama_model_params model_params;
104-
std::string model_;
110+
std::string model_;
111+
std::string chat_template_;
105112
std::string user;
106113
int context_size = -1, ngl = -1;
107114
float temperature = -1;
@@ -137,7 +144,7 @@ class Opt {
137144
}
138145

139146
int parse(int argc, const char ** argv) {
140-
bool options_parsing = true;
147+
bool options_parsing = true;
141148
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
142149
if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) {
143150
if (handle_option_with_value(argc, argv, i, context_size) == 1) {
@@ -166,6 +173,11 @@ class Opt {
166173

167174
++positional_args_i;
168175
model_ = argv[i];
176+
} else if (options_parsing && strcmp(argv[i], "--chat-template") == 0) {
177+
if (i + 1 >= argc) {
178+
return 1;
179+
}
180+
chat_template_ = argv[++i];
169181
} else if (positional_args_i == 1) {
170182
++positional_args_i;
171183
user = argv[i];
@@ -475,7 +487,9 @@ class HttpClient {
475487
return (now_downloaded_plus_file_size * 100) / total_to_download;
476488
}
477489

478-
static std::string generate_progress_prefix(curl_off_t percentage) { return fmt("%3ld%% |", static_cast<long int>(percentage)); }
490+
static std::string generate_progress_prefix(curl_off_t percentage) {
491+
return fmt("%3ld%% |", static_cast<long int>(percentage));
492+
}
479493

480494
static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
481495
const auto now = std::chrono::steady_clock::now();
@@ -515,6 +529,7 @@ class HttpClient {
515529
printe("\r%*s\r%s%s| %s", get_terminal_width(), " ", progress_prefix.c_str(), progress_bar.c_str(),
516530
progress_suffix.c_str());
517531
}
532+
518533
// Function to write data to a file
519534
static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
520535
FILE * out = static_cast<FILE *>(stream);
@@ -538,19 +553,23 @@ class LlamaData {
538553
std::vector<llama_chat_message> messages;
539554
std::vector<std::string> msg_strs;
540555
std::vector<char> fmtted;
556+
std::string chat_template;
541557

542558
int init(Opt & opt) {
543559
model = initialize_model(opt);
544560
if (!model) {
545561
return 1;
546562
}
547563

564+
chat_template = initialize_chat_template(opt);
565+
548566
context = initialize_context(model, opt);
549567
if (!context) {
550568
return 1;
551569
}
552570

553571
sampler = initialize_sampler(opt);
572+
554573
return 0;
555574
}
556575

@@ -573,21 +592,76 @@ class LlamaData {
573592
}
574593
#endif
575594

576-
int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
595+
int huggingface_dl_tmpl(const std::string & hfr, const std::vector<std::string> headers, const std::string & tn) {
596+
// if template already exists, don't download it
597+
struct stat info;
598+
if (stat(tn.c_str(), &info) == 0) {
599+
return 0;
600+
}
601+
602+
const std::string config_url = "https://huggingface.co/" + hfr + "/resolve/main/tokenizer_config.json";
603+
std::string tokenizer_config_str;
604+
download(config_url, headers, "", true, &tokenizer_config_str);
605+
if (tokenizer_config_str.empty()) {
606+
// still return success since tokenizer_config is optional
607+
return 0;
608+
}
609+
610+
nlohmann::json config = nlohmann::json::parse(tokenizer_config_str);
611+
std::string tmpl = config["chat_template"];
612+
613+
FILE * tmpl_file = fopen(tn.c_str(), "w");
614+
if (tmpl_file == NULL) {
615+
return 1;
616+
}
617+
fprintf(tmpl_file, "%s", tmpl.c_str());
618+
fclose(tmpl_file);
619+
620+
return 0;
621+
}
622+
623+
int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn,
624+
const std::string & tn) {
625+
bool model_exists = std::filesystem::exists(bn);
626+
bool chat_tmpl_exists = std::filesystem::exists(tn);
627+
if (model_exists && chat_tmpl_exists) {
628+
return 0;
629+
}
630+
577631
// Find the second occurrence of '/' after protocol string
578632
size_t pos = model.find('/');
579633
pos = model.find('/', pos + 1);
580634
if (pos == std::string::npos) {
581635
return 1;
582636
}
583-
584637
const std::string hfr = model.substr(0, pos);
585638
const std::string hff = model.substr(pos + 1);
586-
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
587-
return download(url, headers, bn, true);
639+
640+
if (!chat_tmpl_exists) {
641+
const int ret = huggingface_dl_tmpl(hfr, headers, tn);
642+
if (ret) {
643+
return ret;
644+
}
645+
}
646+
647+
if (!model_exists) {
648+
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
649+
const int ret = download(url, headers, bn, true);
650+
if (ret) {
651+
return ret;
652+
}
653+
}
654+
return 0;
588655
}
589656

590-
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
657+
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn,
658+
const std::string & tn) {
659+
bool model_exists = std::filesystem::exists(bn);
660+
bool chat_tmpl_exists = std::filesystem::exists(tn);
661+
if (model_exists && chat_tmpl_exists) {
662+
return 0;
663+
}
664+
591665
if (model.find('/') == std::string::npos) {
592666
model = "library/" + model;
593667
}
@@ -607,16 +681,34 @@ class LlamaData {
607681
}
608682

609683
nlohmann::json manifest = nlohmann::json::parse(manifest_str);
610-
std::string layer;
684+
std::string sha_model;
685+
std::string sha_template;
611686
for (const auto & l : manifest["layers"]) {
612687
if (l["mediaType"] == "application/vnd.ollama.image.model") {
613-
layer = l["digest"];
614-
break;
688+
sha_model = l["digest"];
689+
}
690+
if (l["mediaType"] == "application/vnd.ollama.image.template") {
691+
sha_template = l["digest"];
692+
}
693+
}
694+
695+
if (!chat_tmpl_exists && !sha_template.empty()) {
696+
std::string tmpl_blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + sha_template;
697+
const int tmpl_ret = download(tmpl_blob_url, headers, tn, true);
698+
if (tmpl_ret) {
699+
return tmpl_ret;
615700
}
616701
}
617702

618-
std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
619-
return download(blob_url, headers, bn, true);
703+
if (!model_exists) {
704+
std::string model_blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + sha_model;
705+
const int model_ret = download(model_blob_url, headers, bn, true);
706+
if (model_ret) {
707+
return model_ret;
708+
}
709+
}
710+
711+
return 0;
620712
}
621713

622714
std::string basename(const std::string & path) {
@@ -638,38 +730,38 @@ class LlamaData {
638730
return 0;
639731
}
640732

641-
int resolve_model(std::string & model_) {
642-
int ret = 0;
643-
if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) {
733+
int resolve_model(std::string & model_, std::string & chat_template_) {
734+
int ret = 0;
735+
if (string_starts_with(model_, "file://")) {
644736
remove_proto(model_);
645-
646737
return ret;
647738
}
648739

740+
remove_proto(model_);
649741
const std::string bn = basename(model_);
742+
const std::string tn = chat_template_.empty() ? bn + ".template" : chat_template_;
650743
const std::vector<std::string> headers = { "--header",
651744
"Accept: application/vnd.docker.distribution.manifest.v2+json" };
652745
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
653-
remove_proto(model_);
654-
ret = huggingface_dl(model_, headers, bn);
746+
ret = huggingface_dl(model_, headers, bn, tn);
655747
} else if (string_starts_with(model_, "ollama://")) {
656-
remove_proto(model_);
657-
ret = ollama_dl(model_, headers, bn);
748+
ret = ollama_dl(model_, headers, bn, tn);
658749
} else if (string_starts_with(model_, "https://")) {
659750
download(model_, headers, bn, true);
660751
} else {
661-
ret = ollama_dl(model_, headers, bn);
752+
ret = ollama_dl(model_, headers, bn, tn);
662753
}
663754

664-
model_ = bn;
755+
model_ = bn;
756+
chat_template_ = tn;
665757

666758
return ret;
667759
}
668760

669761
// Initializes the model and returns a unique pointer to it
670762
llama_model_ptr initialize_model(Opt & opt) {
671763
ggml_backend_load_all();
672-
resolve_model(opt.model_);
764+
resolve_model(opt.model_, opt.chat_template_);
673765
printe(
674766
"\r%*s"
675767
"\rLoading model",
@@ -702,6 +794,27 @@ class LlamaData {
702794

703795
return sampler;
704796
}
797+
798+
std::string initialize_chat_template(const Opt & opt) {
799+
// if no template file doesn't exists, just return an empty string
800+
struct stat info;
801+
if (stat(opt.chat_template_.c_str(), &info) != 0) {
802+
return "";
803+
}
804+
805+
std::ifstream tmpl_file;
806+
tmpl_file.open(opt.chat_template_);
807+
if (tmpl_file.fail()) {
808+
printe("failed to open chat template: '%s'\n", opt.chat_template_.c_str());
809+
return "";
810+
}
811+
812+
std::stringstream stream;
813+
stream << tmpl_file.rdbuf();
814+
tmpl_file.close();
815+
816+
return stream.str();
817+
}
705818
};
706819

707820
// Add a message to `messages` and store its content in `msg_strs`
@@ -713,13 +826,15 @@ static void add_message(const char * role, const std::string & text, LlamaData &
713826
// Function to apply the chat template and resize `formatted` if needed
714827
static int apply_chat_template(LlamaData & llama_data, const bool append) {
715828
int result = llama_chat_apply_template(
716-
llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append,
717-
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
829+
llama_data.model.get(), llama_data.chat_template.empty() ? nullptr : llama_data.chat_template.c_str(),
830+
llama_data.messages.data(), llama_data.messages.size(), append, append ? llama_data.fmtted.data() : nullptr,
831+
append ? llama_data.fmtted.size() : 0);
718832
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
719833
llama_data.fmtted.resize(result);
720-
result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(),
721-
llama_data.messages.size(), append, llama_data.fmtted.data(),
722-
llama_data.fmtted.size());
834+
result = llama_chat_apply_template(
835+
llama_data.model.get(), llama_data.chat_template.empty() ? nullptr : llama_data.chat_template.c_str(),
836+
llama_data.messages.data(), llama_data.messages.size(), append, llama_data.fmtted.data(),
837+
llama_data.fmtted.size());
723838
}
724839

725840
return result;

0 commit comments

Comments
 (0)