diff --git a/CHANGELOG.md b/CHANGELOG.md index bcc5054..a1326e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added - Support for [llama.cpp](https://github.com/ggerganov/llama.cpp) +- Support for [ollama](https://github.com/ollama/ollama) ## [0.10] 27/01/2024 diff --git a/README.md b/README.md index 1916a21..6df05eb 100644 --- a/README.md +++ b/README.md @@ -17,14 +17,15 @@ - Save chats to files - Vim keybinding (most common ops) - Copy text from/to clipboard (works only on the prompt) +- Multiple backends
## 💎 Supported LLMs - [x] ChatGPT -- [ ] ollama (todo) - [x] llama.cpp (in the `master` branch) +- [x] ollama (in the `master` branch)
@@ -78,7 +79,10 @@ Tenere can be configured using a TOML configuration file. The file should be loc Here are the available general settings: - `archive_file_name`: the file name where the chat will be saved. By default it is set to `tenere.archive` -- `model`: the llm model name. Possible values are: `chatgpt` and `llamacpp`. +- `model`: the llm model name. Possible values are: + - `chatgpt` + - `llamacpp` + - `ollama` ```toml archive_file_name = "tenere.archive" @@ -152,6 +156,20 @@ url = "http://localhost:8080/v1/chat/completions" api_key = "Your API Key here" ``` +More infos about llama.cpp api [here](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md) + +## Ollama + +To use `ollama` as the backemd, you'll need to provide the url that points to the server with the model name : + +```toml +[ollama] +url = "http://localhost:11434/api/chat" +model = "Your model name here" +``` + +More infos about ollama api [here](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion) +
## ⌨️ Key bindings @@ -277,11 +295,3 @@ There are 3 modes like vim: `Normal`, `Visual` and `Insert`. ## ⚖️ License AGPLv3 - -``` - -``` - -``` - -``` diff --git a/src/config.rs b/src/config.rs index 4eb1ef6..acda99c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -19,6 +19,8 @@ pub struct Config { pub chatgpt: ChatGPTConfig, pub llamacpp: Option, + + pub ollama: Option, } pub fn default_archive_file_name() -> String { @@ -69,6 +71,14 @@ pub struct LLamacppConfig { pub api_key: Option, } +// Ollama + +#[derive(Deserialize, Debug, Clone)] +pub struct OllamaConfig { + pub url: String, + pub model: String, +} + #[derive(Deserialize, Debug)] pub struct KeyBindings { #[serde(default = "KeyBindings::default_show_help")] @@ -136,6 +146,11 @@ impl Config { std::process::exit(1) } + if app_config.llm == LLMBackend::Ollama && app_config.ollama.is_none() { + eprintln!("Config for Ollama is not provided"); + std::process::exit(1) + } + app_config } } diff --git a/src/lib.rs b/src/lib.rs index f7710f7..596b9db 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,3 +31,5 @@ pub mod history; pub mod chat; pub mod llamacpp; + +pub mod ollama; diff --git a/src/llm.rs b/src/llm.rs index c53b3be..5da3750 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -2,6 +2,7 @@ use crate::chatgpt::ChatGPT; use crate::config::Config; use crate::event::Event; use crate::llamacpp::LLamacpp; +use crate::ollama::Ollama; use async_trait::async_trait; use serde::Deserialize; use std::sync::atomic::AtomicBool; @@ -43,6 +44,7 @@ pub enum LLMRole { pub enum LLMBackend { ChatGPT, LLamacpp, + Ollama, } pub struct LLMModel; @@ -52,6 +54,7 @@ impl LLMModel { match model { LLMBackend::ChatGPT => Box::new(ChatGPT::new(config.chatgpt.clone())), LLMBackend::LLamacpp => Box::new(LLamacpp::new(config.llamacpp.clone().unwrap())), + LLMBackend::Ollama => Box::new(Ollama::new(config.ollama.clone().unwrap())), } } } diff --git a/src/ollama.rs b/src/ollama.rs new file mode 100644 index 0000000..a7a7aa1 --- /dev/null +++ b/src/ollama.rs @@ -0,0 +1,110 @@ +use std::sync::atomic::{AtomicBool, Ordering}; + +use std::sync::Arc; + +use crate::config::OllamaConfig; +use crate::event::Event; +use async_trait::async_trait; +use tokio::sync::mpsc::UnboundedSender; + +use crate::llm::{LLMAnswer, LLMRole, LLM}; +use reqwest::header::HeaderMap; +use serde_json::{json, Value}; +use std; +use std::collections::HashMap; + +#[derive(Clone, Debug)] +pub struct Ollama { + client: reqwest::Client, + url: String, + model: String, + messages: Vec>, +} + +impl Ollama { + pub fn new(config: OllamaConfig) -> Self { + Self { + client: reqwest::Client::new(), + url: config.url, + model: config.model, + messages: Vec::new(), + } + } +} + +#[async_trait] +impl LLM for Ollama { + fn clear(&mut self) { + self.messages = Vec::new(); + } + + fn append_chat_msg(&mut self, msg: String, role: LLMRole) { + let mut conv: HashMap = HashMap::new(); + conv.insert("role".to_string(), role.to_string()); + conv.insert("content".to_string(), msg); + self.messages.push(conv); + } + + async fn ask( + &self, + sender: UnboundedSender, + terminate_response_signal: Arc, + ) -> Result<(), Box> { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "application/json".parse()?); + + let mut messages: Vec> = vec![ + (HashMap::from([ + ("role".to_string(), "system".to_string()), + ( + "content".to_string(), + "You are a helpful assistant.".to_string(), + ), + ])), + ]; + + messages.extend(self.messages.clone()); + + let body: Value = json!({ + "messages": messages, + "model": self.model, + "stream": true, + }); + + let response = self + .client + .post(&self.url) + .headers(headers) + .json(&body) + .send() + .await?; + + match response.error_for_status() { + Ok(mut res) => { + sender.send(Event::LLMEvent(LLMAnswer::StartAnswer))?; + while let Some(chunk) = res.chunk().await? { + if terminate_response_signal.load(Ordering::Relaxed) { + sender.send(Event::LLMEvent(LLMAnswer::EndAnswer))?; + return Ok(()); + } + + let answer: Value = serde_json::from_slice(chunk.as_ref())?; + + if answer["done"].as_bool().unwrap() { + sender.send(Event::LLMEvent(LLMAnswer::EndAnswer))?; + return Ok(()); + } + + let msg = answer["message"]["content"].as_str().unwrap_or("\n"); + + sender.send(Event::LLMEvent(LLMAnswer::Answer(msg.to_string())))?; + } + } + Err(e) => return Err(Box::new(e)), + } + + sender.send(Event::LLMEvent(LLMAnswer::EndAnswer))?; + + Ok(()) + } +}