import logging import json from mcp.server.fastmcp import FastMCP from mcp.server.sse import SseServerTransport from app.utils.memory import get_memory_client from fastapi import FastAPI, Request from fastapi.routing import APIRouter import contextvars import os from dotenv import load_dotenv from app.database import SessionLocal from app.models import Memory, MemoryState, MemoryStatusHistory, MemoryAccessLog from app.utils.db import get_user_and_app import uuid import datetime from app.utils.permissions import check_memory_access_permissions from qdrant_client import models as qdrant_models # Load environment variables load_dotenv() # Initialize MCP and memory client mcp = FastMCP("mem0-mcp-server") # Check if OpenAI API key is set if not os.getenv("OPENAI_API_KEY"): raise Exception("OPENAI_API_KEY is not set in .env file") memory_client = get_memory_client() # Context variables for user_id and client_name user_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("user_id") client_name_var: contextvars.ContextVar[str] = contextvars.ContextVar("client_name") # Create a router for MCP endpoints mcp_router = APIRouter(prefix="/mcp") # Initialize SSE transport sse = SseServerTransport("/mcp/messages/") @mcp.tool(description="Add a new memory. This method is called everytime the user informs anything about themselves, their preferences, or anything that has any relevant information which can be useful in the future conversation. This can also be called when the user asks you to remember something.") async def add_memories(text: str) -> str: uid = user_id_var.get(None) client_name = client_name_var.get(None) if not uid: return "Error: user_id not provided" if not client_name: return "Error: client_name not provided" try: db = SessionLocal() try: # Get or create user and app user, app = get_user_and_app(db, user_id=uid, app_id=client_name) # Check if app is active if not app.is_active: return f"Error: App {app.name} is currently paused on OpenMemory. Cannot create new memories." response = memory_client.add(text, user_id=uid, metadata={ "source_app": "openmemory", "mcp_client": client_name, }) # Process the response and update database if isinstance(response, dict) and 'results' in response: for result in response['results']: memory_id = uuid.UUID(result['id']) memory = db.query(Memory).filter(Memory.id == memory_id).first() if result['event'] == 'ADD': if not memory: memory = Memory( id=memory_id, user_id=user.id, app_id=app.id, content=result['memory'], state=MemoryState.active ) db.add(memory) else: memory.state = MemoryState.active memory.content = result['memory'] # Create history entry history = MemoryStatusHistory( memory_id=memory_id, changed_by=user.id, old_state=MemoryState.deleted if memory else None, new_state=MemoryState.active ) db.add(history) elif result['event'] == 'DELETE': if memory: memory.state = MemoryState.deleted memory.deleted_at = datetime.datetime.now(datetime.UTC) # Create history entry history = MemoryStatusHistory( memory_id=memory_id, changed_by=user.id, old_state=MemoryState.active, new_state=MemoryState.deleted ) db.add(history) db.commit() return response finally: db.close() except Exception as e: return f"Error adding to memory: {e}" @mcp.tool(description="Search through stored memories. This method is called EVERYTIME the user asks anything.") async def search_memory(query: str) -> str: uid = user_id_var.get(None) client_name = client_name_var.get(None) if not uid: return "Error: user_id not provided" if not client_name: return "Error: client_name not provided" try: db = SessionLocal() try: # Get or create user and app user, app = get_user_and_app(db, user_id=uid, app_id=client_name) # Get accessible memory IDs based on ACL user_memories = db.query(Memory).filter(Memory.user_id == user.id).all() accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)] conditions = [qdrant_models.FieldCondition(key="user_id", match=qdrant_models.MatchValue(value=uid))] if accessible_memory_ids: # Convert UUIDs to strings for Qdrant accessible_memory_ids_str = [str(memory_id) for memory_id in accessible_memory_ids] conditions.append(qdrant_models.HasIdCondition(has_id=accessible_memory_ids_str)) filters = qdrant_models.Filter(must=conditions) embeddings = memory_client.embedding_model.embed(query, "search") hits = memory_client.vector_store.client.query_points( collection_name=memory_client.vector_store.collection_name, query=embeddings, query_filter=filters, limit=10, ) # Process search results memories = hits.points memories = [ { "id": memory.id, "memory": memory.payload["data"], "hash": memory.payload.get("hash"), "created_at": memory.payload.get("created_at"), "updated_at": memory.payload.get("updated_at"), "score": memory.score, } for memory in memories ] # Log memory access for each memory found if isinstance(memories, dict) and 'results' in memories: print(f"Memories: {memories}") for memory_data in memories['results']: if 'id' in memory_data: memory_id = uuid.UUID(memory_data['id']) # Create access log entry access_log = MemoryAccessLog( memory_id=memory_id, app_id=app.id, access_type="search", metadata_={ "query": query, "score": memory_data.get('score'), "hash": memory_data.get('hash') } ) db.add(access_log) db.commit() else: for memory in memories: memory_id = uuid.UUID(memory['id']) # Create access log entry access_log = MemoryAccessLog( memory_id=memory_id, app_id=app.id, access_type="search", metadata_={ "query": query, "score": memory.get('score'), "hash": memory.get('hash') } ) db.add(access_log) db.commit() return json.dumps(memories, indent=2) finally: db.close() except Exception as e: logging.exception(e) return f"Error searching memory: {e}" @mcp.tool(description="List all memories in the user's memory") async def list_memories() -> str: uid = user_id_var.get(None) client_name = client_name_var.get(None) if not uid: return "Error: user_id not provided" if not client_name: return "Error: client_name not provided" try: db = SessionLocal() try: # Get or create user and app user, app = get_user_and_app(db, user_id=uid, app_id=client_name) # Get all memories memories = memory_client.get_all(user_id=uid) filtered_memories = [] # Filter memories based on permissions user_memories = db.query(Memory).filter(Memory.user_id == user.id).all() accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)] if isinstance(memories, dict) and 'results' in memories: for memory_data in memories['results']: if 'id' in memory_data: memory_id = uuid.UUID(memory_data['id']) if memory_id in accessible_memory_ids: # Create access log entry access_log = MemoryAccessLog( memory_id=memory_id, app_id=app.id, access_type="list", metadata_={ "hash": memory_data.get('hash') } ) db.add(access_log) filtered_memories.append(memory_data) db.commit() else: for memory in memories: memory_id = uuid.UUID(memory['id']) memory_obj = db.query(Memory).filter(Memory.id == memory_id).first() if memory_obj and check_memory_access_permissions(db, memory_obj, app.id): # Create access log entry access_log = MemoryAccessLog( memory_id=memory_id, app_id=app.id, access_type="list", metadata_={ "hash": memory.get('hash') } ) db.add(access_log) filtered_memories.append(memory) db.commit() return json.dumps(filtered_memories, indent=2) finally: db.close() except Exception as e: return f"Error getting memories: {e}" @mcp.tool(description="Delete all memories in the user's memory") async def delete_all_memories() -> str: uid = user_id_var.get(None) client_name = client_name_var.get(None) if not uid: return "Error: user_id not provided" if not client_name: return "Error: client_name not provided" try: db = SessionLocal() try: # Get or create user and app user, app = get_user_and_app(db, user_id=uid, app_id=client_name) user_memories = db.query(Memory).filter(Memory.user_id == user.id).all() accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)] # delete the accessible memories only for memory_id in accessible_memory_ids: memory_client.delete(memory_id) # Update each memory's state and create history entries now = datetime.datetime.now(datetime.UTC) for memory_id in accessible_memory_ids: memory = db.query(Memory).filter(Memory.id == memory_id).first() # Update memory state memory.state = MemoryState.deleted memory.deleted_at = now # Create history entry history = MemoryStatusHistory( memory_id=memory_id, changed_by=user.id, old_state=MemoryState.active, new_state=MemoryState.deleted ) db.add(history) # Create access log entry access_log = MemoryAccessLog( memory_id=memory_id, app_id=app.id, access_type="delete_all", metadata_={"operation": "bulk_delete"} ) db.add(access_log) db.commit() return "Successfully deleted all memories" finally: db.close() except Exception as e: return f"Error deleting memories: {e}" @mcp_router.get("/{client_name}/sse/{user_id}") async def handle_sse(request: Request): """Handle SSE connections for a specific user and client""" # Extract user_id and client_name from path parameters uid = request.path_params.get("user_id") user_token = user_id_var.set(uid or "") client_name = request.path_params.get("client_name") client_token = client_name_var.set(client_name or "") try: # Handle SSE connection async with sse.connect_sse( request.scope, request.receive, request._send, ) as (read_stream, write_stream): await mcp._mcp_server.run( read_stream, write_stream, mcp._mcp_server.create_initialization_options(), ) finally: # Clean up context variables user_id_var.reset(user_token) client_name_var.reset(client_token) @mcp_router.post("/messages/") async def handle_get_message(request: Request): return await handle_post_message(request) @mcp_router.post("/{client_name}/sse/{user_id}/messages/") async def handle_post_message(request: Request): return await handle_post_message(request) async def handle_post_message(request: Request): """Handle POST messages for SSE""" try: body = await request.body() # Create a simple receive function that returns the body async def receive(): return {"type": "http.request", "body": body, "more_body": False} # Create a simple send function that does nothing async def send(message): return {} # Call handle_post_message with the correct arguments await sse.handle_post_message(request.scope, receive, send) # Return a success response return {"status": "ok"} finally: pass # Clean up context variable # client_name_var.reset(client_token) def setup_mcp_server(app: FastAPI): """Setup MCP server with the FastAPI application""" mcp._mcp_server.name = f"mem0-mcp-server" # Include MCP router in the FastAPI app app.include_router(mcp_router)