Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

support ollama #24

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,16 @@ We support the following models with our models API wrapper (found in `src/model
- `wizardlm-13b`
- `wizardlm-30b`

### Ollama

You need to install the `ollama` package manually.

- `qwen2.5-coder:latest`
- `qwen2.5:32b`
- `llama3.2:latest`
- `deepseek-r1:32b`
- `deepseek-r1:latest`

</details>

## Adding a CWE
Expand Down
6 changes: 6 additions & 0 deletions src/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def __init__(self, model_name, logger: MyLogger, model_name_map, **kwargs):
# nothing else needed if calling gpt
if model_name.lower().startswith("gpt"):
return
# nothing else needed if calling ollama api
elif model_name.lower().startswith("ollama"):
return
# nothing else needed if calling together AI
elif "-tai" in model_name.lower():
return
Expand Down Expand Up @@ -231,6 +234,9 @@ def get_llm(model_name, kwargs, logger):
elif model_name.lower().startswith("gpt"):
from models.gpt import GPTModel
model=GPTModel(model_name=model_name, logger=logger, **kwargs)
elif model_name.lower().startswith("ollama"):
from models.ollama import OllamaModel
model = OllamaModel(model_name=model_name, logger=logger, **kwargs)
elif model_name.lower().startswith("gemma"):
from models.google import GoogleModel
model=GoogleModel(model_name=model_name, logger=logger, **kwargs)
Expand Down
69 changes: 69 additions & 0 deletions src/models/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import ollama
from tqdm.contrib.concurrent import thread_map

from src.models.llm import LLM
from src.utils.mylogger import MyLogger

_model_name_map = {
"ollama-qwen-coder": "qwen2.5-coder:latest",
"ollama-qwen": "qwen2.5:32b",
"ollama-llama3": "llama3.2:latest",
"ollama-deepseek-32b": "deepseek-r1:32b",
"ollama-deepseek-7b": "deepseek-r1:latest",
}

# default model parameters, add or modify according to your needs
# see https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
_OLLAMA_DEFAULT_OPTIONS = {
"temperature": 0.8,
"num_predict": -1,
"stop": None,
"seed": 0,
}


class OllamaModel(LLM):
def __init__(self, model_name, logger: MyLogger, **kwargs):
super().__init__(model_name, logger, _model_name_map, **kwargs)
if host := os.environ.get("OLLAMA_HOST"):
self.client = ollama.Client(host=host)
else:
self.log.error("Please set OLLAMA_HOST environment variable")
# TODO: https://github.com/ollama/ollama/issues/2415
# self.logprobs = None
for k in _OLLAMA_DEFAULT_OPTIONS:
if k in kwargs:
_OLLAMA_DEFAULT_OPTIONS[k] = kwargs[k]

def predict(self, prompt, batch_size=0, no_progress_bar=False):
if batch_size == 0:
return self._predict(prompt)
args = range(0, len(prompt))
responses = thread_map(
lambda x: self._predict(prompt[x]),
args,
max_workers=batch_size,
disable=no_progress_bar,
)
return responses

def _predict(self, main_prompt):
# assuming 0 is system and 1 is user
system_prompt = main_prompt[0]["content"]
user_prompt = main_prompt[1]["content"]
prompt = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
try:
response = self.client.chat(
model=self.model_id,
messages=prompt,
options=_OLLAMA_DEFAULT_OPTIONS,
)
except ollama.ResponseError as e:
print("Ollama Response Error:", e.error)
return None

return response.message.content