Skip to content

Commit

Permalink
Merge pull request #3 from anon998/sse
Browse files Browse the repository at this point in the history
Add streaming via server-sent events.
Has some changes that I didn't make, and I decided I prefer "stream" to "streaming"
  • Loading branch information
digiwombat authored May 31, 2023
2 parents a25f830 + 2533878 commit e6de69a
Showing 1 changed file with 70 additions and 94 deletions.
164 changes: 70 additions & 94 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct server_params

struct llama_server_context
{
bool streaming = false;
bool stream = false;
bool has_next_token = false;
std::string generated_text = "";

Expand All @@ -35,7 +35,6 @@ struct llama_server_context
std::string stopping_word;

void rewind() {
streaming = false;
params.antiprompt.clear();
num_tokens_predicted = 0;
generated_text = "";
Expand Down Expand Up @@ -253,9 +252,6 @@ struct llama_server_context
if (token == -1) {
return "";
}
if(streaming) {
generated_text = "";
}

std::string token_text = llama_token_to_str(ctx, token);
generated_text += token_text;
Expand All @@ -270,7 +266,7 @@ struct llama_server_context
}
}

return generated_text;
return token_text;
}

std::vector<float> embedding(std::string content, int threads) {
Expand Down Expand Up @@ -478,13 +474,13 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para

bool parse_options_completion(json body, llama_server_context& llama, Response &res) {
gpt_params default_params;
if (!body["streaming"].is_null())
if (!body["stream"].is_null())
{
llama.streaming = body["streaming"].get<bool>();
llama.stream = body["stream"].get<bool>();
}
else
{
llama.streaming = false;
llama.stream = false;
}
if (!body["n_predict"].is_null())
{
Expand Down Expand Up @@ -675,8 +671,6 @@ int main(int argc, char **argv)
llama_server_context llama;
params.model = "ggml-model.bin";

std::string final_text;

if (server_params_parse(argc, argv, sparams, params) == false)
{
return 1;
Expand All @@ -693,98 +687,81 @@ int main(int argc, char **argv)
svr.Get("/", [](const Request &, Response &res)
{ res.set_content("<h1>llama.cpp server works</h1>", "text/html"); });

svr.Post("/completion", [&llama, &final_text](const Request &req, Response &res)
{
if(llama.params.embedding) {
json data = {
{"status", "error"},
{"reason", "To use completion function, disable embedding mode"}};
res.set_content(data.dump(), "application/json");
res.status = 400;
return;
}
svr.Post("/completion", [&llama](const Request &req, Response &res) {
if (llama.params.embedding) {
json data = {
{"status", "error"},
{"reason", "To use completion function, disable embedding mode"}};
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
res.status = 400;
return;
}

llama.rewind();
final_text = "";
llama.rewind();

if(parse_options_completion(json::parse(req.body), llama, res) == false){
return;
}
if (parse_options_completion(json::parse(req.body), llama, res) == false) {
return;
}

if (!llama.loadPrompt())
{
json data = {
{"status", "error"},
{"reason", "Context too long."}};
res.set_content(data.dump(), "application/json");
res.status = 400;
return;
}
if (!llama.loadPrompt()) {
json data = {{"status", "error"}, {"reason", "Context too long."}};
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
res.status = 400;
return;
}

llama.beginCompletion();

if (!llama.stream) {
while (llama.has_next_token) {
llama.doCompletion();
}

json data = {{"content", llama.generated_text},
{"stop", true},
{"model", llama.params.model_alias },
{"tokens_predicted", llama.num_tokens_predicted},
{"generation_settings", format_generation_settings(llama)},
{"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word}};
return res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json");
} else {
const auto chunked_content_provider = [&](size_t, DataSink &sink) {
while (llama.has_next_token) {
std::string token_text = llama.doCompletion();

llama.beginCompletion();
if(llama.streaming)
{
res.set_chunked_content_provider("text/event-stream", [&](size_t /*offset*/,
DataSink& sink) {
std::string final_text = "";
// loop inference until finish completion
while (llama.has_next_token) {
std::string result = llama.doCompletion();
json data;
final_text += result;
if (llama.has_next_token)
{
data = { {"content", result}, {"stop", false} };
}
else
{
// Generation is done, send extra information.
data = { {"content", result},
{"stop", true},
{"tokens_predicted", llama.num_tokens_predicted},
{"generation_settings", format_generation_settings(llama)},
{"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word},
{"generated_text", final_text} };
if (llama.has_next_token) {
data = {{"content", token_text}, {"stop", false}};
} else {
// Generation is done, send extra information.
data = {
{"content", token_text},
{"stop", true},
{"model", llama.params.model_alias},
{"tokens_predicted", llama.num_tokens_predicted},
{"generation_settings", format_generation_settings(llama)},
{"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word},
{"generated_text", llama.generated_text}};
}

std::string str =
"data: " + data.dump(4, ' ', false, json::error_handler_t::replace) +
"\n\n";
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
sink.write(str.data(), str.size());
}

sink.done();
return true;
});
}
else
{
// loop inference until finish completion
while (llama.has_next_token)
{
llama.doCompletion();
}
try
{
json data = {
{"model", llama.params.model_alias },
{"content", llama.generated_text },
{"tokens_predicted", llama.num_tokens_predicted},
{"generation_settings", format_generation_settings(llama)},
{"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word} };
return res.set_content(data.dump(), "application/json");
}
catch (const json::exception &e)
{
// Some tokens have bad UTF-8 strings, the json parser is very sensitive
json data = {
{"content", "Bad encoding token"},
{"tokens_predicted", 0}};
return res.set_content(data.dump(), "application/json");
}
} });

sink.done();
return true;
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
}
});


svr.Post("/tokenize", [&llama](const Request &req, Response &res)
{
Expand All @@ -811,7 +788,6 @@ int main(int argc, char **argv)
return res.set_content(data.dump(), "application/json");
});


fprintf(stderr, "%s: http server Listening at http://%s:%i\n", __func__, sparams.hostname.c_str(), sparams.port);

if(params.embedding) {
Expand Down

0 comments on commit e6de69a

Please # to comment.