Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Patched the "update-vector-store" and moved the whole function that creates vector stores from rm.py to utils.py #126

Merged
merged 11 commits into from
Aug 3, 2024
5 changes: 1 addition & 4 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ By default, STORM is grounded on the Internet using the search engine, but it ca
--output-dir $OUTPUT_DIR \
--vector-db-mode offline \
--offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \
--update-vector-store \
--csv-file-path $CSV_FILE_PATH \
--device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \
--do-research \
Expand All @@ -70,7 +69,6 @@ By default, STORM is grounded on the Internet using the search engine, but it ca
--output-dir $OUTPUT_DIR \
--vector-db-mode online \
--online-vector-db-url $ONLINE_VECTOR_DB_URL \
--update-vector-store \
--csv-file-path $CSV_FILE_PATH \
--device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \
--do-research \
Expand All @@ -85,7 +83,7 @@ By default, STORM is grounded on the Internet using the search engine, but it ca
- Run the following command under the root directory to downsample the dataset by filtering papers with terms `[cs.CV]` and get a csv file that match the format mentioned above.

```
python examples/helper/process_kaggle_arxiv_abstract_dataset --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV
python examples/helper/process_kaggle_arxiv_abstract_dataset.py --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV
```
- Run the following command to run STORM grounding on the processed dataset. You can input a topic related to computer vision (e.g., "The progress of multimodal models in computer vision") to see the generated article. (Note that the generated article may not include enough details since the quick test only use the abstracts of arxiv papers.)

Expand All @@ -94,7 +92,6 @@ By default, STORM is grounded on the Internet using the search engine, but it ca
--output-dir $OUTPUT_DIR \
--vector-db-mode offline \
--offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \
--update-vector-store \
--csv-file-path $PATH_TO_THE_PROCESSED_CSV \
--device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \
--do-research \
Expand Down
44 changes: 28 additions & 16 deletions examples/run_storm_wiki_gpt_with_VectorRM.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
import sys
from argparse import ArgumentParser

sys.path.append('./')
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.rm import VectorRM
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
from knowledge_storm.utils import load_api_key
from knowledge_storm.utils import load_api_key, QdrantVectorStoreManager


def main(args):
Expand Down Expand Up @@ -83,6 +84,31 @@ def main(args):
max_thread_num=args.max_thread_num,
)

# Create / update the vector store with the documents in the csv file
if args.csv_file_path:
kwargs = {
'file_path': args.csv_file_path,
'content_column': 'content',
'title_column': 'title',
'url_column': 'url',
'desc_column': 'description',
'batch_size': args.embed_batch_size,
'vector_db_mode': args.vector_db_mode,
'collection_name': args.collection_name,
'device': args.device,
}
if args.vector_db_mode == 'offline':
QdrantVectorStoreManager.create_or_update_vector_store(
vector_store_path=args.offline_vector_db_dir,
**kwargs
)
elif args.vector_db_mode == 'online':
QdrantVectorStoreManager.create_or_update_vector_store(
url=args.online_vector_db_url,
api_key=os.getenv('QDRANT_API_KEY'),
**kwargs
)

# Setup VectorRM to retrieve information from your own data
rm = VectorRM(collection_name=args.collection_name, device=args.device, k=engine_args.search_top_k)

Expand All @@ -92,17 +118,6 @@ def main(args):
elif args.vector_db_mode == 'online':
rm.init_online_vector_db(url=args.online_vector_db_url, api_key=os.getenv('QDRANT_API_KEY'))

# Update the vector store with the documents in the csv file
if args.update_vector_store:
rm.update_vector_store(
file_path=args.csv_file_path,
content_column='content',
title_column='title',
url_column='url',
desc_column='description',
batch_size=args.embed_batch_size
)

# Initialize the STORM Wiki Runner
runner = STORMWikiRunner(engine_args, engine_lm_configs, rm)

