Skip to content

Research branch #46

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 13 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ venv/
*.pdf
*.mp3
*.sqlite
*.google-cookie
examples/graph_examples/ScrapeGraphAI_generated_graph
main.py
7 changes: 1 addition & 6 deletions examples/custom_graph_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from scrapegraphai.models import OpenAI
from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode
from scrapegraphai.utils import convert_to_csv, convert_to_json

load_dotenv()
openai_key = os.getenv("OPENAI_APIKEY")
Expand Down Expand Up @@ -68,8 +67,4 @@

# get the answer from the result
result = result.get("answer", "No answer found.")
print(result)

# Save to json and csv
convert_to_csv(result, "result")
convert_to_json(result, "result")
print(result)
33 changes: 33 additions & 0 deletions examples/search_graph_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Example of Search Graph
"""

import os
from dotenv import load_dotenv
from scrapegraphai.graphs import SearchGraph
from scrapegraphai.utils import convert_to_csv, convert_to_json

load_dotenv()
openai_key = os.getenv("OPENAI_APIKEY")

# Define the configuration for the graph
graph_config = {
"llm": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
"temperature": 0,
},
}

# Create the SmartScraperGraph instance
smart_scraper_graph = SearchGraph(
prompt="List me all the regions of Italy.",
config=graph_config
)

result = smart_scraper_graph.run()
print(result)

# Save to json and csv
convert_to_csv(result, "result")
convert_to_json(result, "result")
5 changes: 0 additions & 5 deletions examples/smart_scraper_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import convert_to_csv, convert_to_json

load_dotenv()
openai_key = os.getenv("OPENAI_APIKEY")
Expand All @@ -28,7 +27,3 @@

result = smart_scraper_graph.run()
print(result)

# Save to json and csv
convert_to_csv(result, "result")
convert_to_json(result, "result")
5 changes: 0 additions & 5 deletions examples/speech_graph_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
from dotenv import load_dotenv
from scrapegraphai.graphs import SpeechGraph
from scrapegraphai.utils import convert_to_csv, convert_to_json

load_dotenv()
openai_key = os.getenv("OPENAI_APIKEY")
Expand Down Expand Up @@ -37,7 +36,3 @@

result = speech_graph.run()
print(result.get("answer", "No answer found"))

# Save to json and csv
convert_to_csv(result, "result")
convert_to_json(result, "result")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ python-dotenv = "1.0.1"
tiktoken = {version = ">=0.5.2,<0.6.0"}
tqdm = "4.66.1"
graphviz = "0.20.1"
google = "3.0.0"

[tool.poetry.dev-dependencies]
pytest = "8.0.0"
Expand Down
1 change: 1 addition & 0 deletions scrapegraphai/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .base_graph import BaseGraph
from .smart_scraper_graph import SmartScraperGraph
from .speech_graph import SpeechGraph
from .search_graph import SearchGraph
43 changes: 43 additions & 0 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Module having abstract class for creating all the graphs
"""
from abc import ABC, abstractmethod
from typing import Optional

class AbstractGraph(ABC):
"""
Abstract class representing a generic graph-based tool.
"""

def __init__(self, prompt: str, config: dict, file_source: Optional[str] = "url"):
"""
Initializes the AbstractGraph with a prompt, file source, and configuration.
"""
self.prompt = prompt
self.file_source = file_source
self.input_key = "url" if file_source.startswith(
"http") else "local_dir"
self.config = config
self.llm_model = self._create_llm(config["llm"])
self.graph = self._create_graph()

@abstractmethod
def _create_llm(self, llm_config: dict):
"""
Abstract method to create a language model instance.
"""
pass

@abstractmethod
def _create_graph(self):
"""
Abstract method to create a graph representation.
"""
pass

