Skip to content

Commit

Permalink
fix: Remove experimental augmented classes (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab authored Nov 15, 2024
1 parent 6aee564 commit e878012
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 70 deletions.
24 changes: 0 additions & 24 deletions src/rago/augmented/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,48 +25,24 @@ class AugmentedBase:
model_name: str = ''
db: Any
top_k: int = 0
temperature: float = 0.5
prompt_template: str = ''
result_separator = '\n'
output_max_length: int = 500

# default values to be overwritten by the derived classes
default_model_name: str = ''
default_top_k: int = 0
default_temperature: float = 0.5
default_prompt_template: str = (
'Retrieve {top_k} entries from the context that better answer the '
'following query:\n```\n{query}\n```\n\ncontext:\n```\n{context}\n```'
)
default_result_separator = '\n'
default_output_max_length: int = 500

def __init__(
self,
model_name: str = '',
api_key: str = '',
db: DBBase = FaissDB(),
top_k: int = 0,
temperature: float = 0.5,
prompt_template: str = '',
result_separator: str = '\n',
output_max_length: int = 500,
) -> None:
"""Initialize AugmentedBase."""
self.db = db
self.api_key = api_key

self.top_k = top_k or self.default_top_k
self.model_name = model_name or self.default_model_name
self.temperature = temperature or self.default_temperature
self.result_separator = (
result_separator or self.default_result_separator
)
self.prompt_template = prompt_template or self.default_prompt_template
self.output_max_length = (
output_max_length or self.default_output_max_length
)

self.model = None

self._validate()
Expand Down
1 change: 0 additions & 1 deletion src/rago/augmented/experimental/__init__.py

This file was deleted.

41 changes: 0 additions & 41 deletions src/rago/augmented/experimental/gemini.py

This file was deleted.

6 changes: 2 additions & 4 deletions tests/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from rago import Rago
from rago.augmented.experimental.gemini import GeminiAug
from rago.augmented import SentenceTransformerAug
from rago.generation import GeminiGen
from rago.retrieval import StringRet

Expand All @@ -27,9 +27,7 @@ def test_gemini_generation(animals_data: list[str], api_key: str) -> None:
# Instantiate Rago with the Gemini model
rag = Rago(
retrieval=StringRet(animals_data),
augmented=GeminiAug(
top_k=3
), # Update if using a specific augmentation class for Gemini
augmented=SentenceTransformerAug(top_k=3),
generation=GeminiGen(api_key=api_key, model_name='gemini-1.5-flash'),
)

Expand Down

0 comments on commit e878012

Please # to comment.