Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Implementation of a sequence repetition penalty sampler #2593

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
11fa3df
Implementation of a sequence repetition penalty
KerfuffleV2 Aug 12, 2023
34175b0
Expand simple-inference command support
KerfuffleV2 Nov 2, 2023
e2990ff
Fix batched-bench directly depending on common.o
KerfuffleV2 Nov 2, 2023
a10f7cd
Fix logic in simple-inference chunk concat and dump
KerfuffleV2 Nov 2, 2023
a0c5587
Expand simple-inference command handling.
KerfuffleV2 Nov 3, 2023
87061ca
Remove build-info.h include
KerfuffleV2 Nov 3, 2023
63b3776
Fix invalid seqnum in commands when seqnum ommitted in some cases.
KerfuffleV2 Nov 3, 2023
557d867
Minor cleanups.
KerfuffleV2 Nov 9, 2023
930e132
Let's try merging master instead of rebasing for a little change of pace
KerfuffleV2 Nov 13, 2023
3c76bd6
convert.py: also look for plain model.safetensors (#4043)
afrideva Nov 14, 2023
2751031
stablelm : StableLM support (#3586)
Galunid Nov 14, 2023
affa88b
Fix MacOS Sonoma model quantization (#4052)
TortoiseHam Nov 14, 2023
208bdcd
ggml-cuda : increase max graph size (#4084)
slaren Nov 15, 2023
4fc5f7d
llama : restore prefix space in llama tokenizer (#4081)
cebtenzzre Nov 15, 2023
b94b982
gguf : fix potential infinite loops while parsing (#4100)
texmex76 Nov 16, 2023
c301973
Respect tokenizer.ggml.add_bos_token value when tokenizing (#4040)
KerfuffleV2 Nov 17, 2023
16868e2
Merge branch 'master' into feat-seqrep-sampler-simple
KerfuffleV2 Nov 17, 2023
f109568
Merge branch 'master' into feat-seqrep-sampler-simple
KerfuffleV2 Nov 18, 2023
89262de
Merge branch 'master' into feat-seqrep-sampler-simple
KerfuffleV2 Nov 18, 2023
046a469
Fix(ish?) prompt tokenizing
KerfuffleV2 Nov 18, 2023
dc1e34a
Merge branch 'master' into feat-seqrep-sampler-simple
KerfuffleV2 Nov 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging"
option(LLAMA_MPI "llama: use MPI" OFF)
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)

option(LLAMA_SEQREP_SAMPLER "llama: build with support for seqrep sampler" ON)

option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_SERVER "llama: build server example" ON)
Expand Down
15 changes: 13 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
BUILD_TARGETS = \
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
speculative infill tokenize benchmark-matmult parallel finetune export-lora tests/test-c.o
speculative infill tokenize benchmark-matmult parallel finetune export-lora simple-inference tests/test-c.o

# Binaries only useful for tests
TEST_TARGETS = \
Expand Down Expand Up @@ -572,6 +572,14 @@ grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
train.o: common/train.cpp common/train.h
$(CXX) $(CXXFLAGS) -c $< -o $@

ifndef LLAMA_NO_SEQREP_SAMPLER
COMMON_H_DEFS += common/seqrep-sampler.h
COMMON_DEPS += seqrep-sampler.o

seqrep-sampler.o: common/seqrep-sampler.cpp common/seqrep-sampler.h $(COMMON_H_DEPS)
$(CXX) $(CXXFLAGS) -c $< -o $@
endif

libllama.so: llama.o ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)

Expand All @@ -594,13 +602,16 @@ infill: examples/infill/infill.cpp ggml.o llama.o $(C
simple: examples/simple/simple.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

simple-inference: examples/simple-inference/simple-inference.cpp ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

tokenize: examples/tokenize/tokenize.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

batched: examples/batched/batched.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

batched-bench: examples/batched-bench/batched-bench.cpp build-info.o ggml.o llama.o common.o $(OBJS)
batched-bench: examples/batched-bench/batched-bench.cpp build-info.o ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

quantize: examples/quantize/quantize.cpp build-info.o ggml.o llama.o $(OBJS)
Expand Down
2 changes: 2 additions & 0 deletions build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ pub fn build(b: *std.build.Builder) !void {
var make = try Maker.init(b);
make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false;

try make.addFlag("-DLLAMA_NO_SEQREP_SAMPLER");

const ggml = make.obj("ggml", "ggml.c");
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");
Expand Down
6 changes: 6 additions & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ add_library(${TARGET} STATIC
train.cpp
)

if (LLAMA_SEQREP_SAMPLER)
target_sources(${TARGET} PRIVATE seqrep-sampler.h seqrep-sampler.cpp)
else()
add_compile_definitions(LLAMA_NO_SEQREP_SAMPLER)
endif()

if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
Expand Down
26 changes: 26 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include "common.h"
#include "llama.h"

#ifndef LLAMA_NO_SEQREP_SAMPLER
#include "seqrep-sampler.h"
#endif

#include <algorithm>
#include <cassert>
#include <cmath>
Expand Down Expand Up @@ -336,6 +340,24 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
sparams.penalty_present = std::stof(argv[i]);
#ifndef LLAMA_NO_SEQREP_SAMPLER
} else if (arg == "-seqrep" || arg == "--seqrep-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
if (std::strcmp(argv[i], "help") == 0) {
seqrep_sampler_help();
exit(0);
}
llama_sampler_seqrep_params sr_params;
seqrep_sampler_params_init(&sr_params);
if (!seqrep_sampler_params_parse(argv[i], &sr_params)) {
seqrep_sampler_help();
exit(1);
}
sparams.seqrep_params.push_back(sr_params);
#endif
} else if (arg == "--mirostat") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -770,6 +792,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
#ifndef LLAMA_NO_SEQREP_SAMPLER
printf(" -seqrep CFG, --seqrep-penalty CFG\n");
printf(" add a copy of the sequence repetition penalty sampler. may be specified multiple times. for help: -seqrep help\n");
#endif
printf(" --mirostat N use Mirostat sampling.\n");
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
Expand Down
10 changes: 9 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
const int idx,
const std::vector<llama_token> & all_last_tokens) {
const llama_sampling_params & params = ctx_sampling->params;

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
Expand Down Expand Up @@ -155,6 +156,13 @@ llama_token llama_sampling_sample(
prev.data() + prev.size() - penalty_last_n,
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);

#ifndef LLAMA_NO_SEQREP_SAMPLER
for (auto & sr_params : params.seqrep_params) {
if ((sr_params.flags & LLAMA_SEQREP_REWIND_MODE) != 0) continue;
llama_sample_seqrep_penalty(ctx_main, &cur_p, all_last_tokens, &sr_params);
}
#endif

if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
Expand Down
12 changes: 11 additions & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

#include "grammar-parser.h"

#ifndef LLAMA_NO_SEQREP_SAMPLER
#include "seqrep-sampler.h"
#endif

#include <string>
#include <vector>
#include <unordered_map>
Expand Down Expand Up @@ -35,6 +39,11 @@ typedef struct llama_sampling_params {
float cfg_scale = 1.f; // how strong is guidance

std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

#ifndef LLAMA_NO_SEQREP_SAMPLER
std::vector<llama_sampler_seqrep_params> seqrep_params;
#endif

} llama_sampling_params;

// general sampler context
Expand Down Expand Up @@ -101,7 +110,8 @@ llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);
int idx = 0,
const std::vector<llama_token> & all_last_tokens = {});

void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
Expand Down
Loading