Skip to content

Mem0ai integration as example for create memory agents #832

Open
@alanredmaiar

Description

@alanredmaiar

Simple Mem0ai Integration:

Guys, this implementation looks very good and simple interface. I've been using in the meanwhile instead this one, because the versatility of openai agents hooks allow me to do it and it worked super nice (notice that I used mem0ai API REST server as memory backend, but could be anyone if we create a proper interface):

import asyncio
import logging
import json
from typing import TypeVar, Generic
from dataclasses import dataclass

from agents import Agent,  RunContextWrapper, Runner, TResponseInputItem, Tool
from agents.lifecycle import RunHooks
from agents.run import RunResult, RunConfig, RunResultStreaming
from mem0 import AsyncMemoryClient
from pydantic import BaseModel
import randomname
from fastapi.encoders import jsonable_encoder
from agents.tool import FunctionTool

from server.mem0ai.models import Message, MemoryCreate, MemoryData, MemoryCreateResponse, MemoryListResponse, Metadata, EntityType, HandoffData


DEFAULT_MAX_TURNS = 10


TMemoryContext = TypeVar('TMemoryContext', bound=BaseModel)


class MemoryManager:
    def __init__(self, client: AsyncMemoryClient):
        self.client = client

    async def _add_memories_task(self, user_id: str, messages: list[Message], metadata: Metadata | None = None) -> list[Message]:
        """Add memories to storage with duplicate check."""
        logging.info(f"Adding {len(messages)} memories for user {user_id}")
        
        # Check for duplicates by getting the last memory
        try:
            last_memories = await self.client.get_all(user_id=user_id, limit=1)
            if last_memories.get('results'):
                last_memory_content = last_memories['results'][0]['memory']
                # Check if the last message content matches any of the new messages
                for message in messages:
                    if message.content == last_memory_content:
                        logging.warning(f"Duplicate memory detected for user {user_id}: '{message.content[:50]}...' - skipping")
                        return []
        except Exception as e:
            logging.debug(f"Could not check for duplicates: {e}")

        memory_create = MemoryCreate(
            user_id=user_id,
            messages=messages,
            metadata=metadata
        )
        res = await self.client.add(**memory_create.model_dump(mode="json", exclude_none=True))
        response = MemoryCreateResponse.model_validate(res)
        return [Message(role=m.role, content=m.memory) for m in response.results]

    async def add_memories(
        self, 
        user_id: str, 
        messages: Message | list[Message], 
        metadata: Metadata | None = None,
        background: bool = True
    ) -> list[Message] | None:
        """Add single message or list of messages to memory."""
        # Convert single message to list
        if isinstance(messages, Message):
            messages = [messages]
        
        if background:
            asyncio.create_task(self._add_memories_task(user_id, messages, metadata))
            return None
        return await self._add_memories_task(user_id, messages, metadata)

    async def get_memory(self, memory_id: str) -> Message:
        """Get single memory by ID."""
        logging.info(f"Getting memory for {memory_id}")
        res = await self.client.get(memory_id=memory_id)
        memory = MemoryData.model_validate(res)
        return Message(role=memory.role, content=memory.memory)
    
    async def get_memories(
        self, 
        user_id: str, 
        limit: int | None = None, 
        filter_by: dict | None = None
    ) -> list[Message]:
        """Get memories with optional limit and filtering."""
        logging.info(f"Getting memories for user {user_id} with limit={limit}, filter_by={filter_by}")
        
        # Get memories with optional limit
        params = {"user_id": user_id}
        if limit:
            params["limit"] = limit
            
        res = await self.client.get_all(**params)
        response = MemoryListResponse.model_validate(res)
        
        # Apply filtering if specified
        if filter_by:
            response = response.filter_by(filter_by)
        
        return [Message(role=m.role, content=m.memory) for m in response.results]


class MemoryContext(BaseModel):
    user_id: str = randomname.get_name()
    thread_id: str = "default_thread"
    user_input: str | list[TResponseInputItem] = ""
    metadata: dict | None = None


@dataclass
class MemoryAgent(Agent[TMemoryContext], Generic[TMemoryContext]):
    memory_manager: MemoryManager | None = None