Expand Down Expand Up @@ -139,10 +154,7 @@ def main(args):
help='If use offline mode, please provide the directory to store the vector store.')
parser.add_argument('--online-vector-db-url', type=str,
help='If use online mode, please provide the url of the Qdrant server.')
parser.add_argument('--update-vector-store', action='store_true',
help='If True, update the vector store with the documents in the csv file; otherwise, '
'use the existing vector store.')
parser.add_argument('--csv-file-path', type=str,
parser.add_argument('--csv-file-path', type=str, default=None,
help='The path of the custom document corpus in CSV format. The CSV file should include '
'content, title, url, and description columns.')
parser.add_argument('--embed-batch-size', type=int, default=64,
Expand Down
128 changes: 15 additions & 113 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import dspy
import pandas as pd
import requests
from langchain_core.documents import Document

from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import Qdrant
from qdrant_client import QdrantClient, models
from tqdm import tqdm
from qdrant_client import QdrantClient

from .utils import WebPageHelper

Expand Down Expand Up @@ -178,38 +177,38 @@ class VectorRM(dspy.Retrieve):
"""

def __init__(self,
collection_name: str = "my_documents",
embedding_model: str = 'BAAI/bge-m3',
collection_name: str,
embedding_model: str,
device: str = "mps",
k: int = 3,
chunk_size: int = 500,
chunk_overlap: int = 100):
):
"""
Params:
collection_name: Name of the Qdrant collection.
embedding_model: Name of the Hugging Face embedding model.
device: Device to run the embeddings model on, can be "mps", "cuda", "cpu".
k: Number of top chunks to retrieve.
chunk_size: Size of each chunk if you need to build the vector store from documents.
chunk_overlap: Overlap between chunks if you need to build the vector store from documents.
"""
super().__init__(k=k)
self.usage = 0
# check if the collection is provided
if not collection_name:
raise ValueError("Please provide a collection name.")
# check if the embedding model is provided
if not embedding_model:
raise ValueError("Please provide an embedding model.")

model_kwargs = {"device": device}
encode_kwargs = {"normalize_embeddings": True}
self.model = HuggingFaceEmbeddings(
model_name=embedding_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
)

self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap

self.collection_name = collection_name
self.client = None
self.qdrant = None

def _check_create_collection(self):
def _check_collection(self):
"""
Check if the Qdrant collection exists and create it if it does not.
"""
Expand All @@ -223,17 +222,7 @@ def _check_create_collection(self):
embeddings=self.model,
)
else:
print(f"Collection {self.collection_name} does not exist. Creating the collection...")
# create the collection
self.client.create_collection(
collection_name=f"{self.collection_name}",
vectors_config=models.VectorParams(size=1024, distance=models.Distance.COSINE),
)
self.qdrant = Qdrant(
client=self.client,
collection_name=self.collection_name,
embeddings=self.model,
)
raise ValueError(f"Collection {self.collection_name} does not exist. Please create the collection first.")

def init_online_vector_db(self, url: str, api_key: str):
"""
Expand All @@ -252,7 +241,7 @@ def init_online_vector_db(self, url: str, api_key: str):

try:
self.client = QdrantClient(url=url, api_key=api_key)
self._check_create_collection()
self._check_collection()
except Exception as e:
raise ValueError(f"Error occurs when connecting to the server: {e}")

Expand All @@ -268,97 +257,10 @@ def init_offline_vector_db(self, vector_store_path: str):

try:
self.client = QdrantClient(path=vector_store_path)
self._check_create_collection()
self._check_collection()
except Exception as e:
raise ValueError(f"Error occurs when loading the vector store: {e}")

def update_vector_store(
self,
file_path: str,
content_column: str,
title_column: str = "title",
url_column: str = "url",
desc_column: str = "description",
batch_size: int = 64
):
"""
Takes a CSV file where each row is a document and has columns for content, title, url, and description.
Then it converts all these documents in the content column to vectors and add them the Qdrant collection.

Args:
file_path (str): Path to the CSV file.
content_column (str): Name of the column containing the content.
title_column (str): Name of the column containing the title. Default is "title".
url_column (str): Name of the column containing the URL. Default is "url".
desc_column (str): Name of the column containing the description. Default is "description".
batch_size (int): Batch size for adding documents to the collection.
"""
if file_path is None:
raise ValueError("Please provide a file path.")
# check if the file is a csv file
if not file_path.endswith('.csv'):
raise ValueError(f"Not valid file format. Please provide a csv file.")
if content_column is None:
raise ValueError("Please provide the name of the content column.")
if url_column is None:
raise ValueError("Please provide the name of the url column.")

if self.qdrant is None:
raise ValueError("Qdrant client is not initialized.")

# read the csv file
df = pd.read_csv(file_path)
# check that content column exists and url column exists
if content_column not in df.columns:
raise ValueError(f"Content column {content_column} not found in the csv file.")
if url_column not in df.columns:
raise ValueError(f"URL column {url_column} not found in the csv file.")

documents = [
Document(
page_content=row[content_column],
metadata={
"title": row.get(title_column, ''),
"url": row[url_column],
"description": row.get(desc_column, ''),
}
)
for row in df.to_dict(orient='records')
]

# split the documents
from langchain_text_splitters import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
length_function=len,
add_start_index=True,
separators=[
"\n\n",
"\n",
".",
"\uff0e", # Fullwidth full stop
"\u3002", # Ideographic full stop
",",
"\uff0c", # Fullwidth comma
"\u3001", # Ideographic comma
" ",
"\u200B", # Zero-width space
"",
]
)
split_documents = text_splitter.split_documents(documents)

# update and save the vector store
num_batches = (len(split_documents) + batch_size - 1) // batch_size
for i in tqdm(range(num_batches)):
start_idx = i * batch_size
end_idx = min((i + 1) * batch_size, len(split_documents))
self.qdrant.add_documents(
documents=split_documents[start_idx:end_idx],
batch_size=batch_size,
)

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0
Expand Down
Loading