@abstractmethod
def run(self) -> str:
"""
Abstract method to execute the graph and return the result.
"""
pass
91 changes: 91 additions & 0 deletions scrapegraphai/graphs/search_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Module for making the search on the intenet
"""
from ..models import OpenAI, Gemini
from .base_graph import BaseGraph
from ..nodes import (
SearchInternetNode,
FetchNode,
ParseNode,
RAGNode,
GenerateAnswerNode
)
from .abstract_graph import AbstractGraph


class SearchGraph(AbstractGraph):
"""
Module for searching info on the internet
"""

def _create_llm(self, llm_config: dict):
"""
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
"""
llm_defaults = {
"temperature": 0,
"streaming": True
}
llm_params = {**llm_defaults, **llm_config}
if "api_key" not in llm_params:
raise ValueError("LLM configuration must include an 'api_key'.")
if "gpt-" in llm_params["model"]:
return OpenAI(llm_params)
elif "gemini" in llm_params["model"]:
return Gemini(llm_params)
else:
raise ValueError("Model not supported")

def _create_graph(self):
"""
Creates the graph of nodes representing the workflow for web scraping and searching.
"""
search_internet_node = SearchInternetNode(
input="user_prompt",
output=["url"],
model_config={"llm_model": self.llm_model}
)
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
model_config={"llm_model": self.llm_model},
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
model_config={"llm_model": self.llm_model},
)

return BaseGraph(
nodes={
search_internet_node,
fetch_node,
parse_node,
rag_node,
generate_answer_node,
},
edges={
(search_internet_node, fetch_node),
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
},
entry_point=search_internet_node
)

def run(self) -> str:
"""
Executes the web scraping and searching process.
"""
inputs = {"user_prompt": self.prompt}
final_state = self.graph.execute(inputs)

return final_state.get("answer", "No answer found.")
54 changes: 4 additions & 50 deletions scrapegraphai/graphs/smart_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,64 +9,26 @@
RAGNode,
GenerateAnswerNode
)
from .abstract_graph import AbstractGraph


class SmartScraperGraph:
class SmartScraperGraph(AbstractGraph):
"""
SmartScraper is a comprehensive web scraping tool that automates the process of extracting
information from web pages using a natural language model to interpret and answer prompts.

Attributes:
prompt (str): The user's natural language prompt for the information to be extracted.
url (str): The URL of the web page to scrape.
llm_config (dict): Configuration parameters for the language model, with
'api_key' being mandatory.
llm (ChatOpenAI): An instance of the ChatOpenAI class configured with llm_config.
graph (BaseGraph): An instance of the BaseGraph class representing the scraping workflow.

Methods:
run(): Executes the web scraping process and returns the answer to the prompt.

Args:
prompt (str): The user's natural language prompt for the information to be extracted.
url (str): The URL of the web page to scrape.
llm_config (dict): A dictionary containing configuration options for the language model.
Must include 'api_key', may also specify 'model_name',
'temperature', and 'streaming'.
"""

def __init__(self, prompt: str, file_source: str, config: dict):
"""
Initializes the SmartScraper with a prompt, URL, and language model configuration.
"""
self.prompt = prompt
self.file_source = file_source
self.input_key = "url" if file_source.startswith(
"http") else "local_dir"
self.config = config
self.llm_model = self._create_llm(config["llm"])
self.graph = self._create_graph()

def _create_llm(self, llm_config: dict):
"""
Creates an instance of the ChatOpenAI class with the provided language model configuration.

Returns:
ChatOpenAI: An instance of the ChatOpenAI class.

Raises:
ValueError: If 'api_key' is not provided in llm_config.
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
"""
llm_defaults = {
"temperature": 0,
"streaming": True
}
# Update defaults with any LLM parameters that were provided
llm_params = {**llm_defaults, **llm_config}
# Ensure the api_key is set, raise an error if it's not
if "api_key" not in llm_params:
raise ValueError("LLM configuration must include an 'api_key'.")
# select the model based on the model name
if "gpt-" in llm_params["model"]:
return OpenAI(llm_params)
elif "gemini" in llm_params["model"]:
Expand All @@ -76,11 +38,7 @@ def _create_llm(self, llm_config: dict):
def _create_graph(self):
"""
Creates the graph of nodes representing the workflow for web scraping.

Returns:
BaseGraph: An instance of the BaseGraph class.
"""
# define the nodes for the graph
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
Expand Down Expand Up @@ -117,12 +75,8 @@ def _create_graph(self):

def run(self) -> str:
"""
Executes the scraping process by running the graph and returns the extracted information.

Returns:
str: The answer extracted from the web page, corresponding to the given prompt.
Executes the web scraping process and returns the answer to the prompt.
"""

inputs = {"user_prompt": self.prompt, self.input_key: self.file_source}
final_state = self.graph.execute(inputs)

Expand Down
Loading