diff --git a/src/llama-context.cpp b/src/llama-context.cpp index c2fcce42a7d58..abb7e526f6171 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -285,11 +285,15 @@ llama_context::llama_context( // reserve worst-case graph if (!hparams.vocab_only) { - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + const uint32_t n_seqs = 1; // TODO: worst-case number of sequences + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + // restore later + // TODO: something cleaner + const auto n_outputs_save = n_outputs; + // max number of outputs n_outputs = n_tokens; @@ -341,6 +345,8 @@ llama_context::llama_context( } } + n_outputs = n_outputs_save; + for (size_t i = 0; i < backend_ptrs.size(); ++i) { ggml_backend_t backend = backend_ptrs[i]; ggml_backend_buffer_type_t buft = backend_buft[i];