diff --git a/src/pai_rag/tools/data_process/dataset/file_dataset.py b/src/pai_rag/tools/data_process/dataset/file_dataset.py index b90d44d9..cbec386b 100644 --- a/src/pai_rag/tools/data_process/dataset/file_dataset.py +++ b/src/pai_rag/tools/data_process/dataset/file_dataset.py @@ -25,7 +25,8 @@ def _run_single_op(self, ops, op_name): run_tasks = [] for i, batch_data in enumerate(self.data): run_tasks.append(ops[i % num_actors].process.remote(batch_data)) - self.data = ray.get(run_tasks) + results = ray.get(run_tasks) + self.data = [item for sub_results in results for item in sub_results] except: # noqa: E722 logger.error(f"An error occurred during Op [{op_name}].") import traceback diff --git a/src/pai_rag/tools/data_process/ops/parser_op.py b/src/pai_rag/tools/data_process/ops/parser_op.py index 1b9865a9..1bbe47b9 100644 --- a/src/pai_rag/tools/data_process/ops/parser_op.py +++ b/src/pai_rag/tools/data_process/ops/parser_op.py @@ -3,6 +3,8 @@ import threading from typing import List from loguru import logger +from pathlib import Path +from urllib.parse import urlparse from pai_rag.core.rag_module import resolve from pai_rag.utils.oss_client import OssClient from pai_rag.tools.data_process.ops.base_op import BaseOP, OPERATORS @@ -55,6 +57,23 @@ def __init__( reader_config=self.data_reader_config, oss_store=self.oss_store, ) + self.mount_path = os.environ.get("INPUT_MOUNT_PATH", None).strip("/") + self.real_path = os.environ.get("OSS_SOURCE_PATH", None).strip("/") + if self.mount_path and self.real_path: + self.mount_path = Path(self.mount_path).resolve() + real_uri = urlparse(self.real_path) + if not real_uri.scheme: + logger.error( + f"Real path '{self.real_path}' must include a URI scheme (e.g., 'oss://')." + ) + self.should_replace = False + else: + self.should_replace = True + else: + self.should_replace = False + logger.warning( + "File path won't be replaced to data source URI since either INPUT_MOUNT_PATH or OSS_SOURCE_PATH is not provided." + ) logger.info( f"""ParserActor [PaiDataReader] init finished with following parameters: concat_csv_rows: {concat_csv_rows} @@ -63,9 +82,36 @@ def __init__( sheet_column_filters: {sheet_column_filters} oss_bucket: {oss_bucket} oss_endpoint: {oss_endpoint} + path_should_replace: {self.should_replace} """ ) + def replace_mount_with_real_path(self, documents): + if self.should_replace: + for document in documents: + if "file_path" not in document.metadata: + continue + file_path = document.metadata["file_path"] + file_path_obj = Path(file_path).resolve() + try: + relative_path_str = ( + file_path_obj.relative_to(self.mount_path).as_posix().strip("/") + ) + document.metadata[ + "file_path" + ] = f"{self.real_path}/{relative_path_str}" + document.metadata["mount_path"] = file_path + logger.debug( + f"Replacing original file_path: {file_path} --> {document.metadata['file_path']}" + ) + except ValueError: + # file_path 不以 mount_path 开头 + logger.debug( + f"Path {file_path} does not start with mount path {self.mount_path}. No replacement done." + ) + except Exception as e: + logger.error(f"Error replacing path {file_path}: {e}") + def process(self, input_file): current_thread = threading.current_thread() logger.info(f"当前线程的 ID: {current_thread.ident} 进程ID: {os.getpid()}") @@ -73,4 +119,5 @@ def process(self, input_file): if len(documents) == 0: logger.info(f"No data found in the input file: {input_file}") return None - return convert_document_to_dict(documents[0]) + self.replace_mount_with_real_path(documents) + return convert_document_to_dict(documents) diff --git a/src/pai_rag/tools/data_process/ray_executor.py b/src/pai_rag/tools/data_process/ray_executor.py index 9615d8f7..89e7f9de 100644 --- a/src/pai_rag/tools/data_process/ray_executor.py +++ b/src/pai_rag/tools/data_process/ray_executor.py @@ -7,6 +7,7 @@ from pai_rag.tools.data_process.utils.op_utils import ( load_op_names, load_op, + replace_if_previous_op_subdirectory, ) @@ -57,12 +58,9 @@ def run(self): # self.cfg.dataset_path = self.cfg.export_path else: # TODO: support multiple operators - # dataset = RayDataset( - # os.path.join( - # self.cfg.dataset_path, get_previous_operation(op_name) - # ), - # self.cfg, - # ) + self.cfg.dataset_path = replace_if_previous_op_subdirectory( + self.cfg.dataset_path, op_name + ) dataset = RayDataset(self.cfg.dataset_path, self.cfg) ops = load_op(op_name, self.cfg.process) logger.info(f"Processing op {op_name} ...") diff --git a/src/pai_rag/tools/data_process/utils/formatters.py b/src/pai_rag/tools/data_process/utils/formatters.py index 7eeac0e8..79990ef4 100644 --- a/src/pai_rag/tools/data_process/utils/formatters.py +++ b/src/pai_rag/tools/data_process/utils/formatters.py @@ -42,17 +42,22 @@ def convert_node_to_dict(node): } -def convert_document_to_dict(doc): - return { - "id": doc.id_, - "embedding": doc.embedding, - "metadata": doc.metadata, - "excluded_embed_metadata_keys": doc.excluded_embed_metadata_keys, - "excluded_llm_metadata_keys": doc.excluded_llm_metadata_keys, - "relationships": doc.relationships, - "text": doc.text, - "mimetype": doc.mimetype, - } +def convert_document_to_dict(documents): + docs_dict = [] + for document in documents: + docs_dict.append( + { + "id": document.id_, + "embedding": document.embedding, + "metadata": document.metadata, + "excluded_embed_metadata_keys": document.excluded_embed_metadata_keys, + "excluded_llm_metadata_keys": document.excluded_llm_metadata_keys, + "relationships": document.relationships, + "text": document.text, + "mimetype": document.mimetype, + } + ) + return docs_dict def convert_list_to_documents(doc_list): diff --git a/src/pai_rag/tools/data_process/utils/op_utils.py b/src/pai_rag/tools/data_process/utils/op_utils.py index 508a7b0a..4c07b34c 100644 --- a/src/pai_rag/tools/data_process/utils/op_utils.py +++ b/src/pai_rag/tools/data_process/utils/op_utils.py @@ -1,4 +1,6 @@ +import os from loguru import logger +from pathlib import Path from pai_rag.tools.data_process.ops.base_op import OPERATORS from pai_rag.tools.data_process.utils.mm_utils import size_to_bytes from pai_rag.tools.data_process.utils.cuda_utils import get_num_gpus, calculate_np @@ -58,3 +60,15 @@ def load_op_names(process_list): op_name, _ = list(process.items())[0] op_names.append(op_name) return op_names + + +def replace_if_previous_op_subdirectory(folder_path, op_name): + folder = Path(folder_path) + previous_op_name = get_previous_operation(op_name) + + # 检查一级子目录 + for entry in folder.iterdir(): + if entry.is_dir() and entry.name == previous_op_name: + return os.path.join(folder_path, previous_op_name) + + return folder_path