class MemoryHooks(RunHooks[TMemoryContext], Generic[TMemoryContext]):

    async def on_agent_start(
        self, 
        context: RunContextWrapper[MemoryContext], 
        agent: MemoryAgent
    ) -> None:
        print(f"🔍 AGENT START: {agent.name}, Input: {str(context.context.user_input)[:100]}...")
        message = Message(role="user", content=str(context.context.user_input))
        metadata = Metadata(
            entity=EntityType.AGENT, 
            structured=False, 
            io="input", 
            name=agent.name
        )
        await agent.memory_manager.add_memories(
            user_id=context.context.user_id, 
            messages=message, 
            metadata=metadata,
            background=False
        )

    async def on_agent_end(
        self,
        context: RunContextWrapper[MemoryContext],
        agent: MemoryAgent,
        output: str | BaseModel
    ) -> None:
        print(f"🔍 AGENT END: {agent.name}, Output: {str(output)[:100]}...")
        content = output.model_dump_json() if isinstance(output, BaseModel) else str(output)
        message = Message(role="assistant", content=content)
        metadata = Metadata(
            entity=EntityType.AGENT, 
            structured=isinstance(output, BaseModel), 
            io="output", 
            name=agent.name
        )
        await agent.memory_manager.add_memories(
            user_id=context.context.user_id, 
            messages=message, 
            metadata=metadata,
            background=False
        )

    async def on_handoff(
        self,
        context: RunContextWrapper[MemoryContext],
        from_agent: MemoryAgent,
        to_agent: MemoryAgent,
    ) -> None:
        print(f"🔄 HANDOFF: {from_agent.name}{to_agent.name}")
        handoff_data = HandoffData(from_agent=from_agent.name, to_agent=to_agent.name)
        handoff_message = Message(role="system", content=handoff_data.model_dump_json())
        metadata = Metadata(
            entity=EntityType.HANDOFF, 
            structured=True, 
            io="output", 
            name=f"{from_agent.name}->{to_agent.name}"
        )
        await from_agent.memory_manager.add_memories(
            user_id=context.context.user_id, 
            messages=handoff_message, 
            metadata=metadata,
            background=True
        )

    async def on_tool_start(
        self,
        context: RunContextWrapper[MemoryContext],
        agent: MemoryAgent,
        tool: FunctionTool,
    ) -> None:
        print(f"🔧 TOOL START: {agent.name} calling '{tool.name}'")
        tool_input_schema = {
            "tool_name": tool.name,
            "description": tool.description,
            "params_schema": tool.params_json_schema
        }
        
        tool_start_message = Message(role="system", content=json.dumps(tool_input_schema))
        metadata = Metadata(
            entity=EntityType.TOOL, 
            structured=True, 
            io="input", 
            name=tool.name
        )
        await agent.memory_manager.add_memories(
            user_id=context.context.user_id, 
            messages=tool_start_message, 
            metadata=metadata,
            background=False
        )

    async def on_tool_end(
        self,
        context: RunContextWrapper[MemoryContext],
        agent: MemoryAgent,
        tool: FunctionTool,
        result: str | BaseModel,
    ) -> None:
        print(f"🔧 TOOL END: {agent.name} finished '{tool.name}' → {result}")
        logging.info(f"Tool {tool.name} finished by {agent.name} with result {result}")
        tool_result_data = {
            "tool_name": tool.name,
            "result": result.model_dump_json() if isinstance(result, BaseModel) else str(result),
            "success": True
        }
        tool_end_message = Message(role="system", content=json.dumps(tool_result_data))
        metadata = Metadata(
            entity=EntityType.TOOL, 
            structured=isinstance(result, BaseModel), 
            io="output", 
            name=tool.name
        )
        await agent.memory_manager.add_memories(
            user_id=context.context.user_id, 
            messages=tool_end_message, 
            metadata=metadata,
            background=False
        )


