Open
Description
Simple Mem0ai Integration:
Guys, this implementation looks very good and simple interface. I've been using in the meanwhile instead this one, because the versatility of openai agents hooks allow me to do it and it worked super nice (notice that I used mem0ai API REST server as memory backend, but could be anyone if we create a proper interface):
import asyncio
import logging
import json
from typing import TypeVar, Generic
from dataclasses import dataclass
from agents import Agent, RunContextWrapper, Runner, TResponseInputItem, Tool
from agents.lifecycle import RunHooks
from agents.run import RunResult, RunConfig, RunResultStreaming
from mem0 import AsyncMemoryClient
from pydantic import BaseModel
import randomname
from fastapi.encoders import jsonable_encoder
from agents.tool import FunctionTool
from server.mem0ai.models import Message, MemoryCreate, MemoryData, MemoryCreateResponse, MemoryListResponse, Metadata, EntityType, HandoffData
DEFAULT_MAX_TURNS = 10
TMemoryContext = TypeVar('TMemoryContext', bound=BaseModel)
class MemoryManager:
def __init__(self, client: AsyncMemoryClient):
self.client = client
async def _add_memories_task(self, user_id: str, messages: list[Message], metadata: Metadata | None = None) -> list[Message]:
"""Add memories to storage with duplicate check."""
logging.info(f"Adding {len(messages)} memories for user {user_id}")
# Check for duplicates by getting the last memory
try:
last_memories = await self.client.get_all(user_id=user_id, limit=1)
if last_memories.get('results'):
last_memory_content = last_memories['results'][0]['memory']
# Check if the last message content matches any of the new messages
for message in messages:
if message.content == last_memory_content:
logging.warning(f"Duplicate memory detected for user {user_id}: '{message.content[:50]}...' - skipping")
return []
except Exception as e:
logging.debug(f"Could not check for duplicates: {e}")
memory_create = MemoryCreate(
user_id=user_id,
messages=messages,
metadata=metadata
)
res = await self.client.add(**memory_create.model_dump(mode="json", exclude_none=True))
response = MemoryCreateResponse.model_validate(res)
return [Message(role=m.role, content=m.memory) for m in response.results]
async def add_memories(
self,
user_id: str,
messages: Message | list[Message],
metadata: Metadata | None = None,
background: bool = True
) -> list[Message] | None:
"""Add single message or list of messages to memory."""
# Convert single message to list
if isinstance(messages, Message):
messages = [messages]
if background:
asyncio.create_task(self._add_memories_task(user_id, messages, metadata))
return None
return await self._add_memories_task(user_id, messages, metadata)
async def get_memory(self, memory_id: str) -> Message:
"""Get single memory by ID."""
logging.info(f"Getting memory for {memory_id}")
res = await self.client.get(memory_id=memory_id)
memory = MemoryData.model_validate(res)
return Message(role=memory.role, content=memory.memory)
async def get_memories(
self,
user_id: str,
limit: int | None = None,
filter_by: dict | None = None
) -> list[Message]:
"""Get memories with optional limit and filtering."""
logging.info(f"Getting memories for user {user_id} with limit={limit}, filter_by={filter_by}")
# Get memories with optional limit
params = {"user_id": user_id}
if limit:
params["limit"] = limit
res = await self.client.get_all(**params)
response = MemoryListResponse.model_validate(res)
# Apply filtering if specified
if filter_by:
response = response.filter_by(filter_by)
return [Message(role=m.role, content=m.memory) for m in response.results]
class MemoryContext(BaseModel):
user_id: str = randomname.get_name()
thread_id: str = "default_thread"
user_input: str | list[TResponseInputItem] = ""
metadata: dict | None = None
@dataclass
class MemoryAgent(Agent[TMemoryContext], Generic[TMemoryContext]):
memory_manager: MemoryManager | None = None
class MemoryHooks(RunHooks[TMemoryContext], Generic[TMemoryContext]):
async def on_agent_start(
self,
context: RunContextWrapper[MemoryContext],
agent: MemoryAgent
) -> None:
print(f"🔍 AGENT START: {agent.name}, Input: {str(context.context.user_input)[:100]}...")
message = Message(role="user", content=str(context.context.user_input))
metadata = Metadata(
entity=EntityType.AGENT,
structured=False,
io="input",
name=agent.name
)
await agent.memory_manager.add_memories(
user_id=context.context.user_id,
messages=message,
metadata=metadata,
background=False
)
async def on_agent_end(
self,
context: RunContextWrapper[MemoryContext],
agent: MemoryAgent,
output: str | BaseModel
) -> None:
print(f"🔍 AGENT END: {agent.name}, Output: {str(output)[:100]}...")
content = output.model_dump_json() if isinstance(output, BaseModel) else str(output)
message = Message(role="assistant", content=content)
metadata = Metadata(
entity=EntityType.AGENT,
structured=isinstance(output, BaseModel),
io="output",
name=agent.name
)
await agent.memory_manager.add_memories(
user_id=context.context.user_id,
messages=message,
metadata=metadata,
background=False
)
async def on_handoff(
self,
context: RunContextWrapper[MemoryContext],
from_agent: MemoryAgent,
to_agent: MemoryAgent,
) -> None:
print(f"🔄 HANDOFF: {from_agent.name} → {to_agent.name}")
handoff_data = HandoffData(from_agent=from_agent.name, to_agent=to_agent.name)
handoff_message = Message(role="system", content=handoff_data.model_dump_json())
metadata = Metadata(
entity=EntityType.HANDOFF,
structured=True,
io="output",
name=f"{from_agent.name}->{to_agent.name}"
)
await from_agent.memory_manager.add_memories(
user_id=context.context.user_id,
messages=handoff_message,
metadata=metadata,
background=True
)
async def on_tool_start(
self,
context: RunContextWrapper[MemoryContext],
agent: MemoryAgent,
tool: FunctionTool,
) -> None:
print(f"🔧 TOOL START: {agent.name} calling '{tool.name}'")
tool_input_schema = {
"tool_name": tool.name,
"description": tool.description,
"params_schema": tool.params_json_schema
}
tool_start_message = Message(role="system", content=json.dumps(tool_input_schema))
metadata = Metadata(
entity=EntityType.TOOL,
structured=True,
io="input",
name=tool.name
)
await agent.memory_manager.add_memories(
user_id=context.context.user_id,
messages=tool_start_message,
metadata=metadata,
background=False
)
async def on_tool_end(
self,
context: RunContextWrapper[MemoryContext],
agent: MemoryAgent,
tool: FunctionTool,
result: str | BaseModel,
) -> None:
print(f"🔧 TOOL END: {agent.name} finished '{tool.name}' → {result}")
logging.info(f"Tool {tool.name} finished by {agent.name} with result {result}")
tool_result_data = {
"tool_name": tool.name,
"result": result.model_dump_json() if isinstance(result, BaseModel) else str(result),
"success": True
}
tool_end_message = Message(role="system", content=json.dumps(tool_result_data))
metadata = Metadata(
entity=EntityType.TOOL,
structured=isinstance(result, BaseModel),
io="output",
name=tool.name
)
await agent.memory_manager.add_memories(
user_id=context.context.user_id,
messages=tool_end_message,
metadata=metadata,
background=False
)
class MemoryRunner(Runner):
@staticmethod
async def _inject_previous_memories(
user_id: str,
input: str | list[TResponseInputItem],
memory_manager: MemoryManager,
) -> list[TResponseInputItem]:
memories = [m.model_dump(include={"role", "content"}) for m in await memory_manager.get_memories(user_id)]
return memories + [{"role": "user", "content": input}]
@classmethod
async def run(
cls,
starting_agent: MemoryAgent[MemoryContext],
input: str | list[TResponseInputItem],
*,
context: MemoryContext,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: MemoryHooks[MemoryContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
) -> RunResult:
user_input = await cls._inject_previous_memories(user_id=context.user_id, input=input, memory_manager=starting_agent.memory_manager)
return await super().run(starting_agent, input=user_input, context=context, max_turns=max_turns, hooks=hooks, run_config=run_config, previous_response_id=previous_response_id)
@classmethod
def run_sync(
cls,
starting_agent: MemoryAgent[MemoryContext],
input: str | list[TResponseInputItem],
*,
context: MemoryContext,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: MemoryHooks[MemoryContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
) -> RunResult:
user_input = asyncio.get_event_loop().run_until_complete(cls._inject_previous_memories(user_id=context.user_id, input=input, memory_manager=starting_agent.memory_manager))
return super().run_sync(starting_agent, input=user_input, context=context, max_turns=max_turns, hooks=hooks, run_config=run_config, previous_response_id=previous_response_id)
@classmethod
def run_streamed(
cls,
starting_agent: MemoryAgent[MemoryContext],
input: str | list[TResponseInputItem],
*,
context: MemoryContext,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: MemoryHooks[MemoryContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
) -> RunResultStreaming:
user_input = asyncio.get_event_loop().run_until_complete(cls._inject_previous_memories(user_id=context.user_id, input=input, memory_manager=starting_agent.memory_manager))
return super().run_streamed(starting_agent, input=user_input, context=context, max_turns=max_turns, hooks=hooks, run_config=run_config, previous_response_id=previous_response_id)
Simple Usage
from mem0 import AsyncMemoryClient
import randomname
from memory import MemoryAgent, MemoryManager, MemoryContext, MemoryHooks, MemoryRunner
client = AsyncMemoryClient(api_key="tuputamadre", host="http://localhost:8888")
memory_manager = MemoryManager(client)
memory_agent = MemoryAgent(
name="Memory Assistant",
model="gpt-4.1",
instructions="""
You are an intelligent assistant with persistent memory capabilities.
You have access to previous conversation context that gets automatically loaded.
Be conversational, helpful, and remember details from our ongoing conversation.
You don't need to explicitly mention your memory capabilities unless asked.
""",
memory_manager=memory_manager,
)
user_id = randomname.get_name()
with trace(workflow_name="Conversation", group_id="sexo"):
# First turn
user_input = "What city is the Golden Gate Bridge in?"
print(user_input)
result = await MemoryRunner.run(
starting_agent=memory_agent,
input=user_input,
hooks=MemoryHooks(),
context=MemoryContext(user_id=user_id, user_input=user_input)
)
print(result.final_output)
user_input = "What state is it in?"
print(user_input)
result = await MemoryRunner.run(
starting_agent=memory_agent,
input=user_input,
hooks=MemoryHooks(),
context=MemoryContext(user_id=user_id, user_input=user_input)
)
print(result.final_output)
user_input = "The capital of my town is called Rigobertera. Where do you come from?"
print(user_input)
result = await MemoryRunner.run(
starting_agent=memory_agent,
input=user_input,
hooks=MemoryHooks(),
context=MemoryContext(user_id=user_id, user_input=user_input)
)
print(result.final_output)
user_input = "Which is my capital?"
print(user_input)
result = await MemoryRunner.run(
starting_agent=memory_agent,
input=user_input,
hooks=MemoryHooks(),
context=MemoryContext(user_id=user_id, user_input=user_input)
)
print(result.final_output)
``