1
- #include < vector>
2
- #include < cstdio>
3
- #include < chrono>
4
-
5
1
#include " common.h"
6
2
#include " llama.h"
7
- #include " llama.cpp"
8
3
9
- using namespace std ;
4
+ #include < vector>
5
+ #include < cstdio>
6
+ #include < chrono>
10
7
11
8
int main (int argc, char ** argv) {
12
9
gpt_params params;
@@ -20,21 +17,25 @@ int main(int argc, char ** argv) {
20
17
return 1 ;
21
18
}
22
19
20
+ if (params.n_predict < 0 ) {
21
+ params.n_predict = 16 ;
22
+ }
23
+
23
24
auto lparams = llama_context_default_params ();
24
25
25
- lparams.n_ctx = params.n_ctx ;
26
- lparams.n_parts = params.n_parts ;
27
- lparams.seed = params.seed ;
28
- lparams.f16_kv = params.memory_f16 ;
29
- lparams.use_mmap = params.use_mmap ;
30
- lparams.use_mlock = params.use_mlock ;
26
+ lparams.n_ctx = params.n_ctx ;
27
+ lparams.n_parts = params.n_parts ;
28
+ lparams.seed = params.seed ;
29
+ lparams.f16_kv = params.memory_f16 ;
30
+ lparams.use_mmap = params.use_mmap ;
31
+ lparams.use_mlock = params.use_mlock ;
31
32
32
33
auto n_past = 0 ;
33
- auto last_n_tokens_data = vector<llama_token>(params.repeat_last_n , 0 );
34
+ auto last_n_tokens_data = std:: vector<llama_token>(params.repeat_last_n , 0 );
34
35
35
36
// init
36
37
auto ctx = llama_init_from_file (params.model .c_str (), lparams);
37
- auto tokens = vector<llama_token>(params.n_ctx );
38
+ auto tokens = std:: vector<llama_token>(params.n_ctx );
38
39
auto n_prompt_tokens = llama_tokenize (ctx, params.prompt .c_str (), tokens.data (), tokens.size (), true );
39
40
40
41
if (n_prompt_tokens < 1 ) {
@@ -43,26 +44,29 @@ int main(int argc, char ** argv) {
43
44
}
44
45
45
46
// evaluate prompt
46
-
47
47
llama_eval (ctx, tokens.data (), n_prompt_tokens, n_past, params.n_threads );
48
48
49
49
last_n_tokens_data.insert (last_n_tokens_data.end (), tokens.data (), tokens.data () + n_prompt_tokens);
50
50
n_past += n_prompt_tokens;
51
51
52
+ const size_t state_size = llama_get_state_size (ctx);
53
+ uint8_t * state_mem = new uint8_t [state_size];
54
+
52
55
// Save state (rng, logits, embedding and kv_cache) to file
53
- FILE *fp_write = fopen ( " dump_state.bin " , " wb " );
54
- auto state_size = llama_get_state_size (ctx );
55
- auto state_mem = new uint8_t [state_size];
56
- llama_copy_state_data (ctx, state_mem); // could also copy directly to memory mapped file
57
- fwrite (state_mem, 1 , state_size, fp_write);
58
- fclose (fp_write);
56
+ {
57
+ FILE *fp_write = fopen ( " dump_state.bin " , " wb " );
58
+ llama_copy_state_data (ctx, state_mem); // could also copy directly to memory mapped file
59
+ fwrite (state_mem, 1 , state_size, fp_write);
60
+ fclose ( fp_write);
61
+ }
59
62
60
63
// save state (last tokens)
61
- auto last_n_tokens_data_saved = vector<llama_token>(last_n_tokens_data);
62
- auto n_past_saved = n_past;
64
+ const auto last_n_tokens_data_saved = std:: vector<llama_token>(last_n_tokens_data);
65
+ const auto n_past_saved = n_past;
63
66
64
67
// first run
65
68
printf (" \n %s" , params.prompt .c_str ());
69
+
66
70
for (auto i = 0 ; i < params.n_predict ; i++) {
67
71
auto logits = llama_get_logits (ctx);
68
72
auto n_vocab = llama_n_vocab (ctx);
@@ -75,31 +79,42 @@ int main(int argc, char ** argv) {
75
79
auto next_token = llama_sample_token (ctx, &candidates_p);
76
80
auto next_token_str = llama_token_to_str (ctx, next_token);
77
81
last_n_tokens_data.push_back (next_token);
82
+
78
83
printf (" %s" , next_token_str);
79
84
if (llama_eval (ctx, &next_token, 1 , n_past, params.n_threads )) {
80
85
fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
81
86
return 1 ;
82
87
}
83
88
n_past += 1 ;
84
89
}
90
+
85
91
printf (" \n\n " );
86
92
87
93
// free old model
88
94
llama_free (ctx);
89
95
90
96
// load new model
91
-
92
97
auto ctx2 = llama_init_from_file (params.model .c_str (), lparams);
93
98
94
99
// Load state (rng, logits, embedding and kv_cache) from file
95
- FILE *fp_read = fopen (" dump_state.bin" , " rb" );
96
- auto state_size2 = llama_get_state_size (ctx2);
97
- if (state_size != state_size2) {
98
- fprintf (stderr, " \n %s : failed to validate state size\n " , __func__);
100
+ {
101
+ FILE *fp_read = fopen (" dump_state.bin" , " rb" );
102
+ if (state_size != llama_get_state_size (ctx2)) {
103
+ fprintf (stderr, " \n %s : failed to validate state size\n " , __func__);
104
+ return 1 ;
105
+ }
106
+
107
+ const size_t ret = fread (state_mem, 1 , state_size, fp_read);
108
+ if (ret != state_size) {
109
+ fprintf (stderr, " \n %s : failed to read state\n " , __func__);
110
+ return 1 ;
111
+ }
112
+
113
+ llama_set_state_data (ctx2, state_mem); // could also read directly from memory mapped file
114
+ fclose (fp_read);
99
115
}
100
- fread (state_mem, 1 , state_size, fp_read);
101
- llama_set_state_data (ctx2, state_mem); // could also read directly from memory mapped file
102
- fclose (fp_read);
116
+
117
+ delete[] state_mem;
103
118
104
119
// restore state (last tokens)
105
120
last_n_tokens_data = last_n_tokens_data_saved;
@@ -118,13 +133,16 @@ int main(int argc, char ** argv) {
118
133
auto next_token = llama_sample_token (ctx2, &candidates_p);
119
134
auto next_token_str = llama_token_to_str (ctx2, next_token);
120
135
last_n_tokens_data.push_back (next_token);
136
+
121
137
printf (" %s" , next_token_str);
122
138
if (llama_eval (ctx2, &next_token, 1 , n_past, params.n_threads )) {
123
139
fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
124
140
return 1 ;
125
141
}
126
142
n_past += 1 ;
127
143
}
144
+
128
145
printf (" \n\n " );
146
+
129
147
return 0 ;
130
148
}
0 commit comments