From 7bf1e8f35f6c035db29dac2551afd0c680fa609c Mon Sep 17 00:00:00 2001 From: 4shen0ne <4shen.01@gmail.com> Date: Fri, 28 Feb 2025 18:02:21 +0800 Subject: [PATCH] add a model to support ollama (closes #19) --- README.md | 10 +++++++ src/models/llm.py | 6 ++++ src/models/ollama.py | 69 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+) create mode 100644 src/models/ollama.py diff --git a/README.md b/README.md index 93e2f50..346e686 100644 --- a/README.md +++ b/README.md @@ -375,6 +375,16 @@ We support the following models with our models API wrapper (found in `src/model - `wizardlm-13b` - `wizardlm-30b` +### Ollama + +You need to install the `ollama` package manually. + +- `qwen2.5-coder:latest` +- `qwen2.5:32b` +- `llama3.2:latest` +- `deepseek-r1:32b` +- `deepseek-r1:latest` + ## Adding a CWE diff --git a/src/models/llm.py b/src/models/llm.py index 25c9585..37d5ad8 100644 --- a/src/models/llm.py +++ b/src/models/llm.py @@ -32,6 +32,9 @@ def __init__(self, model_name, logger: MyLogger, model_name_map, **kwargs): # nothing else needed if calling gpt if model_name.lower().startswith("gpt"): return + # nothing else needed if calling ollama api + elif model_name.lower().startswith("ollama"): + return # nothing else needed if calling together AI elif "-tai" in model_name.lower(): return @@ -231,6 +234,9 @@ def get_llm(model_name, kwargs, logger): elif model_name.lower().startswith("gpt"): from models.gpt import GPTModel model=GPTModel(model_name=model_name, logger=logger, **kwargs) + elif model_name.lower().startswith("ollama"): + from models.ollama import OllamaModel + model = OllamaModel(model_name=model_name, logger=logger, **kwargs) elif model_name.lower().startswith("gemma"): from models.google import GoogleModel model=GoogleModel(model_name=model_name, logger=logger, **kwargs) diff --git a/src/models/ollama.py b/src/models/ollama.py new file mode 100644 index 0000000..32200b5 --- /dev/null +++ b/src/models/ollama.py @@ -0,0 +1,69 @@ +import os +import ollama +from tqdm.contrib.concurrent import thread_map + +from src.models.llm import LLM +from src.utils.mylogger import MyLogger + +_model_name_map = { + "ollama-qwen-coder": "qwen2.5-coder:latest", + "ollama-qwen": "qwen2.5:32b", + "ollama-llama3": "llama3.2:latest", + "ollama-deepseek-32b": "deepseek-r1:32b", + "ollama-deepseek-7b": "deepseek-r1:latest", +} + +# default model parameters, add or modify according to your needs +# see https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values +_OLLAMA_DEFAULT_OPTIONS = { + "temperature": 0.8, + "num_predict": -1, + "stop": None, + "seed": 0, +} + + +class OllamaModel(LLM): + def __init__(self, model_name, logger: MyLogger, **kwargs): + super().__init__(model_name, logger, _model_name_map, **kwargs) + if host := os.environ.get("OLLAMA_HOST"): + self.client = ollama.Client(host=host) + else: + self.log.error("Please set OLLAMA_HOST environment variable") + # TODO: https://github.com/ollama/ollama/issues/2415 + # self.logprobs = None + for k in _OLLAMA_DEFAULT_OPTIONS: + if k in kwargs: + _OLLAMA_DEFAULT_OPTIONS[k] = kwargs[k] + + def predict(self, prompt, batch_size=0, no_progress_bar=False): + if batch_size == 0: + return self._predict(prompt) + args = range(0, len(prompt)) + responses = thread_map( + lambda x: self._predict(prompt[x]), + args, + max_workers=batch_size, + disable=no_progress_bar, + ) + return responses + + def _predict(self, main_prompt): + # assuming 0 is system and 1 is user + system_prompt = main_prompt[0]["content"] + user_prompt = main_prompt[1]["content"] + prompt = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + try: + response = self.client.chat( + model=self.model_id, + messages=prompt, + options=_OLLAMA_DEFAULT_OPTIONS, + ) + except ollama.ResponseError as e: + print("Ollama Response Error:", e.error) + return None + + return response.message.content