class MemoryRunner(Runner):

    @staticmethod
    async def _inject_previous_memories(
        user_id: str,
        input: str | list[TResponseInputItem],
        memory_manager: MemoryManager,
    ) -> list[TResponseInputItem]:
        memories = [m.model_dump(include={"role", "content"}) for m in await memory_manager.get_memories(user_id)]
        return memories + [{"role": "user", "content": input}]

    @classmethod
    async def run(
        cls,
        starting_agent: MemoryAgent[MemoryContext],
        input: str | list[TResponseInputItem],
        *,
        context: MemoryContext,
        max_turns: int = DEFAULT_MAX_TURNS,
        hooks: MemoryHooks[MemoryContext] | None = None,
        run_config: RunConfig | None = None,
        previous_response_id: str | None = None,
    ) -> RunResult:
        user_input =  await cls._inject_previous_memories(user_id=context.user_id, input=input, memory_manager=starting_agent.memory_manager)
        return await super().run(starting_agent, input=user_input, context=context, max_turns=max_turns, hooks=hooks, run_config=run_config, previous_response_id=previous_response_id)

    @classmethod
    def run_sync(
        cls,
        starting_agent: MemoryAgent[MemoryContext],
        input: str | list[TResponseInputItem],
        *,
        context: MemoryContext,
        max_turns: int = DEFAULT_MAX_TURNS,
        hooks: MemoryHooks[MemoryContext] | None = None,
        run_config: RunConfig | None = None,
        previous_response_id: str | None = None,
    ) -> RunResult:
        user_input = asyncio.get_event_loop().run_until_complete(cls._inject_previous_memories(user_id=context.user_id, input=input, memory_manager=starting_agent.memory_manager))
        return super().run_sync(starting_agent, input=user_input, context=context, max_turns=max_turns, hooks=hooks, run_config=run_config, previous_response_id=previous_response_id)
    
    @classmethod
    def run_streamed(
        cls,
        starting_agent: MemoryAgent[MemoryContext],
        input: str | list[TResponseInputItem],
        *,
        context: MemoryContext,
        max_turns: int = DEFAULT_MAX_TURNS,
        hooks: MemoryHooks[MemoryContext] | None = None,
        run_config: RunConfig | None = None,
        previous_response_id: str | None = None,
    ) -> RunResultStreaming:
        user_input = asyncio.get_event_loop().run_until_complete(cls._inject_previous_memories(user_id=context.user_id, input=input, memory_manager=starting_agent.memory_manager))
        return super().run_streamed(starting_agent, input=user_input, context=context, max_turns=max_turns, hooks=hooks, run_config=run_config, previous_response_id=previous_response_id)

Simple Usage

from mem0 import AsyncMemoryClient
import randomname

from memory import MemoryAgent, MemoryManager, MemoryContext, MemoryHooks, MemoryRunner


client = AsyncMemoryClient(api_key="tuputamadre", host="http://localhost:8888")
memory_manager = MemoryManager(client)

memory_agent = MemoryAgent(
    name="Memory Assistant",
    model="gpt-4.1",
    instructions="""
        You are an intelligent assistant with persistent memory capabilities.
        You have access to previous conversation context that gets automatically loaded.
        Be conversational, helpful, and remember details from our ongoing conversation.
        You don't need to explicitly mention your memory capabilities unless asked.
    """,
    memory_manager=memory_manager,
)

user_id = randomname.get_name()


with trace(workflow_name="Conversation", group_id="sexo"):
    # First turn
    user_input = "What city is the Golden Gate Bridge in?"
    print(user_input)
    result = await MemoryRunner.run(
        starting_agent=memory_agent,
        input=user_input,
        hooks=MemoryHooks(),
        context=MemoryContext(user_id=user_id, user_input=user_input)
    )
    print(result.final_output)
    user_input = "What state is it in?"
    print(user_input)
    result = await MemoryRunner.run(
        starting_agent=memory_agent,
        input=user_input,
        hooks=MemoryHooks(),
        context=MemoryContext(user_id=user_id, user_input=user_input)
    )
    print(result.final_output)
    user_input = "The capital of my town is called Rigobertera. Where do you come from?"
    print(user_input)
    result = await MemoryRunner.run(
        starting_agent=memory_agent,
        input=user_input,
        hooks=MemoryHooks(),
        context=MemoryContext(user_id=user_id, user_input=user_input)
    )
    print(result.final_output)
    user_input = "Which is my capital?"
    print(user_input)
    result = await MemoryRunner.run(
        starting_agent=memory_agent,
        input=user_input,
        hooks=MemoryHooks(),
        context=MemoryContext(user_id=user_id, user_input=user_input)
    )
    print(result.final_output)

``

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions