Skip to content

Commit 291a785

Browse files
committed
llama : rename batch.logits to batch.output
This commit renames the `logits` field of the `llama_batch` struct to `output`. The motivation for this change (apart from the TODO comment) is that the `logits` field is actually used to specify that output should be generated. For example, in the case of generating embeddings, setting logits to true can be confusing since the logits are not used when generating embeddings.
1 parent 9f4cc8f commit 291a785

File tree

19 files changed

+52
-53
lines changed

19 files changed

+52
-53
lines changed

common/common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
607607
<< ", pos " << std::to_string(batch.pos[i])
608608
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
609609
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
610-
<< ", logits " << std::to_string(batch.logits[i]);
610+
<< ", output " << std::to_string(batch.output[i]);
611611
}
612612

613613
buf << " ]";
@@ -1617,7 +1617,7 @@ void common_batch_add(
16171617
llama_token id,
16181618
llama_pos pos,
16191619
const std::vector<llama_seq_id> & seq_ids,
1620-
bool logits) {
1620+
bool output) {
16211621
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
16221622

16231623
batch.token [batch.n_tokens] = id;
@@ -1626,7 +1626,7 @@ void common_batch_add(
16261626
for (size_t i = 0; i < seq_ids.size(); ++i) {
16271627
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
16281628
}
1629-
batch.logits [batch.n_tokens] = logits;
1629+
batch.output [batch.n_tokens] = output;
16301630

16311631
batch.n_tokens++;
16321632
}

examples/batched-bench/batched-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
7373
batch.pos + i,
7474
batch.n_seq_id + i,
7575
batch.seq_id + i,
76-
batch.logits + i,
76+
batch.output + i,
7777
};
7878

7979
const int ret = llama_decode(ctx, batch_view);
@@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
128128
common_batch_add(batch, 0, i, { j }, false);
129129
}
130130
}
131-
batch.logits[batch.n_tokens - 1] = true;
131+
batch.output[batch.n_tokens - 1] = true;
132132

133133
const auto t_pp_start = ggml_time_us();
134134

examples/batched.swift/Sources/main.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ for (i, token) in tokens.enumerated() {
104104
if let seq_id = batch.seq_id[i] {
105105
seq_id[0] = 0
106106
}
107-
batch.logits[i] = 0
107+
batch.output[i] = 0
108108
}
109109

110110
// llama_decode will output logits only for the last token of the prompt
111-
batch.logits[Int(batch.n_tokens) - 1] = 1
111+
batch.output[Int(batch.n_tokens) - 1] = 1
112112

113113
if llama_decode(context, batch) != 0 {
114114
print("llama_decode() failed")
@@ -171,7 +171,7 @@ while n_cur <= n_len {
171171
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
172172
seq_id[0] = Int32(i)
173173
}
174-
batch.logits[Int(batch.n_tokens)] = 1
174+
batch.output[Int(batch.n_tokens)] = 1
175175

176176
i_batch[i] = batch.n_tokens
177177

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ int main(int argc, char ** argv) {
131131
}
132132

133133
// llama_decode will output logits only for the last token of the prompt
134-
batch.logits[batch.n_tokens - 1] = true;
134+
batch.output[batch.n_tokens - 1] = true;
135135

136136
if (llama_decode(ctx, batch) != 0) {
137137
LOG_ERR("%s: llama_decode() failed\n", __func__);

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
5454
}
5555

5656
for (int i = 0; i < batch.n_tokens; i++) {
57-
if (!batch.logits[i]) {
57+
if (!batch.output[i]) {
5858
continue;
5959
}
6060

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
193193
common_batch_add(*batch, 0, i, { 0 }, false);
194194
}
195195

196-
batch->logits[batch->n_tokens - 1] = true;
196+
batch->output[batch->n_tokens - 1] = true;
197197
llama_kv_cache_clear(context);
198198

199199
const auto t_pp_start = ggml_time_us();
@@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
297297
for (int i = 0; i < n_tokens; ++i) {
298298
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
299299
}
300-
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
300+
batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
301301

302302
return reinterpret_cast<jlong>(batch);
303303
}
@@ -381,7 +381,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
381381
}
382382

383383
// llama_decode will output logits only for the last token of the prompt
384-
batch->logits[batch->n_tokens - 1] = true;
384+
batch->output[batch->n_tokens - 1] = true;
385385

386386
if (llama_decode(context, *batch) != 0) {
387387
LOGe("llama_decode() failed");

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ func llama_batch_clear(_ batch: inout llama_batch) {
99
batch.n_tokens = 0
1010
}
1111

12-
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) {
12+
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ outputs: Bool) {
1313
batch.token [Int(batch.n_tokens)] = id
1414
batch.pos [Int(batch.n_tokens)] = pos
1515
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
1616
for i in 0..<seq_ids.count {
1717
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
1818
}
19-
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
19+
batch.outputs [Int(batch.n_tokens)] = outputs ? 1 : 0
2020

2121
batch.n_tokens += 1
2222
}
@@ -139,7 +139,7 @@ actor LlamaContext {
139139
let i = Int(i1)
140140
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
141141
}
142-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
142+
batch.output[Int(batch.n_tokens) - 1] = 1 // true
143143

