Skip to content

Commit 440fd20

Browse files
authored
Remove unprintable characters from vocab
Fixes ggml-org#11 This fixes a Japanese prompt I was attempting to run EG: `./main -m ./models/13B/ggml-model-q4_0.bin -t 8 -n 128 -n 512 -p $'人生の意味は'` Output before change: `人生の意���、フロントカードに���いてる。 2019年3月 © All Rights Reserved. [end of text]` So it is outputting some characters but some � Output after change: `人生の意は、一人が一人ということであります。は安部が立していたので、去からは一人の人にれるのはにとどまったのですが、そう`
1 parent 4235e3d commit 440fd20

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

main.cpp

+21-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <map>
1111
#include <string>
1212
#include <vector>
13+
#include <unordered_set>
1314

1415
// determine number of model parts based on the dimension
1516
static const std::map<int, int> LLAMA_N_PARTS = {
@@ -123,6 +124,9 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
123124
}
124125

125126
// load vocab
127+
128+
std::unordered_set<std::string> unprintable_characters = {"", "", "��"};
129+
126130
{
127131
const int32_t n_vocab = model.hparams.n_vocab;
128132

@@ -131,7 +135,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
131135
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
132136
return false;
133137
}
134-
138+
135139
std::string word;
136140
for (int i = 0; i < n_vocab; i++) {
137141
uint32_t len;
@@ -140,6 +144,10 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
140144
word.resize(len);
141145
fin.read((char *) word.data(), len);
142146

147+
if(unprintable_characters.find(word) != unprintable_characters.end()) {
148+
continue;
149+
}
150+
143151
vocab.token_to_id[word] = i;
144152
vocab.id_to_token[i] = word;
145153

@@ -792,7 +800,7 @@ int main(int argc, char ** argv) {
792800
printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
793801
}
794802
printf("\n");
795-
printf("sampling parameters: temp = %f, top_k = %d, top_p = %f\n", params.temp, params.top_k, params.top_p);
803+
printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
796804
printf("\n\n");
797805

798806
std::vector<gpt_vocab::id> embd;
@@ -801,6 +809,10 @@ int main(int argc, char ** argv) {
801809
size_t mem_per_token = 0;
802810
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
803811

812+
int last_n_size = params.repeat_last_n;
813+
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
814+
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
815+
804816
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
805817
// predict
806818
if (embd.size() > 0) {
@@ -821,6 +833,7 @@ int main(int argc, char ** argv) {
821833
// sample next token
822834
const float top_p = params.top_p;
823835
const float temp = params.temp;
836+
const float repeat_penalty = params.repeat_penalty;
824837

825838
const int n_vocab = model.hparams.n_vocab;
826839

@@ -829,7 +842,10 @@ int main(int argc, char ** argv) {
829842
{
830843
const int64_t t_start_sample_us = ggml_time_us();
831844

832-
id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_p, temp, rng);
845+
id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng);
846+
847+
last_n_tokens.erase(last_n_tokens.begin());
848+
last_n_tokens.push_back(id);
833849

834850
t_sample_us += ggml_time_us() - t_start_sample_us;
835851
}
@@ -840,6 +856,8 @@ int main(int argc, char ** argv) {
840856
// if here, it means we are still processing the input prompt
841857
for (int k = i; k < embd_inp.size(); k++) {
842858
embd.push_back(embd_inp[k]);
859+
last_n_tokens.erase(last_n_tokens.begin());
860+
last_n_tokens.push_back(embd_inp[k]);
843861
if (embd.size() > params.n_batch) {
844862
break;
845863
}

0 commit comments

Comments
 (0)