Skip to content

Commit c4fe84f

Browse files
authored
llama : refactor get / set state + remove redundant kv cache API (#1143)
1 parent 1d78fec commit c4fe84f

File tree

2 files changed

+179
-154
lines changed

2 files changed

+179
-154
lines changed

llama.cpp

+179-140
Original file line numberDiff line numberDiff line change
@@ -2072,35 +2072,191 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
20722072
}
20732073
}
20742074

2075-
// Returns the KV cache that will contain the context for the
2076-
// ongoing prediction with the model.
2077-
const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
2078-
return ctx->model.kv_self.buf.addr;
2075+
int llama_get_kv_cache_token_count(struct llama_context * ctx) {
2076+
return ctx->model.kv_self.n;
20792077
}
20802078

2081-
// Returns the size of the KV cache
2082-
size_t llama_get_kv_cache_size(struct llama_context * ctx) {
2083-
return ctx->model.kv_self.buf.size;
2079+
#define LLAMA_MAX_RNG_STATE 64*1024
2080+
2081+
// Returns the size of the state
2082+
size_t llama_get_state_size(struct llama_context * ctx) {
2083+
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
2084+
// for reference, std::mt19937(1337) serializes to 6701 bytes.
2085+
const size_t s_rng_size = sizeof(size_t);
2086+
const size_t s_rng = LLAMA_MAX_RNG_STATE;
2087+
const size_t s_logits_capacity = sizeof(size_t);
2088+
const size_t s_logits_size = sizeof(size_t);
2089+
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
2090+
const size_t s_embedding_size = sizeof(size_t);
2091+
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
2092+
const size_t s_kv_size = sizeof(size_t);
2093+
const size_t s_kv_ntok = sizeof(int);
2094+
const size_t s_kv = ctx->model.kv_self.buf.size;
2095+
2096+
const size_t s_total = (
2097+
+ s_rng_size
2098+
+ s_rng
2099+
+ s_logits_capacity
2100+
+ s_logits_size
2101+
+ s_logits
2102+
+ s_embedding_size
2103+
+ s_embedding
2104+
+ s_kv_size
2105+
+ s_kv_ntok
2106+
+ s_kv
2107+
);
2108+
2109+
return s_total;
20842110
}
20852111

2086-
int llama_get_kv_cache_token_count(struct llama_context * ctx) {
2087-
return ctx->model.kv_self.n;
2112+
// Copies the state to the specified destination address
2113+
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
2114+
uint8_t * out = dest;
2115+
2116+
// copy rng
2117+
{
2118+
std::stringstream rng_ss;
2119+
rng_ss << ctx->rng;
2120+
2121+
const size_t rng_size = rng_ss.str().size();
2122+
char rng_buf[LLAMA_MAX_RNG_STATE];
2123+
2124+
memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
2125+
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
2126+
2127+
memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size);
2128+
memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE;
2129+
}
2130+
2131+
// copy logits
2132+
{
2133+
const size_t logits_cap = ctx->logits.capacity();
2134+
const size_t logits_size = ctx->logits.size();
2135+
2136+
memcpy(out, &logits_cap, sizeof(logits_cap)); out += sizeof(logits_cap);
2137+
memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size);
2138+
2139+
if (logits_size) {
2140+
memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
2141+
}
2142+
2143+
out += logits_cap * sizeof(float);
2144+
}
2145+
2146+
// copy embeddings
2147+
{
2148+
const size_t embedding_size = ctx->embedding.size();
2149+
2150+
memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size);
2151+
2152+
if (embedding_size) {
2153+
memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float));
2154+
out += embedding_size * sizeof(float);
2155+
}
2156+
}
2157+
2158+
// copy kv cache
2159+
{
2160+
const size_t kv_size = ctx->model.kv_self.buf.size;
2161+
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
2162+
2163+
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
2164+
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
2165+
2166+
if (kv_size) {
2167+
memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size;
2168+
}
2169+
}
2170+
2171+
const size_t written = out - dest;
2172+
const size_t expected = llama_get_state_size(ctx);
2173+
2174+
LLAMA_ASSERT(written == expected);
2175+
2176+
return written;
20882177
}
20892178

