Skip to content

Commit

Permalink
improve deduping issue
Browse files Browse the repository at this point in the history
  • Loading branch information
prasmussen15 committed Aug 23, 2024
1 parent 9cc9883 commit b25e1e6
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 37 deletions.
34 changes: 16 additions & 18 deletions core/prompts/dedupe_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,27 @@ def v1(context: dict[str, Any]) -> list[Message]:
Message(
role='user',
content=f"""
Given the following context, deduplicate edges from a list of new edges given a list of existing edges:
Given the following context, deduplicate facts from a list of new facts given a list of existing facts:
Existing Edges:
Existing Facts:
{json.dumps(context['existing_edges'], indent=2)}
New Edges:
New Facts:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
1. start with the list of edges from New Edges
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
edge in the list
3. Respond with the resulting list of edges
If any facts in New Facts is a duplicate of a fact in Existing Facts,
do not return it in the list of unique facts.
Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates,
duplicate edges may have different names
1. The facts do not have to be completely identical to be duplicates,
they just need to have similar factual content
Respond with a JSON object in the following format:
{{
"new_edges": [
"unique_facts": [
{{
"fact": "one sentence description of the fact"
"uuid": "unique identifier of the fact"
}}
]
}}
Expand Down Expand Up @@ -107,24 +105,24 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
Message(
role='user',
content=f"""
Given the following context, find all of the duplicates in a list of edges:
Given the following context, find all of the duplicates in a list of facts:
Edges:
Facts:
{json.dumps(context['edges'], indent=2)}
Task:
If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges
If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's.
Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates,
edges with the same name may not be duplicates
2. The final list should have only unique facts. If 3 edges are all duplicates of each other, only one of their
1. The facts do not have to be completely identical to be duplicates, they just need to have similar content
2. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
facts should be in the response
Respond with a JSON object in the following format:
{{
"unique_edges": [
"unique_facts": [
{{
"uuid": "unique identifier of the fact"

Check failure on line 125 in core/prompts/dedupe_edges.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E101)

core/prompts/dedupe_edges.py:125:1: E101 Indentation contains mixed spaces and tabs
"fact": "fact of a unique edge",
}}
]
Expand Down
9 changes: 8 additions & 1 deletion core/utils/bulk_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import typing
from datetime import datetime
from numpy import dot

from neo4j import AsyncDriver
from pydantic import BaseModel
Expand All @@ -23,7 +24,7 @@
extract_nodes,
)

CHUNK_SIZE = 10
CHUNK_SIZE = 15

Check failure on line 27 in core/utils/bulk_utils.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

core/utils/bulk_utils.py:1:1: I001 Import block is un-sorted or un-formatted


class BulkEpisode(BaseModel):
Expand Down Expand Up @@ -137,6 +138,9 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str
async def compress_nodes(
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
) -> tuple[list[EntityNode], dict[str, str]]:
anchor = nodes[0] if len(nodes) > 0 else None
nodes.sort(key=lambda node: dot(anchor.name_embedding, node.name_embedding))

Check failure on line 142 in core/utils/bulk_utils.py

View workflow job for this annotation

GitHub Actions / mypy

union-attr

Item "None" of "EntityNode | None" has no attribute "name_embedding"

Check failure on line 142 in core/utils/bulk_utils.py

View workflow job for this annotation

GitHub Actions / mypy

arg-type

Argument 1 to "dot" has incompatible type "list[float] | Any | None"; expected "_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes]"

Check failure on line 142 in core/utils/bulk_utils.py

View workflow job for this annotation

GitHub Actions / mypy

arg-type

Argument 2 to "dot" has incompatible type "list[float] | None"; expected "_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes]"

node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]

results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
Expand All @@ -156,6 +160,9 @@ async def compress_nodes(


async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
anchor = edges[0] if len(edges) > 0 else None
edges.sort(key=lambda embedding: dot(anchor.fact_embedding, embedding.fact_embedding))

Check failure on line 164 in core/utils/bulk_utils.py

View workflow job for this annotation

GitHub Actions / mypy

union-attr

Item "None" of "EntityEdge | None" has no attribute "fact_embedding"

Check failure on line 164 in core/utils/bulk_utils.py

View workflow job for this annotation

GitHub Actions / mypy

arg-type

Argument 1 to "dot" has incompatible type "list[float] | Any | None"; expected "_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes]"

Check failure on line 164 in core/utils/bulk_utils.py

View workflow job for this annotation

GitHub Actions / mypy

arg-type

Argument 2 to "dot" has incompatible type "list[float] | None"; expected "_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes]"

edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]

results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
Expand Down
30 changes: 14 additions & 16 deletions core/utils/maintenance/edge_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,27 +94,23 @@ async def dedupe_extracted_edges(
) -> list[EntityEdge]:
# Create edge map
edge_map = {}
for edge in existing_edges:
edge_map[edge.fact] = edge
for edge in extracted_edges:
if edge.fact in edge_map:
continue
edge_map[edge.fact] = edge
edge_map[edge.uuid] = edge

# Prepare context for LLM
context = {
'extracted_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges],
'existing_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges],
'extracted_edges': [{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges],
'existing_edges': [{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges],
}

llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
new_edges_data = llm_response.get('new_edges', [])
logger.info(f'Extracted new edges: {new_edges_data}')
unique_edge_data = llm_response.get('unique_facts', [])
logger.info(f'Extracted new edges: {unique_edge_data}')

# Get full edge data
edges = []
for edge_data in new_edges_data:
edge = edge_map[edge_data['fact']]
for unique_edge in unique_edge_data:
edge = edge_map[unique_edge['uuid']]
edges.append(edge)

return edges
Expand All @@ -129,23 +125,25 @@ async def dedupe_edge_list(
# Create edge map
edge_map = {}
for edge in edges:
edge_map[edge.fact] = edge
edge_map[edge.uuid] = edge

# Prepare context for LLM
context = {'edges': [{'name': edge.name, 'fact': edge.fact} for edge in edges]}
context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}

llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.edge_list(context)
)
unique_edges_data = llm_response.get('unique_edges', [])
unique_edges_data = llm_response.get('unique_facts', [])

end = time()
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')

# Get full edge data
unique_edges = []
for edge_data in unique_edges_data:
fact = edge_data['fact']
unique_edges.append(edge_map[fact])
uuid = edge_data['uuid']
edge = edge_map[uuid]
edge.fact = edge_data['fact']
unique_edges.append(edge)

return unique_edges
4 changes: 2 additions & 2 deletions examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ async def main(use_bulk: bool = True):
episode_type='string',
reference_time=message.actual_timestamp,
)
for i, message in enumerate(messages[3:7])
for i, message in enumerate(messages[3:14])
]

await client.add_episode_bulk(episodes)


asyncio.run(main())
asyncio.run(main(True))

0 comments on commit b25e1e6

Please # to comment.