144144
if llama_decode(context, batch) != 0 {
145145
print("llama_decode() failed")
@@ -208,7 +208,7 @@ actor LlamaContext {
208208
for i in 0..<n_tokens {
209209
llama_batch_add(&batch, 0, Int32(i), [0], false)
210210
}
211-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
211+
batch.output[Int(batch.n_tokens) - 1] = 1 // true
212212

213213
llama_kv_cache_clear(context)
214214

examples/llava/llava.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -441,13 +441,13 @@ struct llava_embd_batch {
441441
std::vector<int32_t> n_seq_id;
442442
std::vector<llama_seq_id> seq_id_0;
443443
std::vector<llama_seq_id *> seq_ids;
444-
std::vector<int8_t> logits;
444+
std::vector<int8_t> outputs;
445445
llama_batch batch;
446446
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
447447
pos .resize(n_tokens);
448448
n_seq_id.resize(n_tokens);
449449
seq_ids .resize(n_tokens + 1);
450-
logits .resize(n_tokens);
450+
outputs .resize(n_tokens);
451451
seq_id_0.resize(1);
452452
seq_id_0[0] = seq_id;
453453
seq_ids [n_tokens] = nullptr;
@@ -458,13 +458,13 @@ struct llava_embd_batch {
458458
/*pos =*/ pos.data(),
459459
/*n_seq_id =*/ n_seq_id.data(),
460460
/*seq_id =*/ seq_ids.data(),
461-
/*logits =*/ logits.data(),
461+
/*output =*/ outputs.data(),
462462
};
463463
for (int i = 0; i < n_tokens; i++) {
464464
batch.pos [i] = pos_0 + i;
465465
batch.n_seq_id[i] = 1;
466466
batch.seq_id [i] = seq_id_0.data();
467-
batch.logits [i] = false;
467+
batch.output [i] = false;
468468
}
469469
}
470470
};

examples/parallel/parallel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ int main(int argc, char ** argv) {
266266

267267
// extract the logits only for the last token
268268
if (batch.n_tokens > 0) {
269-
batch.logits[batch.n_tokens - 1] = true;
269+
batch.output[batch.n_tokens - 1] = true;
270270
}
271271

272272
client.n_prompt = tokens_prompt.size();
@@ -309,7 +309,7 @@ int main(int argc, char ** argv) {
309309
batch.pos + i,
310310
batch.n_seq_id + i,
311311
batch.seq_id + i,
312-
batch.logits + i,
312+
batch.output + i,
313313
};
314314

315315
const int ret = llama_decode(ctx, batch_view);

examples/passkey/passkey.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ int main(int argc, char ** argv) {
146146
}
147147

148148
if (i + n_batch >= n_tokens_all) {
149-
batch.logits[batch.n_tokens - 1] = true;
149+
batch.output[batch.n_tokens - 1] = true;
150150
}
151151

152152
if (llama_decode(ctx, batch) != 0) {
@@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
180180
}
181181

182182
if (i + n_batch >= n_tokens_all) {
183-
batch.logits[batch.n_tokens - 1] = true;
183+
batch.output[batch.n_tokens - 1] = true;
184184
}
185185

186186
if (llama_decode(ctx, batch) != 0) {

examples/perplexity/perplexity.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -572,9 +572,9 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
572572
batch.pos [idx] = j*n_batch + k;
573573
batch.n_seq_id[idx] = 1;
574574
batch.seq_id [idx][0] = seq;
575-
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
575+
batch.output [idx] = batch.pos[idx] >= first ? 1 : 0;
576576

577-
n_outputs += batch.logits[idx] != 0;
577+
n_outputs += batch.output[idx] != 0;
578578
}
579579
batch.n_tokens += batch_size;
580580

@@ -669,7 +669,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
669669
batch.pos + i,
670670
batch.n_seq_id + i,
671671
batch.seq_id + i,
672-
batch.logits + i,
672+
batch.output + i,
673673
};
674674

675675
const int ret = llama_decode(ctx, batch_view);
@@ -680,7 +680,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
680680

681681
int n_outputs = 0;
682682
for (int i = 0; i < n_tokens; ++i) {
683-
n_outputs += batch_view.logits[i] != 0;
683+
n_outputs += batch_view.output[i] != 0;
684684
}
685685

686686
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
@@ -896,7 +896,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
896896
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
897897
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
898898
}
899-
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
899+
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
900900
n_logits += 1;
901901

902902
for (int s = 0; s < 4; ++s) {
@@ -1177,7 +1177,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11771177
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
11781178
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
11791179
}
1180-
batch.logits[batch.n_tokens - 1] = true;
1180+
batch.output[batch.n_tokens - 1] = true;
11811181
n_logits += 1;
11821182

