@@ -2072,35 +2072,191 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
2072
2072
}
2073
2073
}
2074
2074
2075
- // Returns the KV cache that will contain the context for the
2076
- // ongoing prediction with the model.
2077
- const uint8_t * llama_get_kv_cache (struct llama_context * ctx) {
2078
- return ctx->model .kv_self .buf .addr ;
2075
+ int llama_get_kv_cache_token_count (struct llama_context * ctx) {
2076
+ return ctx->model .kv_self .n ;
2079
2077
}
2080
2078
2081
- // Returns the size of the KV cache
2082
- size_t llama_get_kv_cache_size (struct llama_context * ctx) {
2083
- return ctx->model .kv_self .buf .size ;
2079
+ #define LLAMA_MAX_RNG_STATE 64 *1024
2080
+
2081
+ // Returns the size of the state
2082
+ size_t llama_get_state_size (struct llama_context * ctx) {
2083
+ // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
2084
+ // for reference, std::mt19937(1337) serializes to 6701 bytes.
2085
+ const size_t s_rng_size = sizeof (size_t );
2086
+ const size_t s_rng = LLAMA_MAX_RNG_STATE;
2087
+ const size_t s_logits_capacity = sizeof (size_t );
2088
+ const size_t s_logits_size = sizeof (size_t );
2089
+ const size_t s_logits = ctx->logits .capacity () * sizeof (float );
2090
+ const size_t s_embedding_size = sizeof (size_t );
2091
+ const size_t s_embedding = ctx->embedding .size () * sizeof (float );
2092
+ const size_t s_kv_size = sizeof (size_t );
2093
+ const size_t s_kv_ntok = sizeof (int );
2094
+ const size_t s_kv = ctx->model .kv_self .buf .size ;
2095
+
2096
+ const size_t s_total = (
2097
+ + s_rng_size
2098
+ + s_rng
2099
+ + s_logits_capacity
2100
+ + s_logits_size
2101
+ + s_logits
2102
+ + s_embedding_size
2103
+ + s_embedding
2104
+ + s_kv_size
2105
+ + s_kv_ntok
2106
+ + s_kv
2107
+ );
2108
+
2109
+ return s_total;
2084
2110
}
2085
2111
2086
- int llama_get_kv_cache_token_count (struct llama_context * ctx) {
2087
- return ctx->model .kv_self .n ;
2112
+ // Copies the state to the specified destination address
2113
+ size_t llama_copy_state_data (struct llama_context * ctx, uint8_t * dest) {
2114
+ uint8_t * out = dest;
2115
+
2116
+ // copy rng
2117
+ {
2118
+ std::stringstream rng_ss;
2119
+ rng_ss << ctx->rng ;
2120
+
2121
+ const size_t rng_size = rng_ss.str ().size ();
2122
+ char rng_buf[LLAMA_MAX_RNG_STATE];
2123
+
2124
+ memset (&rng_buf[0 ], 0 , LLAMA_MAX_RNG_STATE);
2125
+ memcpy (&rng_buf[0 ], rng_ss.str ().data (), rng_ss.str ().size ());
2126
+
2127
+ memcpy (out, &rng_size, sizeof (rng_size)); out += sizeof (rng_size);
2128
+ memcpy (out, &rng_buf[0 ], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE;
2129
+ }
2130
+
2131
+ // copy logits
2132
+ {
2133
+ const size_t logits_cap = ctx->logits .capacity ();
2134
+ const size_t logits_size = ctx->logits .size ();
2135
+
2136
+ memcpy (out, &logits_cap, sizeof (logits_cap)); out += sizeof (logits_cap);
2137
+ memcpy (out, &logits_size, sizeof (logits_size)); out += sizeof (logits_size);
2138
+
2139
+ if (logits_size) {
2140
+ memcpy (out, ctx->logits .data (), logits_size * sizeof (float ));
2141
+ }
2142
+
2143
+ out += logits_cap * sizeof (float );
2144
+ }
2145
+
2146
+ // copy embeddings
2147
+ {
2148
+ const size_t embedding_size = ctx->embedding .size ();
2149
+
2150
+ memcpy (out, &embedding_size, sizeof (embedding_size)); out += sizeof (embedding_size);
2151
+
2152
+ if (embedding_size) {
2153
+ memcpy (out, ctx->embedding .data (), embedding_size * sizeof (float ));
2154
+ out += embedding_size * sizeof (float );
2155
+ }
2156
+ }
2157
+
2158
+ // copy kv cache
2159
+ {
2160
+ const size_t kv_size = ctx->model .kv_self .buf .size ;
2161
+ const int kv_ntok = llama_get_kv_cache_token_count (ctx);
2162
+
2163
+ memcpy (out, &kv_size, sizeof (kv_size)); out += sizeof (kv_size);
2164
+ memcpy (out, &kv_ntok, sizeof (kv_ntok)); out += sizeof (kv_ntok);
2165
+
2166
+ if (kv_size) {
2167
+ memcpy (out, ctx->model .kv_self .buf .addr , kv_size); out += kv_size;
2168
+ }
2169
+ }
2170
+
2171
+ const size_t written = out - dest;
2172
+ const size_t expected = llama_get_state_size (ctx);
2173
+
2174
+ LLAMA_ASSERT (written == expected);
2175
+
2176
+ return written;
2088
2177
}
2089
2178
2090
- // Sets the KV cache containing the current context for the model
2091
- void llama_set_kv_cache (
2092
- struct llama_context * ctx,
2093
- const uint8_t * kv_cache,
2094
- size_t n_size,
2095
- int n_token_count) {
2096
- // Make sure we have the same kv cache setup
2097
- LLAMA_ASSERT (ctx->model .kv_self .buf .size == n_size);
2098
- void * k_data = ctx->model .kv_self .k ->data ; // remember data pointers
2099
- void * v_data = ctx->model .kv_self .v ->data ; // because their value is stored in buf and overwritten by memcpy
2100
- memcpy (ctx->model .kv_self .buf .addr , kv_cache, n_size);
2101
- ctx->model .kv_self .k ->data = k_data; // restore correct data pointers
2102
- ctx->model .kv_self .v ->data = v_data;
2103
- ctx->model .kv_self .n = n_token_count;
2179
+ // Sets the state reading from the specified source address
2180
+ size_t llama_set_state_data (struct llama_context * ctx, const uint8_t * src) {
2181
+ const uint8_t * in = src;
2182
+
2183
+ // set rng
2184
+ {
2185
+ size_t rng_size;
2186
+ char rng_buf[LLAMA_MAX_RNG_STATE];
2187
+
2188
+ memcpy (&rng_size, in, sizeof (rng_size)); in += sizeof (rng_size);
2189
+ memcpy (&rng_buf[0 ], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE;
2190
+
2191
+ std::stringstream rng_ss;
2192
+ rng_ss.str (std::string (&rng_buf[0 ], rng_size));
2193
+ rng_ss >> ctx->rng ;
2194
+
2195
+ LLAMA_ASSERT (rng_ss.fail () == false );
2196
+ }
2197
+
2198
+ // set logits
2199
+ {
2200
+ size_t logits_cap;
2201
+ size_t logits_size;
2202
+
2203
+ memcpy (&logits_cap, in, sizeof (logits_cap)); in += sizeof (logits_cap);
2204
+ memcpy (&logits_size, in, sizeof (logits_size)); in += sizeof (logits_size);
2205
+
2206
+ LLAMA_ASSERT (ctx->logits .capacity () == logits_cap);
2207
+
2208
+ if (logits_size) {
2209
+ ctx->logits .resize (logits_size);
2210
+ memcpy (ctx->logits .data (), in, logits_size * sizeof (float ));
2211
+ }
2212
+
2213
+ in += logits_cap * sizeof (float );
2214
+ }
2215
+
2216
+ // set embeddings
2217
+ {
2218
+ size_t embedding_size;
2219
+
2220
+ memcpy (&embedding_size, in, sizeof (embedding_size)); in += sizeof (embedding_size);
2221
+
2222
+ LLAMA_ASSERT (ctx->embedding .capacity () == embedding_size);
2223
+
2224
+ if (embedding_size) {
2225
+ memcpy (ctx->embedding .data (), in, embedding_size * sizeof (float ));
2226
+ in += embedding_size * sizeof (float );
2227
+ }
2228
+ }
2229
+
2230
+ // set kv cache
2231
+ {
2232
+ size_t kv_size;
2233
+ int kv_ntok;
2234
+
2235
+ memcpy (&kv_size, in, sizeof (kv_size)); in += sizeof (kv_size);
2236
+ memcpy (&kv_ntok, in, sizeof (kv_ntok)); in += sizeof (kv_ntok);
2237
+
2238
+ if (kv_size) {
2239
+ LLAMA_ASSERT (ctx->model .kv_self .buf .size == kv_size);
2240
+
2241
+ void * k_data = ctx->model .kv_self .k ->data ; // remember data pointers
2242
+ void * v_data = ctx->model .kv_self .v ->data ; // because their value is stored in buf and overwritten by memcpy
2243
+
2244
+ memcpy (ctx->model .kv_self .buf .addr , in, kv_size); in += kv_size;
2245
+
2246
+ ctx->model .kv_self .k ->data = k_data; // restore correct data pointers
2247
+ ctx->model .kv_self .v ->data = v_data;
2248
+
2249
+ }
2250
+
2251
+ ctx->model .kv_self .n = kv_ntok;
2252
+ }
2253
+
2254
+ const size_t nread = in - src;
2255
+ const size_t expected = llama_get_state_size (ctx);
2256
+
2257
+ LLAMA_ASSERT (nread == expected);
2258
+
2259
+ return nread;
2104
2260
}
2105
2261
2106
2262
int llama_eval (
@@ -2256,120 +2412,3 @@ std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_te
2256
2412
return ctx->model .tensors_by_name ;
2257
2413
}
2258
2414
2259
- // Returns the size of the state
2260
- size_t llama_get_state_size (struct llama_context * ctx) {
2261
- // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
2262
- // for reference, std::mt19937(1337) serializes to 6701 bytes.
2263
- const size_t s_rng_size = sizeof (size_t );
2264
- const size_t s_rng = 64 *1024 ;
2265
- const size_t s_logits_capacity = sizeof (size_t );
2266
- const size_t s_logits_size = sizeof (size_t );
2267
- const size_t s_logits = ctx->logits .capacity () * sizeof (float );
2268
- const size_t s_embedding_size = sizeof (size_t );
2269
- const size_t s_embedding = ctx->embedding .size () * sizeof (float );
2270
- const size_t s_kv_size = sizeof (size_t );
2271
- const size_t s_kv_ntok = sizeof (int );
2272
- const size_t s_kv = llama_get_kv_cache_size (ctx);
2273
- const size_t s_total = (
2274
- + s_rng_size
2275
- + s_rng
2276
- + s_logits_capacity
2277
- + s_logits_size
2278
- + s_logits
2279
- + s_embedding_size
2280
- + s_embedding
2281
- + s_kv_size
2282
- + s_kv_ntok
2283
- + s_kv
2284
- );
2285
- return s_total;
2286
- }
2287
-
2288
- // Copies the state to the specified destination address
2289
- size_t llama_copy_state_data (struct llama_context * ctx, uint8_t * dest) {
2290
- std::stringstream rng_ss;
2291
- rng_ss << ctx->rng ;
2292
- const size_t rng_size = rng_ss.str ().size ();
2293
- char rng_buf[64 *1024 ];
2294
- memset (&rng_buf[0 ], 0 , 64 *1024 );
2295
- memcpy (&rng_buf[0 ], rng_ss.str ().data (), rng_ss.str ().size ());
2296
- const size_t logits_capacity = ctx->logits .capacity ();
2297
- const size_t logits_size = ctx->logits .size ();
2298
- const size_t embedding_size = ctx->embedding .size ();
2299
- const size_t kv_size = llama_get_kv_cache_size (ctx);
2300
- const int kv_ntok = llama_get_kv_cache_token_count (ctx);
2301
-
2302
- uint8_t * out = dest;
2303
- memcpy (out, &rng_size, sizeof (size_t )); out += sizeof (size_t );
2304
- memcpy (out, &rng_buf[0 ], 64 *1024 ); out += 64 *1024 ;
2305
- memcpy (out, &logits_capacity, sizeof (size_t )); out += sizeof (size_t );
2306
- memcpy (out, &logits_size, sizeof (size_t )); out += sizeof (size_t );
2307
- if (logits_size) {
2308
- memcpy (out, ctx->logits .data (), logits_size * sizeof (float ));
2309
- }
2310
- out += logits_capacity * sizeof (float );
2311
- memcpy (out, &embedding_size, sizeof (size_t )); out += sizeof (size_t );
2312
- if (embedding_size) {
2313
- memcpy (out, ctx->embedding .data (), embedding_size * sizeof (float )); out += embedding_size * sizeof (float );
2314
- }
2315
- memcpy (out, &kv_size, sizeof (size_t )); out += sizeof (size_t );
2316
- memcpy (out, &kv_ntok, sizeof (int )); out += sizeof (int );
2317
- if (kv_size) {
2318
- memcpy (out, llama_get_kv_cache (ctx), kv_size); out += kv_size;
2319
- }
2320
- const size_t written = out - dest;
2321
- const size_t expected = llama_get_state_size (ctx);
2322
- LLAMA_ASSERT (written == expected);
2323
- return written;
2324
- }
2325
-
2326
- // Sets the state reading from the specified source address
2327
- size_t llama_set_state_data (struct llama_context * ctx, const uint8_t * src) {
2328
- size_t rng_size;
2329
- char rng_buf[64 *1024 ];
2330
- std::stringstream rng_ss;
2331
-
2332
- const uint8_t * in = src;
2333
- memcpy (&rng_size, in, sizeof (size_t )); in += sizeof (size_t );
2334
- memcpy (&rng_buf[0 ], in, 64 *1024 ); in += 64 *1024 ;
2335
- rng_ss.str (std::string (&rng_buf[0 ], rng_size));
2336
- rng_ss >> ctx->rng ;
2337
- LLAMA_ASSERT (rng_ss.fail () == false );
2338
-
2339
- size_t logits_capacity;
2340
- size_t logits_size;
2341
- size_t embedding_size;
2342
- size_t kv_size;
2343
- int kv_ntok;
2344
-
2345
- memcpy (&logits_capacity, in, sizeof (size_t )); in += sizeof (size_t );
2346
- memcpy (&logits_size, in, sizeof (size_t )); in += sizeof (size_t );
2347
- LLAMA_ASSERT (ctx->logits .capacity () == logits_capacity);
2348
- if (logits_size) {
2349
- ctx->logits .resize (logits_size);
2350
- memcpy (ctx->logits .data (), in, logits_size * sizeof (float ));
2351
- }
2352
- in += logits_capacity * sizeof (float );
2353
- memcpy (&embedding_size, in, sizeof (size_t )); in += sizeof (size_t );
2354
- LLAMA_ASSERT (ctx->embedding .capacity () == embedding_size);
2355
- if (embedding_size) {
2356
- memcpy (ctx->embedding .data (), in, embedding_size * sizeof (float ));
2357
- in += embedding_size * sizeof (float );
2358
- }
2359
- memcpy (&kv_size, in, sizeof (size_t )); in += sizeof (size_t );
2360
- memcpy (&kv_ntok, in, sizeof (int )); in += sizeof (int );
2361
- if (kv_size) {
2362
- LLAMA_ASSERT (ctx->model .kv_self .buf .size == kv_size);
2363
- void * k_data = ctx->model .kv_self .k ->data ; // remember data pointers
2364
- void * v_data = ctx->model .kv_self .v ->data ; // because their value is stored in buf and overwritten by memcpy
2365
- memcpy (ctx->model .kv_self .buf .addr , in, kv_size);
2366
- ctx->model .kv_self .k ->data = k_data; // restore correct data pointers
2367
- ctx->model .kv_self .v ->data = v_data;
2368
- in += kv_size;
2369
- }
2370
- ctx->model .kv_self .n = kv_ntok;
2371
- const size_t nread = in - src;
2372
- const size_t expected = llama_get_state_size (ctx);
2373
- LLAMA_ASSERT (nread == expected);
2374
- return nread;
2375
- }
0 commit comments