2090-
// Sets the KV cache containing the current context for the model
2091-
void llama_set_kv_cache(
2092-
struct llama_context * ctx,
2093-
const uint8_t * kv_cache,
2094-
size_t n_size,
2095-
int n_token_count) {
2096-
// Make sure we have the same kv cache setup
2097-
LLAMA_ASSERT(ctx->model.kv_self.buf.size == n_size);
2098-
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
2099-
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
2100-
memcpy(ctx->model.kv_self.buf.addr, kv_cache, n_size);
2101-
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
2102-
ctx->model.kv_self.v->data = v_data;
2103-
ctx->model.kv_self.n = n_token_count;
2179+
// Sets the state reading from the specified source address
2180+
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
2181+
const uint8_t * in = src;
2182+
2183+
// set rng
2184+
{
2185+
size_t rng_size;
2186+
char rng_buf[LLAMA_MAX_RNG_STATE];
2187+
2188+
memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size);
2189+
memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE;
2190+
2191+
std::stringstream rng_ss;
2192+
rng_ss.str(std::string(&rng_buf[0], rng_size));
2193+
rng_ss >> ctx->rng;
2194+
2195+
LLAMA_ASSERT(rng_ss.fail() == false);
2196+
}
2197+
2198+
// set logits
2199+
{
2200+
size_t logits_cap;
2201+
size_t logits_size;
2202+
2203+
memcpy(&logits_cap, in, sizeof(logits_cap)); in += sizeof(logits_cap);
2204+
memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size);
2205+
2206+
LLAMA_ASSERT(ctx->logits.capacity() == logits_cap);
2207+
2208+
if (logits_size) {
2209+
ctx->logits.resize(logits_size);
2210+
memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
2211+
}
2212+
2213+
in += logits_cap * sizeof(float);
2214+
}
2215+
2216+
// set embeddings
2217+
{
2218+
size_t embedding_size;
2219+
2220+
memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size);
2221+
2222+
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
2223+
2224+
if (embedding_size) {
2225+
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
2226+
in += embedding_size * sizeof(float);
2227+
}
2228+
}
2229+
2230+
// set kv cache
2231+
{
2232+
size_t kv_size;
2233+
int kv_ntok;
2234+
2235+
memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
2236+
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
2237+
2238+
if (kv_size) {
2239+
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
2240+
2241+
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
2242+
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
2243+
2244+
memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size;
2245+
2246+
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
2247+
ctx->model.kv_self.v->data = v_data;
2248+
2249+
}
2250+
2251+
ctx->model.kv_self.n = kv_ntok;
2252+
}
2253+
2254+
const size_t nread = in - src;
2255+
const size_t expected = llama_get_state_size(ctx);
2256+
2257+
LLAMA_ASSERT(nread == expected);
2258+
2259+
return nread;
21042260
}
21052261

21062262
int llama_eval(
@@ -2256,120 +2412,3 @@ std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_te
22562412
return ctx->model.tensors_by_name;
22572413
}
22582414

