Skip to content

Commit

Permalink
Merge pull request #59 from willwoodward/fixing-classes
Browse files Browse the repository at this point in the history
Fixing classes
  • Loading branch information
willwoodward authored Oct 30, 2024
2 parents 0049a05 + 4f5682b commit 530f48b
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 42 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: Run Unit Tests
on:
pull_request:
branches: [ 'main' ]
workflow_dispatch:

jobs:
test:
Expand All @@ -18,11 +19,12 @@ jobs:
python-version: '3.12'

- name: Install dependencies
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
python -m pip install --upgrade pip
pip install -e .
pip install pytest
pip install python-dotenv
woodwork init --all
- name: Run tests
run: pytest -s tests/
pytest -s tests/
41 changes: 39 additions & 2 deletions tests/input_interface_test.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,58 @@
import pytest
import os
from dotenv import load_dotenv

from woodwork.components.input_interface import input_interface
from woodwork.helper_functions import import_all_classes
import_all_classes('woodwork.components')
load_dotenv()

def get_all_subclasses(cls):
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses.extend(get_all_subclasses(subclass))
return subclasses

input_implementors = get_all_subclasses(input_interface)
def get_leaf_subclasses(cls):
# Get direct subclasses of the provided class
subclasses = cls.__subclasses__()
leaf_subclasses = []

for subclass in subclasses:
# Recursively find leaf subclasses
leaves = get_leaf_subclasses(subclass)
# If the subclass has no further subclasses, it’s a leaf node
if not leaves:
leaf_subclasses.append(subclass)
else:
leaf_subclasses.extend(leaves)

return leaf_subclasses

# Factory function to initialise classes with default inputs
def create_instance(cls):
# class_name: parameters
default_config = {
"openai": {
"name": "openai_example",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"hugging_face": {
"name": "hugging_face-example",
"api_key": os.getenv("HF_API_TOKEN"),
}
}

if cls.__name__ in default_config.keys():
return cls(**default_config[cls.__name__])
return cls()

input_implementors = get_leaf_subclasses(input_interface)
print("Collected subclasses of input_interface:", input_implementors)

@pytest.mark.parametrize("input_implementor", input_implementors)
def test_input_returns(input_implementor):
input_instance = input_implementor()
input_instance = create_instance(input_implementor)

try:
result = input_instance.input("Some input")
Expand Down
20 changes: 11 additions & 9 deletions woodwork/components/llms/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
from woodwork.components.llms.llm import llm

class hugging_face(llm):
def __init__(self, name, config):
def __init__(self, name, api_key: str, model="mistralai/Mixtral-8x7B-Instruct-v0.1", **config):
super().__init__(name, **config)
print_debug(f"Establishing connection with model...")

llm = HuggingFaceEndpoint(
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
self._llm_value = HuggingFaceEndpoint(
repo_id=model,
temperature=0.1,
model_kwargs={"max_length": 100},
huggingfacehub_api_token=config["api_key"]
huggingfacehub_api_token=api_key
)

retriever = None
if "knowledge_base" in config:
retriever = config["knowledge_base"].retriever

super().__init__(name, llm, retriever, config)
self._retriever = config.get("knowledge_base")
if self._retriever:
self._retriever = self._retriever.retriever

print_debug("Model initialised.")

@property
def _llm(self): return self._llm_value
32 changes: 18 additions & 14 deletions woodwork/components/llms/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,24 @@
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from abc import ABC, abstractmethod

class llm(component, input_interface):
def __init__(self, name, llm, retriever, config):
# Each LLM will have a: LLM object, input_handler, retriever?
class llm(component, input_interface, ABC):
def __init__(self, name, **config):
super().__init__(name, "llm")

self.__llm = llm
self.__retriever = retriever

self._memory = None
if "memory" in config:
self._memory = config["memory"]
self._memory = config.get("memory")

@property
@abstractmethod
def _llm(self): pass

def input(self, input: str) -> str:
return self.input_handler(input)

def input_handler(self, query):
# If there is no retriever object, there is no connected Knowledge Base
if not self.__retriever:
if not self._retriever:
return self.question_answer(query)
else:
return self.context_answer(query)
Expand All @@ -49,8 +48,13 @@ def question_answer(self, query):
]
)

