Skip to content

Commit

Permalink
Replace mounted file_path with actual OSS path (#347)
Browse files Browse the repository at this point in the history
* Replace mounted file_path with actual OSS path

* Fix reviews

* Support multiple output documents and fix dataset input subdirectory
  • Loading branch information
wwxxzz authored Jan 21, 2025
1 parent 17d4435 commit e96afe6
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 19 deletions.
3 changes: 2 additions & 1 deletion src/pai_rag/tools/data_process/dataset/file_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def _run_single_op(self, ops, op_name):
run_tasks = []
for i, batch_data in enumerate(self.data):
run_tasks.append(ops[i % num_actors].process.remote(batch_data))
self.data = ray.get(run_tasks)
results = ray.get(run_tasks)
self.data = [item for sub_results in results for item in sub_results]
except: # noqa: E722
logger.error(f"An error occurred during Op [{op_name}].")
import traceback
Expand Down
49 changes: 48 additions & 1 deletion src/pai_rag/tools/data_process/ops/parser_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import threading
from typing import List
from loguru import logger
from pathlib import Path
from urllib.parse import urlparse
from pai_rag.core.rag_module import resolve
from pai_rag.utils.oss_client import OssClient
from pai_rag.tools.data_process.ops.base_op import BaseOP, OPERATORS
Expand Down Expand Up @@ -55,6 +57,23 @@ def __init__(
reader_config=self.data_reader_config,
oss_store=self.oss_store,
)
self.mount_path = os.environ.get("INPUT_MOUNT_PATH", None).strip("/")
self.real_path = os.environ.get("OSS_SOURCE_PATH", None).strip("/")
if self.mount_path and self.real_path:
self.mount_path = Path(self.mount_path).resolve()
real_uri = urlparse(self.real_path)
if not real_uri.scheme:
logger.error(
f"Real path '{self.real_path}' must include a URI scheme (e.g., 'oss://')."
)
self.should_replace = False
else:
self.should_replace = True
else:
self.should_replace = False
logger.warning(
"File path won't be replaced to data source URI since either INPUT_MOUNT_PATH or OSS_SOURCE_PATH is not provided."
)
logger.info(
f"""ParserActor [PaiDataReader] init finished with following parameters:
concat_csv_rows: {concat_csv_rows}
Expand All @@ -63,14 +82,42 @@ def __init__(
sheet_column_filters: {sheet_column_filters}
oss_bucket: {oss_bucket}
oss_endpoint: {oss_endpoint}
path_should_replace: {self.should_replace}
"""
)

def replace_mount_with_real_path(self, documents):
if self.should_replace:
for document in documents:
if "file_path" not in document.metadata:
continue
file_path = document.metadata["file_path"]
file_path_obj = Path(file_path).resolve()
try:
relative_path_str = (
file_path_obj.relative_to(self.mount_path).as_posix().strip("/")
)
document.metadata[
"file_path"
] = f"{self.real_path}/{relative_path_str}"
document.metadata["mount_path"] = file_path
logger.debug(
f"Replacing original file_path: {file_path} --> {document.metadata['file_path']}"
)
except ValueError:
# file_path 不以 mount_path 开头
logger.debug(
f"Path {file_path} does not start with mount path {self.mount_path}. No replacement done."
)
except Exception as e:
logger.error(f"Error replacing path {file_path}: {e}")

def process(self, input_file):
current_thread = threading.current_thread()
logger.info(f"当前线程的 ID: {current_thread.ident} 进程ID: {os.getpid()}")
documents = self.data_reader.load_data(file_path_or_directory=input_file)
if len(documents) == 0:
logger.info(f"No data found in the input file: {input_file}")
return None
return convert_document_to_dict(documents[0])
self.replace_mount_with_real_path(documents)
return convert_document_to_dict(documents)
10 changes: 4 additions & 6 deletions src/pai_rag/tools/data_process/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pai_rag.tools.data_process.utils.op_utils import (
load_op_names,
load_op,
replace_if_previous_op_subdirectory,
)


Expand Down Expand Up @@ -57,12 +58,9 @@ def run(self):
# self.cfg.dataset_path = self.cfg.export_path
else:
# TODO: support multiple operators
# dataset = RayDataset(
# os.path.join(
# self.cfg.dataset_path, get_previous_operation(op_name)
# ),
# self.cfg,
# )
self.cfg.dataset_path = replace_if_previous_op_subdirectory(
self.cfg.dataset_path, op_name
)
dataset = RayDataset(self.cfg.dataset_path, self.cfg)
ops = load_op(op_name, self.cfg.process)
logger.info(f"Processing op {op_name} ...")
Expand Down
27 changes: 16 additions & 11 deletions src/pai_rag/tools/data_process/utils/formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,22 @@ def convert_node_to_dict(node):
}


def convert_document_to_dict(doc):
return {
"id": doc.id_,
"embedding": doc.embedding,
"metadata": doc.metadata,
"excluded_embed_metadata_keys": doc.excluded_embed_metadata_keys,
"excluded_llm_metadata_keys": doc.excluded_llm_metadata_keys,
"relationships": doc.relationships,
"text": doc.text,
"mimetype": doc.mimetype,
}
def convert_document_to_dict(documents):
docs_dict = []
for document in documents:
docs_dict.append(
{
"id": document.id_,
"embedding": document.embedding,
"metadata": document.metadata,
"excluded_embed_metadata_keys": document.excluded_embed_metadata_keys,
"excluded_llm_metadata_keys": document.excluded_llm_metadata_keys,
"relationships": document.relationships,
"text": document.text,
"mimetype": document.mimetype,
}
)
return docs_dict


def convert_list_to_documents(doc_list):
Expand Down
14 changes: 14 additions & 0 deletions src/pai_rag/tools/data_process/utils/op_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from loguru import logger
from pathlib import Path
from pai_rag.tools.data_process.ops.base_op import OPERATORS
from pai_rag.tools.data_process.utils.mm_utils import size_to_bytes
from pai_rag.tools.data_process.utils.cuda_utils import get_num_gpus, calculate_np
Expand Down Expand Up @@ -58,3 +60,15 @@ def load_op_names(process_list):
op_name, _ = list(process.items())[0]
op_names.append(op_name)
return op_names


def replace_if_previous_op_subdirectory(folder_path, op_name):
folder = Path(folder_path)
previous_op_name = get_previous_operation(op_name)

# 检查一级子目录
for entry in folder.iterdir():
if entry.is_dir() and entry.name == previous_op_name:
return os.path.join(folder_path, previous_op_name)

return folder_path

0 comments on commit e96afe6

Please # to comment.