diff --git a/environment.yml b/environment.yml index e924b55..2f1b7a1 100644 --- a/environment.yml +++ b/environment.yml @@ -11,3 +11,4 @@ dependencies: - openai - accelerate - transformers + - ollama diff --git a/src/models/llm.py b/src/models/llm.py index 25c9585..37d5ad8 100644 --- a/src/models/llm.py +++ b/src/models/llm.py @@ -32,6 +32,9 @@ def __init__(self, model_name, logger: MyLogger, model_name_map, **kwargs): # nothing else needed if calling gpt if model_name.lower().startswith("gpt"): return + # nothing else needed if calling ollama api + elif model_name.lower().startswith("ollama"): + return # nothing else needed if calling together AI elif "-tai" in model_name.lower(): return @@ -231,6 +234,9 @@ def get_llm(model_name, kwargs, logger): elif model_name.lower().startswith("gpt"): from models.gpt import GPTModel model=GPTModel(model_name=model_name, logger=logger, **kwargs) + elif model_name.lower().startswith("ollama"): + from models.ollama import OllamaModel + model = OllamaModel(model_name=model_name, logger=logger, **kwargs) elif model_name.lower().startswith("gemma"): from models.google import GoogleModel model=GoogleModel(model_name=model_name, logger=logger, **kwargs) diff --git a/src/models/ollama.py b/src/models/ollama.py new file mode 100644 index 0000000..f3f52a4 --- /dev/null +++ b/src/models/ollama.py @@ -0,0 +1,66 @@ +import os +import ollama +from tqdm.contrib.concurrent import thread_map + +from src.models.llm import LLM +from src.utils.mylogger import MyLogger + +_model_name_map = { + "ollama-qwen-coder": "qwen2.5-coder:latest", + "ollama-qwen": "qwen2.5:32b", + "ollama-llama3": "llama3.2:latest", + "ollama-deepseek-32b": "deepseek-r1:32b", + "ollama-deepseek-7b": "deepseek-r1:latest", +} +_OLLAMA_DEFAULT_OPTIONS = { + "temperature": 0, + "num_predict": 4096, + "stop": None, + "seed": 345, +} + + +class OllamaModel(LLM): + def __init__(self, model_name, logger: MyLogger, **kwargs): + super().__init__(model_name, logger, _model_name_map, **kwargs) + if host := os.environ.get("OLLAMA_HOST"): + self.client = ollama.Client(host=host) + else: + self.log.error("Please set OLLAMA_HOST environment variable") + # TODO: https://github.com/ollama/ollama/issues/2415 + # self.logprobs = None + for k in _OLLAMA_DEFAULT_OPTIONS: + if k in kwargs: + _OLLAMA_DEFAULT_OPTIONS[k] = kwargs[k] + + def predict(self, prompt, batch_size=0, no_progress_bar=False): + if batch_size == 0: + return self._predict(prompt) + args = range(0, len(prompt)) + responses = thread_map( + lambda x: self._predict(prompt[x]), + args, + max_workers=batch_size, + disable=no_progress_bar, + ) + return responses + + def _predict(self, main_prompt): + # assuming 0 is system and 1 is user + system_prompt = main_prompt[0]["content"] + user_prompt = main_prompt[1]["content"] + prompt = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + try: + response = self.client.chat( + model=self.model_id, + messages=prompt, + options=_OLLAMA_DEFAULT_OPTIONS, + ) + except ollama.ResponseError as e: + print("Ollama Response Error:", e.error) + return None + + return response.message.content