2259-
// Returns the size of the state
2260-
size_t llama_get_state_size(struct llama_context * ctx) {
2261-
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
2262-
// for reference, std::mt19937(1337) serializes to 6701 bytes.
2263-
const size_t s_rng_size = sizeof(size_t);
2264-
const size_t s_rng = 64*1024;
2265-
const size_t s_logits_capacity = sizeof(size_t);
2266-
const size_t s_logits_size = sizeof(size_t);
2267-
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
2268-
const size_t s_embedding_size = sizeof(size_t);
2269-
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
2270-
const size_t s_kv_size = sizeof(size_t);
2271-
const size_t s_kv_ntok = sizeof(int);
2272-
const size_t s_kv = llama_get_kv_cache_size(ctx);
2273-
const size_t s_total = (
2274-
+ s_rng_size
2275-
+ s_rng
2276-
+ s_logits_capacity
2277-
+ s_logits_size
2278-
+ s_logits
2279-
+ s_embedding_size
2280-
+ s_embedding
2281-
+ s_kv_size
2282-
+ s_kv_ntok
2283-
+ s_kv
2284-
);
2285-
return s_total;
2286-
}
2287-
2288-
// Copies the state to the specified destination address
2289-
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
2290-
std::stringstream rng_ss;
2291-
rng_ss << ctx->rng;
2292-
const size_t rng_size = rng_ss.str().size();
2293-
char rng_buf[64*1024];
2294-
memset(&rng_buf[0], 0, 64*1024);
2295-
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
2296-
const size_t logits_capacity = ctx->logits.capacity();
2297-
const size_t logits_size = ctx->logits.size();
2298-
const size_t embedding_size = ctx->embedding.size();
2299-
const size_t kv_size = llama_get_kv_cache_size(ctx);
2300-
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
2301-
2302-
uint8_t * out = dest;
2303-
memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t);
2304-
memcpy(out, &rng_buf[0], 64*1024); out += 64*1024;
2305-
memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t);
2306-
memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t);
2307-
if (logits_size) {
2308-
memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
2309-
}
2310-
out += logits_capacity * sizeof(float);
2311-
memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t);
2312-
if (embedding_size) {
2313-
memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float);
2314-
}
2315-
memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t);
2316-
memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int);
2317-
if (kv_size) {
2318-
memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size;
2319-
}
2320-
const size_t written = out - dest;
2321-
const size_t expected = llama_get_state_size(ctx);
2322-
LLAMA_ASSERT(written == expected);
2323-
return written;
2324-
}
2325-
2326-
// Sets the state reading from the specified source address
2327-
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
2328-
size_t rng_size;
2329-
char rng_buf[64*1024];
2330-
std::stringstream rng_ss;
2331-
2332-
const uint8_t * in = src;
2333-
memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t);
2334-
memcpy(&rng_buf[0], in, 64*1024); in += 64*1024;
2335-
rng_ss.str(std::string(&rng_buf[0], rng_size));
2336-
rng_ss >> ctx->rng;
2337-
LLAMA_ASSERT(rng_ss.fail() == false);
2338-
2339-
size_t logits_capacity;
2340-
size_t logits_size;
2341-
size_t embedding_size;
2342-
size_t kv_size;
2343-
int kv_ntok;
2344-
2345-
memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t);
2346-
memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t);
2347-
LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity);
2348-
if (logits_size) {
2349-
ctx->logits.resize(logits_size);
2350-
memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
2351-
}
2352-
in += logits_capacity * sizeof(float);
2353-
memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t);
2354-
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
2355-
if (embedding_size) {
2356-
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
2357-
in += embedding_size * sizeof(float);
2358-
}
2359-
memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t);
2360-
memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int);
2361-
if (kv_size) {
2362-
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
2363-
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
2364-
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
2365-
memcpy(ctx->model.kv_self.buf.addr, in, kv_size);
2366-
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
2367-
ctx->model.kv_self.v->data = v_data;
2368-
in += kv_size;
2369-
}
2370-
ctx->model.kv_self.n = kv_ntok;
2371-
const size_t nread = in - src;
2372-
const size_t expected = llama_get_state_size(ctx);
2373-
LLAMA_ASSERT(nread == expected);
2374-
return nread;
2375-
}

llama.h

-14
Original file line numberDiff line numberDiff line change
@@ -112,23 +112,9 @@ extern "C" {
112112
const char * path_base_model,
113113
int n_threads);
114114

115-
// Returns the KV cache that will contain the context for the
116-
// ongoing prediction with the model.
117-
LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);
118-
119-
// Returns the size of the KV cache
120-
LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx);
121-
122115
// Returns the number of tokens in the KV cache
123116
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
124117

125-
// Sets the KV cache containing the current context for the model
126-
LLAMA_API void llama_set_kv_cache(
127-
struct llama_context * ctx,
128-
const uint8_t * kv_cache,
129-
size_t n_size,
130-
int n_token_count);
131-
132118
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
133119
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
134120

0 commit comments

Comments
 (0)