From d75004730af3cd992f7438ff39eca665e164560a Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 13 Mar 2024 17:10:27 -0500 Subject: [PATCH] Add Customer SamplerStage --- crates/llama_cpp/src/standard_sampler.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/crates/llama_cpp/src/standard_sampler.rs b/crates/llama_cpp/src/standard_sampler.rs index e280ee7..fc79a20 100644 --- a/crates/llama_cpp/src/standard_sampler.rs +++ b/crates/llama_cpp/src/standard_sampler.rs @@ -111,6 +111,9 @@ pub enum SamplerStage { /// /// See [`GrammarStage`] and [`LlamaGrammar`] for more information. Grammar(GrammarStage), + + /// A custom, stateless [`SamplerStage`] defined using a function pointer. + Custom(fn(*mut llama_context, &[Token], llama_token_data_array, usize) -> llama_token_data_array) } impl SamplerStage { @@ -195,6 +198,9 @@ impl SamplerStage { SamplerStage::Grammar(stage) => { candidates_p = stage.apply(context, tokens, candidates_p, min_keep) } + SamplerStage::Custom(func) => { + candidates_p = func(context, tokens, candidates_p, min_keep) + }, } }