Skip to content

Commit

Permalink
Merge pull request #134 from kevindragon/main
Browse files Browse the repository at this point in the history
[New RM] support brave search
  • Loading branch information
shaoyijia authored Aug 9, 2024
2 parents 11a1707 + 457ed9e commit 821a324
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 16 deletions.
6 changes: 4 additions & 2 deletions examples/run_storm_wiki_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import ClaudeModel
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.rm import YouRM, BingSearch, BraveRM
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -65,6 +65,8 @@ def main(args):
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand All @@ -89,7 +91,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand Down
6 changes: 4 additions & 2 deletions examples/run_storm_wiki_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import DeepSeekModel
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.rm import YouRM, BingSearch, BraveRM
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -90,6 +90,8 @@ def main(args):
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k)
else:
raise ValueError(f"Invalid retriever: {args.retriever}. Choose either 'bing' or 'you'.")

Expand Down Expand Up @@ -123,7 +125,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'], required=True,
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave'], required=True,
help='The search engine API to use for retrieving information.')
parser.add_argument('--model', type=str, choices=['deepseek-chat', 'deepseek-coder'], default='deepseek-chat',
help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.')
Expand Down
6 changes: 4 additions & 2 deletions examples/run_storm_wiki_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import GoogleModel
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.rm import YouRM, BingSearch, BraveRM
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -67,6 +67,8 @@ def main(args):
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand All @@ -91,7 +93,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand Down
6 changes: 4 additions & 2 deletions examples/run_storm_wiki_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.rm import YouRM, BingSearch, BraveRM
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -77,6 +77,8 @@ def main(args):
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand All @@ -101,7 +103,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand Down
6 changes: 4 additions & 2 deletions examples/run_storm_wiki_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import VLLMClient
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.rm import YouRM, BingSearch, BraveRM
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -63,6 +63,8 @@ def main(args):
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand Down Expand Up @@ -147,7 +149,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand Down
6 changes: 4 additions & 2 deletions examples/run_storm_wiki_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

sys.path.append('./src')
from lm import OllamaClient
from rm import YouRM, BingSearch
from rm import YouRM, BingSearch, BraveRM
from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from utils import load_api_key

Expand Down Expand Up @@ -65,6 +65,8 @@ def main(args):
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=engine_args.search_top_k)

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand Down Expand Up @@ -151,7 +153,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand Down
77 changes: 73 additions & 4 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,10 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st

return collected_results


class SerperRM(dspy.Retrieve):
"""Retrieve information from custom queries using Serper.dev."""

def __init__(self, serper_search_api_key=None, query_params=None):
"""Args:
serper_search_api_key str: API key to run serper, can be found by creating an account on https://serper.dev/
Expand Down Expand Up @@ -373,8 +374,8 @@ def get_usage_and_reset(self):
def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]):
"""
Calls the API and searches for the query passed in.
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): Dummy parameter to match the interface. Does not have any effect.
Expand All @@ -395,7 +396,7 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
if query == 'Queries:':
continue
query_params = self.query_params

# All available parameters can be found in the playground: https://serper.dev/playground
# Sets the json value for query to be the query that is being parsed.
query_params['q'] = query
Expand Down Expand Up @@ -441,3 +442,71 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
continue

return collected_results


class BraveRM(dspy.Retrieve):
def __init__(self, brave_search_api_key=None, k=3, is_valid_source: Callable = None):
super().__init__(k=k)
if not brave_search_api_key and not os.environ.get("BRAVE_API_KEY"):
raise RuntimeError("You must supply brave_search_api_key or set environment variable BRAVE_API_KEY")
elif brave_search_api_key:
self.brave_search_api_key = brave_search_api_key
else:
self.brave_search_api_key = os.environ["BRAVE_API_KEY"]
self.usage = 0

# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
if is_valid_source:
self.is_valid_source = is_valid_source
else:
self.is_valid_source = lambda x: True

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0

return {'BraveRM': usage}

def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []):
"""Search with api.search.brave.com for self.k top passages for query or queries
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): A list of urls to exclude from the search results.
Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
collected_results = []
for query in queries:
try:
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": self.brave_search_api_key
}
response = requests.get(
f"https://api.search.brave.com/res/v1/web/search?result_filter=web&q={query}",
headers=headers,
).json()
results = response.get('web', {}).get('results', [])

for result in results:
collected_results.append(
{
'snippets': result.get('extra_snippets', []),
'title': result.get('title'),
'url': result.get('url'),
'description': result.get('description'),
}
)
except Exception as e:
logging.error(f'Error occurs when searching query {query}: {e}')

return collected_results

0 comments on commit 821a324

Please # to comment.