Skip to content

Commit e986f94

Browse files
chrfalchprusnak
andauthored
Added api for getting/setting the kv_cache (#685)
The api provides access methods for retrieving the current memory buffer for the kv_cache and its token number. It also contains a method for setting the kv_cache from a memory buffer. This makes it possible to load/save history - maybe support --cache-prompt paramater as well? Co-authored-by: Pavol Rusnak <pavol@rusnak.io>
1 parent c0bb1d3 commit e986f94

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

Diff for: llama.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,33 @@ int llama_model_quantize(
16681668
return 0;
16691669
}
16701670

1671+
// Returns the KV cache that will contain the context for the
1672+
// ongoing prediction with the model.
1673+
const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
1674+
return ctx->model.kv_self.buf.data();
1675+
}
1676+
1677+
// Returns the size of the KV cache
1678+
size_t llama_get_kv_cache_size(struct llama_context * ctx) {
1679+
return ctx->model.kv_self.buf.size();
1680+
}
1681+
1682+
int llama_get_kv_cache_token_count(struct llama_context * ctx) {
1683+
return ctx->model.kv_self.n;
1684+
}
1685+
1686+
// Sets the KV cache containing the current context for the model
1687+
void llama_set_kv_cache(
1688+
struct llama_context * ctx,
1689+
const uint8_t * kv_cache,
1690+
size_t n_size,
1691+
int n_token_count) {
1692+
// Make sure we have the same kv cache setup
1693+
LLAMA_ASSERT(ctx->model.kv_self.buf.size() == n_size);
1694+
memcpy(ctx->model.kv_self.buf.data(), kv_cache, n_size);
1695+
ctx->model.kv_self.n = n_token_count;
1696+
}
1697+
16711698
int llama_eval(
16721699
struct llama_context * ctx,
16731700
const llama_token * tokens,

Diff for: llama.h

+17
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,23 @@ extern "C" {
8383
const char * fname_out,
8484
int itype);
8585

86+
// Returns the KV cache that will contain the context for the
87+
// ongoing prediction with the model.
88+
LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);
89+
90+
// Returns the size of the KV cache
91+
LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx);
92+
93+
// Returns the number of tokens in the KV cache
94+
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
95+
96+
// Sets the KV cache containing the current context for the model
97+
LLAMA_API void llama_set_kv_cache(
98+
struct llama_context * ctx,
99+
const uint8_t * kv_cache,
100+
size_t n_size,
101+
int n_token_count);
102+
86103
// Run the llama inference to obtain the logits and probabilities for the next token.
87104
// tokens + n_tokens is the provided batch of new tokens to process
88105
// n_past is the number of tokens to use from previous eval calls

0 commit comments

Comments
 (0)