1
1
#if defined(_WIN32)
2
- # include < windows.h>
3
2
# include < io.h>
3
+ # include < windows.h>
4
4
#else
5
5
# include < sys/file.h>
6
6
# include < sys/ioctl.h>
12
12
#endif
13
13
14
14
#include < signal.h>
15
+ #include < sys/stat.h>
15
16
16
17
#include < climits>
17
18
#include < cstdarg>
18
19
#include < cstdio>
19
20
#include < cstring>
20
21
#include < filesystem>
22
+ #include < fstream>
21
23
#include < iostream>
22
24
#include < sstream>
23
25
#include < string>
34
36
}
35
37
#endif
36
38
39
+ #define LLAMA_USE_CURL
40
+
37
41
GGML_ATTRIBUTE_FORMAT (1 , 2 )
42
+
38
43
static std::string fmt(const char * fmt, ...) {
39
44
va_list ap;
40
45
va_list ap2;
41
46
va_start (ap, fmt);
42
47
va_copy (ap2, ap);
43
48
const int size = vsnprintf (NULL , 0 , fmt, ap);
44
- GGML_ASSERT (size >= 0 && size < INT_MAX); // NOLINT
49
+ GGML_ASSERT (size >= 0 && size < INT_MAX); // NOLINT
45
50
std::string buf;
46
51
buf.resize (size);
47
52
const int size2 = vsnprintf (const_cast <char *>(buf.data ()), buf.size () + 1 , fmt, ap2);
@@ -53,6 +58,7 @@ static std::string fmt(const char * fmt, ...) {
53
58
}
54
59
55
60
GGML_ATTRIBUTE_FORMAT (1 , 2 )
61
+
56
62
static int printe(const char * fmt, ...) {
57
63
va_list args;
58
64
va_start (args, fmt);
@@ -101,7 +107,8 @@ class Opt {
101
107
102
108
llama_context_params ctx_params;
103
109
llama_model_params model_params;
104
- std::string model_;
110
+ std::string model_;
111
+ std::string chat_template_;
105
112
std::string user;
106
113
int context_size = -1 , ngl = -1 ;
107
114
float temperature = -1 ;
@@ -137,7 +144,7 @@ class Opt {
137
144
}
138
145
139
146
int parse (int argc, const char ** argv) {
140
- bool options_parsing = true ;
147
+ bool options_parsing = true ;
141
148
for (int i = 1 , positional_args_i = 0 ; i < argc; ++i) {
142
149
if (options_parsing && (strcmp (argv[i], " -c" ) == 0 || strcmp (argv[i], " --context-size" ) == 0 )) {
143
150
if (handle_option_with_value (argc, argv, i, context_size) == 1 ) {
@@ -166,6 +173,11 @@ class Opt {
166
173
167
174
++positional_args_i;
168
175
model_ = argv[i];
176
+ } else if (options_parsing && strcmp (argv[i], " --chat-template" ) == 0 ) {
177
+ if (i + 1 >= argc) {
178
+ return 1 ;
179
+ }
180
+ chat_template_ = argv[++i];
169
181
} else if (positional_args_i == 1 ) {
170
182
++positional_args_i;
171
183
user = argv[i];
@@ -475,7 +487,9 @@ class HttpClient {
475
487
return (now_downloaded_plus_file_size * 100 ) / total_to_download;
476
488
}
477
489
478
- static std::string generate_progress_prefix (curl_off_t percentage) { return fmt (" %3ld%% |" , static_cast <long int >(percentage)); }
490
+ static std::string generate_progress_prefix (curl_off_t percentage) {
491
+ return fmt (" %3ld%% |" , static_cast <long int >(percentage));
492
+ }
479
493
480
494
static double calculate_speed (curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
481
495
const auto now = std::chrono::steady_clock::now ();
@@ -515,6 +529,7 @@ class HttpClient {
515
529
printe (" \r %*s\r %s%s| %s" , get_terminal_width (), " " , progress_prefix.c_str (), progress_bar.c_str (),
516
530
progress_suffix.c_str ());
517
531
}
532
+
518
533
// Function to write data to a file
519
534
static size_t write_data (void * ptr, size_t size, size_t nmemb, void * stream) {
520
535
FILE * out = static_cast <FILE *>(stream);
@@ -538,19 +553,23 @@ class LlamaData {
538
553
std::vector<llama_chat_message> messages;
539
554
std::vector<std::string> msg_strs;
540
555
std::vector<char > fmtted;
556
+ std::string chat_template;
541
557
542
558
int init (Opt & opt) {
543
559
model = initialize_model (opt);
544
560
if (!model) {
545
561
return 1 ;
546
562
}
547
563
564
+ chat_template = initialize_chat_template (opt);
565
+
548
566
context = initialize_context (model, opt);
549
567
if (!context) {
550
568
return 1 ;
551
569
}
552
570
553
571
sampler = initialize_sampler (opt);
572
+
554
573
return 0 ;
555
574
}
556
575
@@ -573,21 +592,76 @@ class LlamaData {
573
592
}
574
593
#endif
575
594
576
- int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
595
+ int huggingface_dl_tmpl (const std::string & hfr, const std::vector<std::string> headers, const std::string & tn) {
596
+ // if template already exists, don't download it
597
+ struct stat info;
598
+ if (stat (tn.c_str (), &info) == 0 ) {
599
+ return 0 ;
600
+ }
601
+
602
+ const std::string config_url = " https://huggingface.co/" + hfr + " /resolve/main/tokenizer_config.json" ;
603
+ std::string tokenizer_config_str;
604
+ download (config_url, headers, " " , true , &tokenizer_config_str);
605
+ if (tokenizer_config_str.empty ()) {
606
+ // still return success since tokenizer_config is optional
607
+ return 0 ;
608
+ }
609
+
610
+ nlohmann::json config = nlohmann::json::parse (tokenizer_config_str);
611
+ std::string tmpl = config[" chat_template" ];
612
+
613
+ FILE * tmpl_file = fopen (tn.c_str (), " w" );
614
+ if (tmpl_file == NULL ) {
615
+ return 1 ;
616
+ }
617
+ fprintf (tmpl_file, " %s" , tmpl.c_str ());
618
+ fclose (tmpl_file);
619
+
620
+ return 0 ;
621
+ }
622
+
623
+ int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn,
624
+ const std::string & tn) {
625
+ bool model_exists = std::filesystem::exists (bn);
626
+ bool chat_tmpl_exists = std::filesystem::exists (tn);
627
+ if (model_exists && chat_tmpl_exists) {
628
+ return 0 ;
629
+ }
630
+
577
631
// Find the second occurrence of '/' after protocol string
578
632
size_t pos = model.find (' /' );
579
633
pos = model.find (' /' , pos + 1 );
580
634
if (pos == std::string::npos) {
581
635
return 1 ;
582
636
}
583
-
584
637
const std::string hfr = model.substr (0 , pos);
585
638
const std::string hff = model.substr (pos + 1 );
586
- const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
587
- return download (url, headers, bn, true );
639
+
640
+ if (!chat_tmpl_exists) {
641
+ const int ret = huggingface_dl_tmpl (hfr, headers, tn);
642
+ if (ret) {
643
+ return ret;
644
+ }
645
+ }
646
+
647
+ if (!model_exists) {
648
+ const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
649
+ const int ret = download (url, headers, bn, true );
650
+ if (ret) {
651
+ return ret;
652
+ }
653
+ }
654
+ return 0 ;
588
655
}
589
656
590
- int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn) {
657
+ int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn,
658
+ const std::string & tn) {
659
+ bool model_exists = std::filesystem::exists (bn);
660
+ bool chat_tmpl_exists = std::filesystem::exists (tn);
661
+ if (model_exists && chat_tmpl_exists) {
662
+ return 0 ;
663
+ }
664
+
591
665
if (model.find (' /' ) == std::string::npos) {
592
666
model = " library/" + model;
593
667
}
@@ -607,16 +681,34 @@ class LlamaData {
607
681
}
608
682
609
683
nlohmann::json manifest = nlohmann::json::parse (manifest_str);
610
- std::string layer;
684
+ std::string sha_model;
685
+ std::string sha_template;
611
686
for (const auto & l : manifest[" layers" ]) {
612
687
if (l[" mediaType" ] == " application/vnd.ollama.image.model" ) {
613
- layer = l[" digest" ];
614
- break ;
688
+ sha_model = l[" digest" ];
689
+ }
690
+ if (l[" mediaType" ] == " application/vnd.ollama.image.template" ) {
691
+ sha_template = l[" digest" ];
692
+ }
693
+ }
694
+
695
+ if (!chat_tmpl_exists && !sha_template.empty ()) {
696
+ std::string tmpl_blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + sha_template;
697
+ const int tmpl_ret = download (tmpl_blob_url, headers, tn, true );
698
+ if (tmpl_ret) {
699
+ return tmpl_ret;
615
700
}
616
701
}
617
702
618
- std::string blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + layer;
619
- return download (blob_url, headers, bn, true );
703
+ if (!model_exists) {
704
+ std::string model_blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + sha_model;
705
+ const int model_ret = download (model_blob_url, headers, bn, true );
706
+ if (model_ret) {
707
+ return model_ret;
708
+ }
709
+ }
710
+
711
+ return 0 ;
620
712
}
621
713
622
714
std::string basename (const std::string & path) {
@@ -638,38 +730,38 @@ class LlamaData {
638
730
return 0 ;
639
731
}
640
732
641
- int resolve_model (std::string & model_) {
642
- int ret = 0 ;
643
- if (string_starts_with (model_, " file://" ) || std::filesystem::exists (model_) ) {
733
+ int resolve_model (std::string & model_, std::string & chat_template_ ) {
734
+ int ret = 0 ;
735
+ if (string_starts_with (model_, " file://" )) {
644
736
remove_proto (model_);
645
-
646
737
return ret;
647
738
}
648
739
740
+ remove_proto (model_);
649
741
const std::string bn = basename (model_);
742
+ const std::string tn = chat_template_.empty () ? bn + " .template" : chat_template_;
650
743
const std::vector<std::string> headers = { " --header" ,
651
744
" Accept: application/vnd.docker.distribution.manifest.v2+json" };
652
745
if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" )) {
653
- remove_proto (model_);
654
- ret = huggingface_dl (model_, headers, bn);
746
+ ret = huggingface_dl (model_, headers, bn, tn);
655
747
} else if (string_starts_with (model_, " ollama://" )) {
656
- remove_proto (model_);
657
- ret = ollama_dl (model_, headers, bn);
748
+ ret = ollama_dl (model_, headers, bn, tn);
658
749
} else if (string_starts_with (model_, " https://" )) {
659
750
download (model_, headers, bn, true );
660
751
} else {
661
- ret = ollama_dl (model_, headers, bn);
752
+ ret = ollama_dl (model_, headers, bn, tn );
662
753
}
663
754
664
- model_ = bn;
755
+ model_ = bn;
756
+ chat_template_ = tn;
665
757
666
758
return ret;
667
759
}
668
760
669
761
// Initializes the model and returns a unique pointer to it
670
762
llama_model_ptr initialize_model (Opt & opt) {
671
763
ggml_backend_load_all ();
672
- resolve_model (opt.model_ );
764
+ resolve_model (opt.model_ , opt. chat_template_ );
673
765
printe (
674
766
" \r %*s"
675
767
" \r Loading model" ,
@@ -702,6 +794,27 @@ class LlamaData {
702
794
703
795
return sampler;
704
796
}
797
+
798
+ std::string initialize_chat_template (const Opt & opt) {
799
+ // if no template file doesn't exists, just return an empty string
800
+ struct stat info;
801
+ if (stat (opt.chat_template_ .c_str (), &info) != 0 ) {
802
+ return " " ;
803
+ }
804
+
805
+ std::ifstream tmpl_file;
806
+ tmpl_file.open (opt.chat_template_ );
807
+ if (tmpl_file.fail ()) {
808
+ printe (" failed to open chat template: '%s'\n " , opt.chat_template_ .c_str ());
809
+ return " " ;
810
+ }
811
+
812
+ std::stringstream stream;
813
+ stream << tmpl_file.rdbuf ();
814
+ tmpl_file.close ();
815
+
816
+ return stream.str ();
817
+ }
705
818
};
706
819
707
820
// Add a message to `messages` and store its content in `msg_strs`
@@ -713,13 +826,15 @@ static void add_message(const char * role, const std::string & text, LlamaData &
713
826
// Function to apply the chat template and resize `formatted` if needed
714
827
static int apply_chat_template (LlamaData & llama_data, const bool append) {
715
828
int result = llama_chat_apply_template (
716
- llama_data.model .get (), nullptr , llama_data.messages .data (), llama_data.messages .size (), append,
717
- append ? llama_data.fmtted .data () : nullptr , append ? llama_data.fmtted .size () : 0 );
829
+ llama_data.model .get (), llama_data.chat_template .empty () ? nullptr : llama_data.chat_template .c_str (),
830
+ llama_data.messages .data (), llama_data.messages .size (), append, append ? llama_data.fmtted .data () : nullptr ,
831
+ append ? llama_data.fmtted .size () : 0 );
718
832
if (append && result > static_cast <int >(llama_data.fmtted .size ())) {
719
833
llama_data.fmtted .resize (result);
720
- result = llama_chat_apply_template (llama_data.model .get (), nullptr , llama_data.messages .data (),
721
- llama_data.messages .size (), append, llama_data.fmtted .data (),
722
- llama_data.fmtted .size ());
834
+ result = llama_chat_apply_template (
835
+ llama_data.model .get (), llama_data.chat_template .empty () ? nullptr : llama_data.chat_template .c_str (),
836
+ llama_data.messages .data (), llama_data.messages .size (), append, llama_data.fmtted .data (),
837
+ llama_data.fmtted .size ());
723
838
}
724
839
725
840
return result;
0 commit comments