@@ -97,7 +97,9 @@ struct llama_context {
97
97
llama_model model;
98
98
llama_vocab vocab;
99
99
100
- size_t mem_per_token = 0 ;
100
+ // used to estimate memory requiremnts experimentally
101
+ size_t mem_at_token0 = 0 ; // fist time
102
+ size_t mem_at_token1 = 0 ; // second time
101
103
102
104
// decode output (2-dimensional array: [n_tokens][n_vocab])
103
105
std::vector<float > logits;
@@ -626,14 +628,25 @@ static bool llama_eval_internal(
626
628
const int n_vocab = hparams.n_vocab ;
627
629
const int n_rot = hparams.n_embd /hparams.n_head ;
628
630
629
- auto & mem_per_token = lctx.mem_per_token ;
631
+ auto & mem_at_token0 = lctx.mem_at_token0 ;
632
+ auto & mem_at_token1 = lctx.mem_at_token1 ;
630
633
631
634
// TODO: fix this hardcoded size
632
- static size_t buf_size = 512u *1024 *1024 ;
635
+ static size_t buf_size = size_t (n_ctx) *1024 *1024 ;
633
636
static void * buf = malloc (buf_size);
634
637
635
- if (mem_per_token > 0 && mem_per_token*N > buf_size) {
636
- const size_t buf_size_new = 1.3 *(mem_per_token*N); // add 30% to account for ggml object overhead
638
+ const size_t C0 = mem_at_token0; // ~base
639
+ const int64_t C1 = mem_at_token1 - mem_at_token0; // delta 0,1
640
+
641
+ // fprintf(stderr, "\n%s: C0:%zu C1:%zu\n", __func__, C0, C1);
642
+
643
+ // const size_t size_estimate = C0 + size_t(C1 * (n_past + N)); // TODO(Green-Sky): determine relation to N (batch size)
644
+ const size_t size_estimate = C0 + C1 * n_past;
645
+
646
+ // fprintf(stderr, "\n%s: size_estimate %zu bytes (%zu | %zu)\n", __func__, size_estimate, mem_per_token0, mem_per_token1);
647
+
648
+ if (mem_at_token0 > 0 && mem_at_token1 > 0 && size_estimate > buf_size) {
649
+ const size_t buf_size_new = 1.1 *size_estimate; // just grow by 10%
637
650
// fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
638
651
639
652
// reallocate
@@ -830,8 +843,12 @@ static bool llama_eval_internal(
830
843
memcpy (logits_out.data (), (float *) ggml_get_data (inpL) + (n_vocab*(N-1 )), sizeof (float )*n_vocab);
831
844
}
832
845
833
- if (mem_per_token == 0 ) {
834
- mem_per_token = ggml_used_mem (ctx0)/N;
846
+ if (mem_at_token0 == 0 ) {
847
+ mem_at_token0 = ggml_used_mem (ctx0);
848
+ } else if (mem_at_token1 == 0 ) {
849
+ mem_at_token1 = ggml_used_mem (ctx0);
850
+ } else {
851
+ // fprintf(stderr, "estimate/used_mem = %f\n", double(size_estimate) / ggml_used_mem(ctx0));
835
852
}
836
853
// fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
837
854
0 commit comments