Skip to content

Commit 436e561

Browse files
swggerganov
andauthored
all : be more strict about converting float to double (#458)
* Be more strict about converting float to double * Test equivalence of round, SILU implementations Test module is commented out in CMakeLists.txt because the tests may take a long time, depending on how much the compiler optimizes. * Fix softmax in perplexity.cpp * all : prefer float over double where appropriate * perplexity : add <cmath> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 20e1e84 commit 436e561

File tree

11 files changed

+185
-117
lines changed

11 files changed

+185
-117
lines changed

CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,18 @@ if (LLAMA_ALL_WARNINGS)
124124
-Wall
125125
-Wextra
126126
-Wpedantic
127-
-Wshadow
128127
-Wcast-qual
128+
-Wdouble-promotion
129+
-Wshadow
129130
-Wstrict-prototypes
130131
-Wpointer-arith
131-
-Wno-unused-function
132132
)
133133
set(cxx_flags
134134
-Wall
135135
-Wextra
136136
-Wpedantic
137137
-Wcast-qual
138+
-Wdouble-promotion
138139
)
139140
else()
140141
# todo : msvc

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
3535
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
3636
LDFLAGS =
3737

38+
# warnings
39+
CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function
40+
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
41+
3842
# OS specific
3943
# TODO: support Windows
4044
ifeq ($(UNAME_S),Linux)

examples/common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,13 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
215215
fprintf(stderr, " prompt file to start generation.\n");
216216
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
217217
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
218-
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
218+
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
219219
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
220-
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
220+
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
221221
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
222222
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
223223
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
224-
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
224+
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
225225
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
226226
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
227227
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");

examples/main/main.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ int main(int argc, char ** argv) {
209209
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
210210
}
211211
}
212-
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
212+
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
213+
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
213214
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
214215
fprintf(stderr, "\n\n");
215216

@@ -274,10 +275,10 @@ int main(int argc, char ** argv) {
274275

275276
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
276277
// out of user input, sample next token
277-
const float top_k = params.top_k;
278-
const float top_p = params.top_p;
279-
const float temp = params.temp;
280-
const float repeat_penalty = params.repeat_penalty;
278+
const int32_t top_k = params.top_k;
279+
const float top_p = params.top_p;
280+
const float temp = params.temp;
281+
const float repeat_penalty = params.repeat_penalty;
281282

282283
llama_token id = 0;
283284

examples/perplexity/perplexity.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
#include "common.h"
22
#include "llama.h"
33

4-
std::vector<double> softmax(const std::vector<float>& logits) {
5-
std::vector<double> probs(logits.size());
4+
#include <cmath>
5+
6+
std::vector<float> softmax(const std::vector<float>& logits) {
7+
std::vector<float> probs(logits.size());
68
float max_logit = logits[0];
79
for (float v : logits) max_logit = std::max(max_logit, v);
810
double sum_exp = 0.0;
911
for (size_t i = 0; i < logits.size(); i++) {
1012
// Subtract the maximum logit value from the current logit value for numerical stability
11-
float logit = logits[i] - max_logit;
12-
double exp_logit = std::exp(logit);
13+
const float logit = logits[i] - max_logit;
14+
const float exp_logit = expf(logit);
1315
sum_exp += exp_logit;
1416
probs[i] = exp_logit;
1517
}
@@ -24,14 +26,16 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
2426
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
2527

2628
int count = 0;
27-
double nll = 0.0;
2829
int seq_count = tokens.size() / params.n_ctx;
2930

31+
double nll = 0.0;
32+
3033
fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
3134

3235
for (int i = 0; i < seq_count; ++i) {
3336
int start = i * params.n_ctx;
34-
int end = start + params.n_ctx - 1;
37+
int end = start + params.n_ctx - 1; // TODO: this is not optimal, e.g. it makes the batch 511 instead of 512
38+
// it is better to always be power of 2 for better performance
3539
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
3640
auto start_t = std::chrono::high_resolution_clock::now();
3741
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
@@ -40,7 +44,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
4044
}
4145
auto end_t = std::chrono::high_resolution_clock::now();
4246
if (i == 0) {
43-
double seconds = std::chrono::duration<double>(end_t - start_t).count();
47+
const float seconds = std::chrono::duration<float>(end_t - start_t).count();
4448
printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0));
4549
}
4650
// We get the logits for all the tokens in the context window (params.n_ctx)
@@ -63,7 +67,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
6367
std::vector<float> tok_logits(
6468
logits + j * n_vocab,
6569
logits + (j + 1) * n_vocab);
66-
double prob = softmax(tok_logits)[tokens[start + j + 1]];
70+
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
6771
nll += -std::log(prob);
6872
++count;
6973
}

examples/quantize/quantize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ int main(int argc, char ** argv) {
5050
const int64_t t_main_end_us = ggml_time_us();
5151

5252
printf("\n");
53-
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
54-
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
53+
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0);
54+
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0);
5555
}
5656

5757
return 0;

0 commit comments

Comments
 (0)