1
1
#if defined(_WIN32)
2
- # include < windows.h>
3
2
# include < io.h>
3
+ # include < windows.h>
4
4
#else
5
5
# include < sys/file.h>
6
6
# include < sys/ioctl.h>
12
12
#endif
13
13
14
14
#include < signal.h>
15
+ #include < sys/stat.h>
15
16
16
17
#include < climits>
17
18
#include < cstdarg>
18
19
#include < cstdio>
19
20
#include < cstring>
20
21
#include < filesystem>
22
+ #include < fstream>
21
23
#include < iostream>
22
24
#include < sstream>
23
25
#include < string>
35
37
#endif
36
38
37
39
GGML_ATTRIBUTE_FORMAT (1 , 2 )
40
+
38
41
static std::string fmt(const char * fmt, ...) {
39
42
va_list ap;
40
43
va_list ap2;
41
44
va_start (ap, fmt);
42
45
va_copy (ap2, ap);
43
46
const int size = vsnprintf (NULL , 0 , fmt, ap);
44
- GGML_ASSERT (size >= 0 && size < INT_MAX); // NOLINT
47
+ GGML_ASSERT (size >= 0 && size < INT_MAX); // NOLINT
45
48
std::string buf;
46
49
buf.resize (size);
47
50
const int size2 = vsnprintf (const_cast <char *>(buf.data ()), buf.size () + 1 , fmt, ap2);
@@ -53,6 +56,7 @@ static std::string fmt(const char * fmt, ...) {
53
56
}
54
57
55
58
GGML_ATTRIBUTE_FORMAT (1 , 2 )
59
+
56
60
static int printe(const char * fmt, ...) {
57
61
va_list args;
58
62
va_start (args, fmt);
@@ -101,7 +105,8 @@ class Opt {
101
105
102
106
llama_context_params ctx_params;
103
107
llama_model_params model_params;
104
- std::string model_;
108
+ std::string model_;
109
+ std::string chat_template_;
105
110
std::string user;
106
111
int context_size = -1 , ngl = -1 ;
107
112
float temperature = -1 ;
@@ -137,7 +142,7 @@ class Opt {
137
142
}
138
143
139
144
int parse (int argc, const char ** argv) {
140
- bool options_parsing = true ;
145
+ bool options_parsing = true ;
141
146
for (int i = 1 , positional_args_i = 0 ; i < argc; ++i) {
142
147
if (options_parsing && (strcmp (argv[i], " -c" ) == 0 || strcmp (argv[i], " --context-size" ) == 0 )) {
143
148
if (handle_option_with_value (argc, argv, i, context_size) == 1 ) {
@@ -166,6 +171,11 @@ class Opt {
166
171
167
172
++positional_args_i;
168
173
model_ = argv[i];
174
+ } else if (options_parsing && strcmp (argv[i], " --chat-template" ) == 0 ) {
175
+ if (i + 1 >= argc) {
176
+ return 1 ;
177
+ }
178
+ chat_template_ = argv[++i];
169
179
} else if (positional_args_i == 1 ) {
170
180
++positional_args_i;
171
181
user = argv[i];
@@ -475,7 +485,9 @@ class HttpClient {
475
485
return (now_downloaded_plus_file_size * 100 ) / total_to_download;
476
486
}
477
487
478
- static std::string generate_progress_prefix (curl_off_t percentage) { return fmt (" %3ld%% |" , static_cast <long int >(percentage)); }
488
+ static std::string generate_progress_prefix (curl_off_t percentage) {
489
+ return fmt (" %3ld%% |" , static_cast <long int >(percentage));
490
+ }
479
491
480
492
static double calculate_speed (curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
481
493
const auto now = std::chrono::steady_clock::now ();
@@ -515,6 +527,7 @@ class HttpClient {
515
527
printe (" \r %*s\r %s%s| %s" , get_terminal_width (), " " , progress_prefix.c_str (), progress_bar.c_str (),
516
528
progress_suffix.c_str ());
517
529
}
530
+
518
531
// Function to write data to a file
519
532
static size_t write_data (void * ptr, size_t size, size_t nmemb, void * stream) {
520
533
FILE * out = static_cast <FILE *>(stream);
@@ -538,19 +551,23 @@ class LlamaData {
538
551
std::vector<llama_chat_message> messages;
539
552
std::vector<std::string> msg_strs;
540
553
std::vector<char > fmtted;
554
+ std::string chat_template;
541
555
542
556
int init (Opt & opt) {
543
557
model = initialize_model (opt);
544
558
if (!model) {
545
559
return 1 ;
546
560
}
547
561
562
+ chat_template = initialize_chat_template (model, opt);
563
+
548
564
context = initialize_context (model, opt);
549
565
if (!context) {
550
566
return 1 ;
551
567
}
552
568
553
569
sampler = initialize_sampler (opt);
570
+
554
571
return 0 ;
555
572
}
556
573
@@ -573,21 +590,76 @@ class LlamaData {
573
590
}
574
591
#endif
575
592
576
- int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
593
+ int huggingface_dl_tmpl (const std::string & hfr, const std::vector<std::string> headers, const std::string & tn) {
594
+ // if template already exists, don't download it
595
+ struct stat info;
596
+ if (stat (tn.c_str (), &info) == 0 ) {
597
+ return 0 ;
598
+ }
599
+
600
+ const std::string config_url = " https://huggingface.co/" + hfr + " /resolve/main/tokenizer_config.json" ;
601
+ std::string tokenizer_config_str;
602
+ download (config_url, headers, " " , true , &tokenizer_config_str);
603
+ if (tokenizer_config_str.empty ()) {
604
+ // still return success since tokenizer_config is optional
605
+ return 0 ;
606
+ }
607
+
608
+ nlohmann::json config = nlohmann::json::parse (tokenizer_config_str);
609
+ std::string tmpl = config[" chat_template" ];
610
+
611
+ FILE * tmpl_file = fopen (tn.c_str (), " w" );
612
+ if (tmpl_file == NULL ) {
613
+ return 1 ;
614
+ }
615
+ fprintf (tmpl_file, " %s" , tmpl.c_str ());
616
+ fclose (tmpl_file);
617
+
618
+ return 0 ;
619
+ }
620
+
621
+ int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn,
622
+ const std::string & tn) {
623
+ bool model_exists = std::filesystem::exists (bn);
624
+ bool chat_tmpl_exists = std::filesystem::exists (tn);
625
+ if (model_exists && chat_tmpl_exists) {
626
+ return 0 ;
627
+ }
628
+
577
629
// Find the second occurrence of '/' after protocol string
578
630
size_t pos = model.find (' /' );
579
631
pos = model.find (' /' , pos + 1 );
580
632
if (pos == std::string::npos) {
581
633
return 1 ;
582
634
}
583
-
584
635
const std::string hfr = model.substr (0 , pos);
585
636
const std::string hff = model.substr (pos + 1 );
586
- const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
587
- return download (url, headers, bn, true );
637
+
638
+ if (!chat_tmpl_exists) {
639
+ const int ret = huggingface_dl_tmpl (hfr, headers, tn);
640
+ if (ret) {
641
+ return ret;
642
+ }
643
+ }
644
+
645
+ if (!model_exists) {
646
+ const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
647
+ const int ret = download (url, headers, bn, true );
648
+ if (ret) {
649
+ return ret;
650
+ }
651
+ }
652
+ return 0 ;
588
653
}
589
654
590
- int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn) {
655
+ int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn,
656
+ const std::string & tn) {
657
+ bool model_exists = std::filesystem::exists (bn);
658
+ bool chat_tmpl_exists = std::filesystem::exists (tn);
659
+ if (model_exists && chat_tmpl_exists) {
660
+ return 0 ;
661
+ }
662
+
591
663
if (model.find (' /' ) == std::string::npos) {
592
664
model = " library/" + model;
593
665
}
@@ -607,16 +679,34 @@ class LlamaData {
607
679
}
608
680
609
681
nlohmann::json manifest = nlohmann::json::parse (manifest_str);
610
- std::string layer;
682
+ std::string sha_model;
683
+ std::string sha_template;
611
684
for (const auto & l : manifest[" layers" ]) {
612
685
if (l[" mediaType" ] == " application/vnd.ollama.image.model" ) {
613
- layer = l[" digest" ];
614
- break ;
686
+ sha_model = l[" digest" ];
687
+ }
688
+ if (l[" mediaType" ] == " application/vnd.ollama.image.template" ) {
689
+ sha_template = l[" digest" ];
690
+ }
691
+ }
692
+
693
+ if (!chat_tmpl_exists && !sha_template.empty ()) {
694
+ std::string tmpl_blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + sha_template;
695
+ const int tmpl_ret = download (tmpl_blob_url, headers, tn, true );
696
+ if (tmpl_ret) {
697
+ return tmpl_ret;
698
+ }
699
+ }
700
+
701
+ if (!model_exists) {
702
+ std::string model_blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + sha_model;
703
+ const int model_ret = download (model_blob_url, headers, bn, true );
704
+ if (model_ret) {
705
+ return model_ret;
615
706
}
616
707
}
617
708
618
- std::string blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + layer;
619
- return download (blob_url, headers, bn, true );
709
+ return 0 ;
620
710
}
621
711
622
712
std::string basename (const std::string & path) {
@@ -628,6 +718,15 @@ class LlamaData {
628
718
return path.substr (pos + 1 );
629
719
}
630
720
721
+ std::string get_proto (const std::string & model_) {
722
+ const std::string::size_type pos = model_.find (" ://" );
723
+ if (pos == std::string::npos) {
724
+ return " " ;
725
+ }
726
+
727
+ return model_.substr (0 , pos + 3 ); // Include "://"
728
+ }
729
+
631
730
int remove_proto (std::string & model_) {
632
731
const std::string::size_type pos = model_.find (" ://" );
633
732
if (pos == std::string::npos) {
@@ -638,38 +737,40 @@ class LlamaData {
638
737
return 0 ;
639
738
}
640
739
641
- int resolve_model (std::string & model_) {
642
- int ret = 0 ;
643
- if (string_starts_with (model_, " file://" ) || std::filesystem::exists (model_) ) {
740
+ int resolve_model (std::string & model_, std::string & chat_template_ ) {
741
+ int ret = 0 ;
742
+ if (string_starts_with (model_, " file://" )) {
644
743
remove_proto (model_);
645
-
646
744
return ret;
647
745
}
648
746
747
+ std::string proto = get_proto (model_);
748
+ remove_proto (model_);
749
+
649
750
const std::string bn = basename (model_);
751
+ const std::string tn = chat_template_.empty () ? bn + " .template" : chat_template_;
650
752
const std::vector<std::string> headers = { " --header" ,
651
753
" Accept: application/vnd.docker.distribution.manifest.v2+json" };
652
- if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" )) {
653
- remove_proto (model_);
654
- ret = huggingface_dl (model_, headers, bn);
655
- } else if (string_starts_with (model_, " ollama://" )) {
656
- remove_proto (model_);
657
- ret = ollama_dl (model_, headers, bn);
658
- } else if (string_starts_with (model_, " https://" )) {
754
+ if (string_starts_with (proto, " hf://" ) || string_starts_with (proto, " huggingface://" )) {
755
+ ret = huggingface_dl (model_, headers, bn, tn);
756
+ } else if (string_starts_with (proto, " ollama://" )) {
757
+ ret = ollama_dl (model_, headers, bn, tn);
758
+ } else if (string_starts_with (proto, " https://" )) {
659
759
download (model_, headers, bn, true );
660
760
} else {
661
- ret = ollama_dl (model_, headers, bn);
761
+ ret = ollama_dl (model_, headers, bn, tn );
662
762
}
663
763
664
- model_ = bn;
764
+ model_ = bn;
765
+ chat_template_ = tn;
665
766
666
767
return ret;
667
768
}
668
769
669
770
// Initializes the model and returns a unique pointer to it
670
771
llama_model_ptr initialize_model (Opt & opt) {
671
772
ggml_backend_load_all ();
672
- resolve_model (opt.model_ );
773
+ resolve_model (opt.model_ , opt. chat_template_ );
673
774
printe (
674
775
" \r %*s"
675
776
" \r Loading model" ,
@@ -702,6 +803,27 @@ class LlamaData {
702
803
703
804
return sampler;
704
805
}
806
+
807
+ std::string initialize_chat_template (const llama_model_ptr & model, const Opt & opt) {
808
+ // if no template file doesn't exists, just return an empty string
809
+ struct stat info;
810
+ if (stat (opt.chat_template_ .c_str (), &info) != 0 ) {
811
+ return common_get_builtin_chat_template (model.get ());
812
+ }
813
+
814
+ std::ifstream tmpl_file;
815
+ tmpl_file.open (opt.chat_template_ );
816
+ if (tmpl_file.fail ()) {
817
+ printe (" failed to open chat template: '%s'\n " , opt.chat_template_ .c_str ());
818
+ return " " ;
819
+ }
820
+
821
+ std::stringstream stream;
822
+ stream << tmpl_file.rdbuf ();
823
+ tmpl_file.close ();
824
+
825
+ return stream.str ();
826
+ }
705
827
};
706
828
707
829
// Add a message to `messages` and store its content in `msg_strs`
@@ -713,11 +835,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
713
835
// Function to apply the chat template and resize `formatted` if needed
714
836
static int apply_chat_template (LlamaData & llama_data, const bool append) {
715
837
int result = llama_chat_apply_template (
716
- llama_model_chat_template ( llama_data.model . get () ), llama_data.messages .data (), llama_data.messages .size (), append,
838
+ llama_data.chat_template . c_str ( ), llama_data.messages .data (), llama_data.messages .size (), append,
717
839
append ? llama_data.fmtted .data () : nullptr , append ? llama_data.fmtted .size () : 0 );
718
840
if (append && result > static_cast <int >(llama_data.fmtted .size ())) {
719
841
llama_data.fmtted .resize (result);
720
- result = llama_chat_apply_template (llama_model_chat_template ( llama_data.model . get () ), llama_data.messages .data (),
842
+ result = llama_chat_apply_template (llama_data.chat_template . c_str ( ), llama_data.messages .data (),
721
843
llama_data.messages .size (), append, llama_data.fmtted .data (),
722
844
llama_data.fmtted .size ());
723
845
}
@@ -730,8 +852,8 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
730
852
std::vector<llama_token> & prompt_tokens) {
731
853
const int n_prompt_tokens = -llama_tokenize (vocab, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
732
854
prompt_tokens.resize (n_prompt_tokens);
733
- if (llama_tokenize (vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true ,
734
- true ) < 0 ) {
855
+ if (llama_tokenize (vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true , true ) <
856
+ 0 ) {
735
857
printe (" failed to tokenize the prompt\n " );
736
858
return -1 ;
737
859
}
0 commit comments