Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[New RM] support brave search #134

Merged
merged 2 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/run_storm_wiki_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import ClaudeModel
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.rm import YouRM, BingSearch, BraveRM
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -65,6 +65,8 @@ def main(args):
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand All @@ -89,7 +91,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand Down
6 changes: 4 additions & 2 deletions examples/run_storm_wiki_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import DeepSeekModel
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.rm import YouRM, BingSearch, BraveRM
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -90,6 +90,8 @@ def main(args):
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k)
else:
raise ValueError(f"Invalid retriever: {args.retriever}. Choose either 'bing' or 'you'.")

Expand Down Expand Up @@ -123,7 +125,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'], required=True,
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave'], required=True,
help='The search engine API to use for retrieving information.')
parser.add_argument('--model', type=str, choices=['deepseek-chat', 'deepseek-coder'], default='deepseek-chat',
help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.')
Expand Down
6 changes: 4 additions & 2 deletions examples/run_storm_wiki_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import GoogleModel
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.rm import YouRM, BingSearch, BraveRM
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -67,6 +67,8 @@ def main(args):
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand All @@ -91,7 +93,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand Down
6 changes: 4 additions & 2 deletions examples/run_storm_wiki_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.rm import YouRM, BingSearch, BraveRM
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -77,6 +77,8 @@ def main(args):
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand All @@ -101,7 +103,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand Down
6 changes: 4 additions & 2 deletions examples/run_storm_wiki_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import VLLMClient
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.rm import YouRM, BingSearch, BraveRM
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -63,6 +63,8 @@ def main(args):
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand Down Expand Up @@ -147,7 +149,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand Down
6 changes: 4 additions & 2 deletions examples/run_storm_wiki_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

sys.path.append('./src')
from lm import OllamaClient
from rm import YouRM, BingSearch
from rm import YouRM, BingSearch, BraveRM
from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from utils import load_api_key

Expand Down Expand Up @@ -65,6 +65,8 @@ def main(args):
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand Down Expand Up @@ -151,7 +153,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand Down
77 changes: 73 additions & 4 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,10 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st

return collected_results


class SerperRM(dspy.Retrieve):
"""Retrieve information from custom queries using Serper.dev."""

def __init__(self, serper_search_api_key=None, query_params=None):
"""Args:
serper_search_api_key str: API key to run serper, can be found by creating an account on https://serper.dev/
Expand Down Expand Up @@ -373,8 +374,8 @@ def get_usage_and_reset(self):
def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]):
"""
Calls the API and searches for the query passed in.


Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): Dummy parameter to match the interface. Does not have any effect.
Expand All @@ -395,7 +396,7 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
if query == 'Queries:':
continue
query_params = self.query_params

# All available parameters can be found in the playground: https://serper.dev/playground
# Sets the json value for query to be the query that is being parsed.
query_params['q'] = query
Expand Down Expand Up @@ -441,3 +442,71 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
continue

return collected_results


class BraveRM(dspy.Retrieve):
def __init__(self, brave_search_api_key=None, k=3, is_valid_source: Callable = None):
super().__init__(k=k)
if not brave_search_api_key and not os.environ.get("BRAVE_API_KEY"):
raise RuntimeError("You must supply brave_search_api_key or set environment variable BRAVE_API_KEY")
elif brave_search_api_key:
self.brave_search_api_key = brave_search_api_key
else:
self.brave_search_api_key = os.environ["BRAVE_API_KEY"]
self.usage = 0

# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
if is_valid_source:
self.is_valid_source = is_valid_source
else:
self.is_valid_source = lambda x: True

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0

return {'BraveRM': usage}

def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []):
"""Search with api.search.brave.com for self.k top passages for query or queries

Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): A list of urls to exclude from the search results.

Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
collected_results = []
for query in queries:
try:
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": self.brave_search_api_key
}
response = requests.get(
f"https://api.search.brave.com/res/v1/web/search?result_filter=web&q={query}",
headers=headers,
).json()
results = response.get('web', {}).get('results', [])

for result in results:
collected_results.append(
{
'snippets': result.get('extra_snippets', []),
'title': result.get('title'),
'url': result.get('url'),
'description': result.get('description'),
}
)
except Exception as e:
logging.error(f'Error occurs when searching query {query}: {e}')

return collected_results