Skip to content

Commit

Permalink
add a model to support ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
zrquan committed Mar 3, 2025
1 parent 70ea45c commit ab045a1
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ dependencies:
- openai
- accelerate
- transformers
- ollama
6 changes: 6 additions & 0 deletions src/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def __init__(self, model_name, logger: MyLogger, model_name_map, **kwargs):
# nothing else needed if calling gpt
if model_name.lower().startswith("gpt"):
return
# nothing else needed if calling ollama api
elif model_name.lower().startswith("ollama"):
return
# nothing else needed if calling together AI
elif "-tai" in model_name.lower():
return
Expand Down Expand Up @@ -231,6 +234,9 @@ def get_llm(model_name, kwargs, logger):
elif model_name.lower().startswith("gpt"):
from models.gpt import GPTModel
model=GPTModel(model_name=model_name, logger=logger, **kwargs)
elif model_name.lower().startswith("ollama"):
from models.ollama import OllamaModel
model = OllamaModel(model_name=model_name, logger=logger, **kwargs)
elif model_name.lower().startswith("gemma"):
from models.google import GoogleModel
model=GoogleModel(model_name=model_name, logger=logger, **kwargs)
Expand Down
66 changes: 66 additions & 0 deletions src/models/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import ollama
from tqdm.contrib.concurrent import thread_map

from src.models.llm import LLM
from src.utils.mylogger import MyLogger

_model_name_map = {
"ollama-qwen-coder": "qwen2.5-coder:latest",
"ollama-qwen": "qwen2.5:32b",
"ollama-llama3": "llama3.2:latest",
"ollama-deepseek-32b": "deepseek-r1:32b",
"ollama-deepseek-7b": "deepseek-r1:latest",
}
_OLLAMA_DEFAULT_OPTIONS = {
"temperature": 0,
"num_predict": 4096,
"stop": None,
"seed": 345,
}


class OllamaModel(LLM):
def __init__(self, model_name, logger: MyLogger, **kwargs):
super().__init__(model_name, logger, _model_name_map, **kwargs)
if host := os.environ.get("OLLAMA_HOST"):
self.client = ollama.Client(host=host)
else:
self.log.error("Please set OLLAMA_HOST environment variable")
# TODO: https://github.com/ollama/ollama/issues/2415
# self.logprobs = None
for k in _OLLAMA_DEFAULT_OPTIONS:
if k in kwargs:
_OLLAMA_DEFAULT_OPTIONS[k] = kwargs[k]

def predict(self, prompt, batch_size=0, no_progress_bar=False):
if batch_size == 0:
return self._predict(prompt)
args = range(0, len(prompt))
responses = thread_map(
lambda x: self._predict(prompt[x]),
args,
max_workers=batch_size,
disable=no_progress_bar,
)
return responses

def _predict(self, main_prompt):
# assuming 0 is system and 1 is user
system_prompt = main_prompt[0]["content"]
user_prompt = main_prompt[1]["content"]
prompt = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
try:
response = self.client.chat(
model=self.model_id,
messages=prompt,
options=_OLLAMA_DEFAULT_OPTIONS,
)
except ollama.ResponseError as e:
print("Ollama Response Error:", e.error)
return None

return response.message.content

0 comments on commit ab045a1

Please # to comment.