11831183
for (int s = 0; s < 2; ++s) {
@@ -1545,7 +1545,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15451545
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
15461546
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
15471547
}
1548-
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
1548+
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
15491549
n_logits += 1;
15501550

15511551
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {

examples/retrieval/retrieval.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
9292
}
9393

9494
for (int i = 0; i < batch.n_tokens; i++) {
95-
if (!batch.logits[i]) {
95+
if (!batch.output[i]) {
9696
continue;
9797
}
9898

examples/save-load-state/save-load-state.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ int main(int argc, char ** argv) {
5252
for (size_t i = 0; i < tokens.size(); i++) {
5353
common_batch_add(batch, tokens[i], i, {0}, false);
5454
}
55-
batch.logits[batch.n_tokens - 1] = true; // generate next token
55+
batch.output[batch.n_tokens - 1] = true; // generate next token
5656

5757
// evaluate prompt
5858
llama_decode(ctx, batch);

examples/server/server.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2413,7 +2413,7 @@ struct server_context {
24132413
std::vector<float> embd_res(n_embd, 0.0f);
24142414

24152415
for (int i = 0; i < batch.n_tokens; ++i) {
2416-
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
2416+
if (!batch.output[i] || batch.seq_id[i][0] != slot.id) {
24172417
continue;
24182418
}
24192419

@@ -2451,7 +2451,7 @@ struct server_context {
24512451
res->n_tokens = slot.n_prompt_tokens;
24522452

24532453
for (int i = 0; i < batch.n_tokens; ++i) {
2454-
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
2454+
if (!batch.output[i] || batch.seq_id[i][0] != slot.id) {
24552455
continue;
24562456
}
24572457

@@ -3109,7 +3109,7 @@ struct server_context {
31093109
}
31103110

31113111
// extract the logits only for the last token
3112-
batch.logits[batch.n_tokens - 1] = true;
3112+
batch.output[batch.n_tokens - 1] = true;
31133113

31143114
slot.n_decoded = 0;
31153115
slot.i_batch = batch.n_tokens - 1;
@@ -3149,7 +3149,7 @@ struct server_context {
31493149
batch.pos + i,
31503150
batch.n_seq_id + i,
31513151
batch.seq_id + i,
3152-
batch.logits + i,
3152+
batch.output + i,
31533153
};
31543154

31553155
const int ret = llama_decode(ctx, batch_view);

examples/tts/tts.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
722722
GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());
723723

724724
// llama_decode will output logits only for the last token of the prompt
725-
batch.logits[batch.n_tokens - 1] = true;
725+
batch.output[batch.n_tokens - 1] = true;
726726

727727
if (llama_decode(ctx_ttc, batch) != 0) {
728728
LOG_ERR("%s: llama_decode() failed\n", __func__);

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ extern "C" {
252252
llama_pos * pos;
253253
int32_t * n_seq_id;
254254
llama_seq_id ** seq_id;
255-
int8_t * logits; // TODO: rename this to "output"
255+
int8_t * output;
256256
} llama_batch;
257257

258258
enum llama_model_kv_override_type {

src/llama-batch.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,17 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
102102
ubatch.output[ubatch.n_tokens + i] = 1;
103103
out_ids.push_back(ids[seq.offset + i]);
104104
}
105-
} else if (batch->logits) {
105+
} else if (batch->output) {
106106
if (ubatch.equal_seqs) {
107107
for (size_t i = 0; i < length; ++i) {
108108
size_t id = ids[seq.offset + i];
109-
int8_t is_output = batch->logits[id];
109+
int8_t is_output = batch->output[id];
110110
ubatch.output[ubatch.n_tokens + i] = is_output;
111111
if (is_output) { out_ids.push_back(id); }
112112
}
113113
} else {
114114
// simple split
115-
ubatch.output = batch->logits + seq.offset;
115+
ubatch.output = batch->output + seq.offset;
116116
for (size_t i = 0; i < length; ++i) {
117117
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118118
}
@@ -298,10 +298,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
298298
}
299299
batch.seq_id = seq_id.data();
300300
}
301-
if (!batch.logits) {
302-
logits.resize(batch.n_tokens);
303-
logits[logits.size() - 1] = true;
304-
batch.logits = logits.data();
301+
if (!batch.output) {
302+
outputs.resize(batch.n_tokens);
303+
outputs[outputs.size() - 1] = true;
304+
batch.output = outputs.data();
305305
}
306306
}
307307

@@ -348,7 +348,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
348348
}
349349
batch.seq_id[n_tokens_alloc] = nullptr;
350350

351-
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
351+
batch.output = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352352

353353
return batch;
354354
}
@@ -364,5 +364,5 @@ void llama_batch_free(struct llama_batch batch) {
364364
}
365365
free(batch.seq_id);
366366
}
367-
if (batch.logits) free(batch.logits);
367+
if (batch.output) free(batch.output);
368368
}

0 commit comments

Comments
 (0)