Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Support chunk text-overflow display #170

Merged
merged 6 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import httpx
import os
import re
import mimetypes
import markdown
import html
import mimetypes
from http import HTTPStatus
from pai_rag.app.web.view_model import ViewModel
from pai_rag.app.web.ui_constants import EMPTY_KNOWLEDGEBASE_MESSAGE
Expand Down Expand Up @@ -95,16 +95,39 @@ def _format_rag_response(
response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE.format(query_str=question)
return response
elif is_finished:
content_list = []
for i, doc in enumerate(docs):
filename = doc["metadata"].get("file_name", None)
file_url = doc["metadata"].get("file_url", None)
if filename:
media_url = doc.get("metadata", {}).get("image_url", None)
if media_url and doc["text"] == "":
formatted_image_name = re.sub(
"^[0-9a-z]{32}_", "", "/".join(media_url.split("/")[-2:])
)
content = f"""
<span>
<a href="{media_url}"> [{i+1}]: {formatted_image_name} </a> Score:{doc["score"]}
</span>
<br>
"""
elif filename:
formatted_file_name = re.sub("^[0-9a-z]{32}_", "", filename)
html_content = html.escape(
re.sub(r"<.*?>", "", doc["text"])
).replace("\n", " ")
if file_url:
formatted_file_name = f"""[{formatted_file_name}]({file_url})"""
referenced_docs += (
f'[{i+1}]: {formatted_file_name} Score:{doc["score"]}\n'
)
formatted_file_name = (
f'<a href="{file_url}"> {formatted_file_name} </a>'
)
content = f"""
<span class="text" title="{html_content}">
[{i+1}]: {formatted_file_name} Score:{doc["score"]}
<span style='color: blue; font-size: 12px; background-color: #FFCCCB'> ( {html_content[:40]}... ) </span>
</span>
<br>
"""
content_list.append(content)
referenced_docs = "".join(content_list)

formatted_answer = ""
if session_id:
Expand Down
9 changes: 8 additions & 1 deletion src/pai_rag/app/web/tabs/chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,14 @@ def change_query_radio(query_type):
)

with gr.Column(scale=8):
chatbot = gr.Chatbot(height=500, elem_id="chatbot")
css = """
.text{
white-space: normal !important;
overflow:hidden;
text-overflow:ellipsis;
display: -webkit-box;
}"""
chatbot = gr.Chatbot(height=500, elem_id="chatbot", css=css)
question = gr.Textbox(label="Enter your question.", elem_id="question")
with gr.Row():
submitBtn = gr.Button("Submit", variant="primary")
Expand Down
21 changes: 20 additions & 1 deletion src/pai_rag/data/rag_oss_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pai_rag.integrations.extractors.text_qa_extractor import TextQAExtractor
from pai_rag.modules.nodeparser.node_parser import node_id_hash
from pai_rag.data.open_dataset import MiraclOpenDataSet, DuRetrievalDataSet

from llama_index.core.schema import BaseNode

import logging
import re
Expand Down Expand Up @@ -99,6 +99,24 @@ def _get_oss_files(self):
logger.error(f"Failed to load document {oss_obj.key}")
return files

def _filter_text_nodes(self, nodes: List[BaseNode]):
filtered_nodes = []
text_seen = set()
text_seen.update(
node.text.strip()
for node in nodes
if isinstance(node, TextNode) and node.metadata.get("image_url") is not None
)
for node in nodes:
if (
isinstance(node, TextNode)
and node.metadata.get("image_url") is None
and node.text.strip() in text_seen
):
continue
filtered_nodes.append(node)
return filtered_nodes

def _get_nodes(
self,
file_path: str | List[str],
Expand Down Expand Up @@ -187,6 +205,7 @@ def _get_nodes(

logger.info(f"[DataReader] Split into {len(nodes)} nodes.")

nodes = self._filter_text_nodes(nodes)
# QA metadata extraction
if enable_qa_extraction:
qa_nodes = []
Expand Down