diff --git a/crates/llama_cpp/src/standard_sampler.rs b/crates/llama_cpp/src/standard_sampler.rs index e280ee7..fc79a20 100644 --- a/crates/llama_cpp/src/standard_sampler.rs +++ b/crates/llama_cpp/src/standard_sampler.rs @@ -111,6 +111,9 @@ pub enum SamplerStage { /// /// See [`GrammarStage`] and [`LlamaGrammar`] for more information. Grammar(GrammarStage), + + /// A custom, stateless [`SamplerStage`] defined using a function pointer. + Custom(fn(*mut llama_context, &[Token], llama_token_data_array, usize) -> llama_token_data_array) } impl SamplerStage { @@ -195,6 +198,9 @@ impl SamplerStage { SamplerStage::Grammar(stage) => { candidates_p = stage.apply(context, tokens, candidates_p, min_keep) } + SamplerStage::Custom(func) => { + candidates_p = func(context, tokens, candidates_p, min_keep) + }, } }