Skip to content

server : remove legacy system_prompt feature #9857

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 3 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1788,23 +1788,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.n_threads_http = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP"));
add_opt(common_arg(
{"-spf", "--system-prompt-file"}, "FNAME",
"set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications",
[](common_params & params, const std::string & value) {
std::ifstream file(value);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
}
std::string system_prompt;
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(system_prompt)
);
params.system_prompt = system_prompt;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--metrics"},
string_format("enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled"),
Expand Down
1 change: 0 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ struct common_params {
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
std::string chat_template = ""; // NOLINT
std::string system_prompt = ""; // NOLINT
bool enable_chat_template = true;

std::vector<std::string> api_keys;
Expand Down
6 changes: 1 addition & 5 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ The project is under active development, and we are [looking for feedback and co
| `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate<br/>(env: LLAMA_ARG_SSL_CERT_FILE) |
| `-to, --timeout N` | server read/write timeout in seconds (default: 600)<br/>(env: LLAMA_ARG_TIMEOUT) |
| `--threads-http N` | number of threads used to process HTTP requests (default: -1)<br/>(env: LLAMA_ARG_THREADS_HTTP) |
| `-spf, --system-prompt-file FNAME` | set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications |
| `--metrics` | enable prometheus compatible metrics endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_METRICS) |
| `--slots` | enable slots monitoring endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_SLOTS) |
| `--props` | enable changing global properties via POST /props (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_PROPS) |
Expand Down Expand Up @@ -320,7 +319,6 @@ node index.js

- The prompt is a string or an array with the first element given as a string
- The model's `tokenizer.ggml.add_bos_token` metadata is `true`
- The system prompt is empty

`temperature`: Adjust the randomness of the generated text. Default: `0.8`

Expand Down Expand Up @@ -536,14 +534,12 @@ This endpoint is public (no API key check). By default, it is read-only. To make

```json
{
"system_prompt": "",
"default_generation_settings": { ... },
"total_slots": 1,
"chat_template": ""
}
```

- `system_prompt` - the system prompt (initial prompt of all slots). Please note that this does not take into account the chat template. It will append the prompt at the beginning of formatted prompt.
- `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint.
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
- `chat_template` - the model's original Jinja2 prompt template
Expand All @@ -554,7 +550,7 @@ To use this endpoint with POST method, you need to start server with `--props`

*Options:*

- `system_prompt`: Change the system prompt (initial prompt of all slots). Please note that this does not take into account the chat template. It will append the prompt at the beginning of formatted prompt.
- None yet

### POST `/v1/chat/completions`: OpenAI-compatible Chat Completions API

Expand Down
103 changes: 18 additions & 85 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,12 +623,6 @@ struct server_context {

int32_t n_ctx; // total context for all clients / slots

// system prompt
bool system_need_update = false;

std::string system_prompt;
std::vector<llama_token> system_tokens;

// slots / clients
std::vector<server_slot> slots;
json default_generation_settings_for_props;
Expand Down Expand Up @@ -665,7 +659,7 @@ struct server_context {
bool load_model(const common_params & params_) {
params = params_;

// dedicate one sequence to the system prompt
// reserve one extra sequence (seq_id == 0) for extra features
params.n_parallel += 1;

common_init_result llama_init = common_init_from_params(params);
Expand Down Expand Up @@ -1061,51 +1055,6 @@ struct server_context {
clean_kv_cache = false;
}

void system_prompt_update() {
SRV_DBG("updating system prompt: '%s'\n", system_prompt.c_str());

kv_cache_clear();
system_tokens.clear();

if (!system_prompt.empty()) {
system_tokens = common_tokenize(ctx, system_prompt, true);

const int32_t n_batch = llama_n_batch(ctx);
const int32_t n_tokens_prompt = system_tokens.size();

for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);

common_batch_clear(batch);

for (int32_t j = 0; j < n_tokens; ++j) {
common_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
}

if (llama_decode(ctx, batch) != 0) {
SRV_ERR("%s", "llama_decode() failed\n");
return;
}
}

// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= params.n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
}

system_need_update = false;
}

bool system_prompt_set(const std::string & sys_prompt) {
SRV_DBG("system prompt set: '%s'\n", system_prompt.c_str());

system_prompt = sys_prompt;
// update system_tokens and KV cache as soon as all slots are idle
system_need_update = true;
return true;
}

bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
Expand Down Expand Up @@ -1855,12 +1804,8 @@ struct server_context {
}

if (all_idle) {
if (system_need_update) {
system_prompt_update();
}

SRV_INF("%s", "all slots are idle\n");
if (system_prompt.empty() && clean_kv_cache) {
if (clean_kv_cache) {
kv_cache_clear();
}

Expand All @@ -1882,7 +1827,7 @@ struct server_context {
// TODO: simplify and improve
for (server_slot & slot : slots) {
if (slot.ga_n == 1) {
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
if (slot.is_processing() && slot.n_past >= slot.n_ctx - 1) {
if (!params.ctx_shift) {
// this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token()
Expand All @@ -1893,13 +1838,13 @@ struct server_context {

// Shift context
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
const int n_left = slot.n_past - n_keep;
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);

SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);

llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past, -n_discard);

if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
Expand Down Expand Up @@ -1929,18 +1874,16 @@ struct server_context {

const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;

// TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow
common_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
common_batch_add(batch, slot.sampled, slot_npast, { slot.id + 1 }, true);

slot.n_past += 1;

if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(slot.sampled);
}

SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_system_tokens = %d, n_cache_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.n_past, (int) system_tokens.size(), (int) slot.cache_tokens.size(), slot.truncated);
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
}

// process in chunks of params.n_batch
Expand Down Expand Up @@ -1971,7 +1914,7 @@ struct server_context {
case SERVER_TASK_CMPL_TYPE_NORMAL:
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
{
prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt
prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true);
} break;
case SERVER_TASK_CMPL_TYPE_RERANK:
{
Expand Down Expand Up @@ -2050,7 +1993,7 @@ struct server_context {
} else {
if (!params.ctx_shift) {
// if context shift is disabled, we make sure prompt size is smaller than KV size
if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
if (slot.n_prompt_tokens >= slot.n_ctx) {
slot.release();
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
continue;
Expand Down Expand Up @@ -2138,22 +2081,19 @@ struct server_context {
}

// keep only the common part
int p0 = (int) system_tokens.size() + slot.n_past;
int p0 = slot.n_past;

if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);

p0 = (int) system_tokens.size();
if (p0 != 0) {
// copy over the system prompt when there is one
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
}
p0 = 0;

// there is no common part left (except for the system prompt)
// there is no common part left
slot.n_past = 0;
slot.n_past_se = 0;
slot.ga_i = 0;
// TODO: is the system prompt ever in the sampling context?

common_sampler_reset(slot.smpl);
}

Expand All @@ -2179,7 +2119,7 @@ struct server_context {
}
}

common_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
common_batch_add(batch, prompt_tokens[slot.n_past], slot_npast, { slot.id + 1 }, false);

if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
Expand Down Expand Up @@ -2409,10 +2349,6 @@ int main(int argc, char ** argv) {
// struct that contains llama context and inference
server_context ctx_server;

if (!params.system_prompt.empty()) {
ctx_server.system_prompt_set(params.system_prompt);
}

if (params.model_alias == "unknown") {
params.model_alias = params.model;
}
Expand Down Expand Up @@ -2840,7 +2776,6 @@ int main(int argc, char ** argv) {

const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
json data = {
{ "system_prompt", ctx_server.system_prompt },
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params.n_parallel },
{ "chat_template", llama_get_chat_template(ctx_server.model) },
Expand All @@ -2856,10 +2791,8 @@ int main(int argc, char ** argv) {
}

json data = json::parse(req.body);
if (data.contains("system_prompt")) {
std::string system_prompt = data.at("system_prompt");
ctx_server.system_prompt_set(system_prompt);
}

// update any props here

res_ok(res, {{ "success", true }});
};
Expand Down
Loading