From e8780125db3c5966493621eb3e5f2a7713ccc419 Mon Sep 17 00:00:00 2001 From: Ivan Ogasawara Date: Fri, 15 Nov 2024 12:54:39 -0400 Subject: [PATCH] fix: Remove experimental augmented classes (#21) --- src/rago/augmented/base.py | 24 ------------ src/rago/augmented/experimental/__init__.py | 1 - src/rago/augmented/experimental/gemini.py | 41 --------------------- tests/test_gemini.py | 6 +-- 4 files changed, 2 insertions(+), 70 deletions(-) delete mode 100644 src/rago/augmented/experimental/__init__.py delete mode 100644 src/rago/augmented/experimental/gemini.py diff --git a/src/rago/augmented/base.py b/src/rago/augmented/base.py index d8e2e67..f98a31e 100644 --- a/src/rago/augmented/base.py +++ b/src/rago/augmented/base.py @@ -25,21 +25,10 @@ class AugmentedBase: model_name: str = '' db: Any top_k: int = 0 - temperature: float = 0.5 - prompt_template: str = '' - result_separator = '\n' - output_max_length: int = 500 # default values to be overwritten by the derived classes default_model_name: str = '' default_top_k: int = 0 - default_temperature: float = 0.5 - default_prompt_template: str = ( - 'Retrieve {top_k} entries from the context that better answer the ' - 'following query:\n```\n{query}\n```\n\ncontext:\n```\n{context}\n```' - ) - default_result_separator = '\n' - default_output_max_length: int = 500 def __init__( self, @@ -47,10 +36,6 @@ def __init__( api_key: str = '', db: DBBase = FaissDB(), top_k: int = 0, - temperature: float = 0.5, - prompt_template: str = '', - result_separator: str = '\n', - output_max_length: int = 500, ) -> None: """Initialize AugmentedBase.""" self.db = db @@ -58,15 +43,6 @@ def __init__( self.top_k = top_k or self.default_top_k self.model_name = model_name or self.default_model_name - self.temperature = temperature or self.default_temperature - self.result_separator = ( - result_separator or self.default_result_separator - ) - self.prompt_template = prompt_template or self.default_prompt_template - self.output_max_length = ( - output_max_length or self.default_output_max_length - ) - self.model = None self._validate() diff --git a/src/rago/augmented/experimental/__init__.py b/src/rago/augmented/experimental/__init__.py deleted file mode 100644 index e6a50c9..0000000 --- a/src/rago/augmented/experimental/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Experimental augmented classes.""" diff --git a/src/rago/augmented/experimental/gemini.py b/src/rago/augmented/experimental/gemini.py deleted file mode 100644 index d9de567..0000000 --- a/src/rago/augmented/experimental/gemini.py +++ /dev/null @@ -1,41 +0,0 @@ -"""GeminiAug class for query augmentation using Google's Gemini Model.""" - -from __future__ import annotations - -import google.generativeai as genai - -from typeguard import typechecked - -from rago.augmented.base import AugmentedBase - - -@typechecked -class GeminiAug(AugmentedBase): - """GeminiAug class for query augmentation using Gemini API.""" - - default_model_name: str = 'gemini-1.5-flash' - default_top_k: int = 1 - - def _setup(self) -> None: - """Set up the object with the initial parameters.""" - genai.configure(api_key=self.api_key) - - def search( - self, query: str, documents: list[str], top_k: int = 0 - ) -> list[str]: - """Augment the query by expanding or rephrasing it using Gemini.""" - top_k = top_k or self.top_k - prompt = self.prompt_template.format( - query=query, context=' '.join(documents), top_k=top_k - ) - - response = genai.GenerativeModel(self.model_name).generate_content( - prompt - ) - - augmented_query = str( - response.text.strip() - if hasattr(response, 'text') - else response[0].text.strip() - ) - return augmented_query.split(self.result_separator)[:top_k] diff --git a/tests/test_gemini.py b/tests/test_gemini.py index 91d6046..1d41e91 100644 --- a/tests/test_gemini.py +++ b/tests/test_gemini.py @@ -5,7 +5,7 @@ import pytest from rago import Rago -from rago.augmented.experimental.gemini import GeminiAug +from rago.augmented import SentenceTransformerAug from rago.generation import GeminiGen from rago.retrieval import StringRet @@ -27,9 +27,7 @@ def test_gemini_generation(animals_data: list[str], api_key: str) -> None: # Instantiate Rago with the Gemini model rag = Rago( retrieval=StringRet(animals_data), - augmented=GeminiAug( - top_k=3 - ), # Update if using a specific augmentation class for Gemini + augmented=SentenceTransformerAug(top_k=3), generation=GeminiGen(api_key=api_key, model_name='gemini-1.5-flash'), )