chain = prompt | self.__llm
response = chain.invoke({"input": query}).content
chain = prompt | self._llm
response = chain.invoke({"input": query})

try:
response = response.content
except:
pass

# Adding to memory
if self._memory:
Expand All @@ -75,8 +79,8 @@ def context_answer(self, query):
]
)

question_answer_chain = create_stuff_documents_chain(self.__llm, prompt)
chain = create_retrieval_chain(self.__retriever, question_answer_chain)
question_answer_chain = create_stuff_documents_chain(self._llm, prompt)
chain = create_retrieval_chain(self._retriever, question_answer_chain)

return chain.invoke({"input": query})['answer']

Expand Down
19 changes: 11 additions & 8 deletions woodwork/components/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@
from woodwork.components.llms.llm import llm

class openai(llm):
def __init__(self, name, config):
def __init__(self, name, api_key: str, model="gpt-4o-mini", **config):
print_debug(f"Establishing connection with model...")

llm = ChatOpenAI(
model=config["model"],
self._llm_value = ChatOpenAI(
model=model,
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
api_key=config["api_key"]
api_key=api_key
)

retriever = None
if "knowledge_base" in config:
retriever = config["knowledge_base"].retriever
self._retriever = config.get("knowledge_base")
if self._retriever:
self._retriever = self._retriever.retriever

super().__init__(name, llm, retriever, config)
super().__init__(name, **config)

print_debug("Model initialised.")

@property
def _llm(self): return self._llm_value
4 changes: 2 additions & 2 deletions woodwork/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def create_object(command):
if type == "short_term": return short_term(command["variable"], command["config"])

if component == "llm":
if type == "hugging_face": return hugging_face(command["variable"], command["config"])
if type == "openai": return openai(command["variable"], command["config"])
if type == "hugging_face": return hugging_face(command["variable"], **command["config"])
if type == "openai": return openai(command["variable"], **command["config"])

if component == "input":
if type == "command_line": return command_line(command["variable"], command["config"])
Expand Down
10 changes: 7 additions & 3 deletions woodwork/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def init(options={"isolated": False}):
if os.path.exists(temp_requirements_file):
os.remove(temp_requirements_file)


print("Initialization complete.")

def get_subdirectories(path: str) -> list[str]:
Expand All @@ -138,6 +137,11 @@ def get_subdirectories(path: str) -> list[str]:

def install_all():
print("Installing all dependencies...")

setup_virtual_env({"isolated": True})

# Change this to work with windows
activate_script = '.woodwork/env/bin/activate'

# Access the requirements directory as a package resource
requirements_dir = pkg_resources.files('woodwork')/'requirements'
Expand Down Expand Up @@ -167,7 +171,7 @@ def install_all():
f.write(f"{requirement}\n")

try:
subprocess.check_call([f"pip install -r {temp_requirements_file}"], shell=True)
subprocess.check_call([f". {activate_script} && pip install -r {temp_requirements_file}"], shell=True)
print(f"Installed all combined dependencies.")
except subprocess.CalledProcessError:
sys.exit(1)
Expand All @@ -176,5 +180,5 @@ def install_all():
if os.path.exists(temp_requirements_file):
os.remove(temp_requirements_file)


activate_virtual_environment()
print("Initialization complete.")
7 changes: 6 additions & 1 deletion woodwork/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@ def import_all_classes(package_name: str) -> bool:
if file.endswith(".py") and file != "__init__.py":
# Derive the full module path
relative_path = os.path.relpath(root, package_path)
print("RELPATH =", relative_path)
module_name = os.path.splitext(file)[0]
full_module_name = f"{package_name}.{relative_path.replace(os.path.sep, '.')}.{module_name}"

if relative_path == ".":
full_module_name = f"{package_name}.{module_name}"
else:
full_module_name = f"{package_name}.{relative_path.replace(os.path.sep, '.')}.{module_name}"

# Import the module
try:
Expand Down

0 comments on commit 530f48b

Please # to comment.